diff --git a/pytype/matcher.py b/pytype/matcher.py index 578a46376..230e43751 100644 --- a/pytype/matcher.py +++ b/pytype/matcher.py @@ -947,6 +947,143 @@ def _match_type_against_type(self, left, other_type, subst, view): % (type(left), type(other_type)) ) + def _match_signatures(self, sig, other_sig, subst, view): + """Checks if signature `sig` is compatible with signature `other_sig`.""" + # Track unconsumed required parameters in sig that could accept keyword + # arguments (i.e., non-positional-only parameters without defaults). + unconsumed_kw_params = { + *sig.param_names[sig.posonly_count :], + *sig.kwonly_params, + } - sig.defaults.keys() + + new_substs = [] + + # Track the index in sig.param_names as we consume positional parameters. + sig_pos_index = 0 + + # For normal (positional-or-keyword) parameters in other_sig, sig must be + # prepared to receive these as either positional or keyword arguments. + for other_param_idx, other_param in enumerate(other_sig.param_names): + other_type = other_sig.annotations.get( + other_param, self.ctx.convert.unsolvable + ) + # Check if this is a normal parameter (not positional-only) in other_sig. + other_is_normal = other_param_idx >= other_sig.posonly_count + + # Check that there is a corresponding positional parameter in sig. + # Note that the name does not have to agree. + if sig_pos_index < len(sig.param_names): + sig_param = sig.param_names[sig_pos_index] + other_is_optional = other_param in other_sig.defaults + sig_is_required = sig_param not in sig.defaults + if other_is_optional and sig_is_required: + return None + sig_type = sig.annotations.get(sig_param, self.ctx.convert.unsolvable) + sig_pos_index += 1 + unconsumed_kw_params.discard(sig_param) + elif sig.varargs_name: + varargs_type = sig.annotations.get( + sig.varargs_name, self.ctx.convert.unsolvable + ) + sig_type = self.ctx.convert.get_element_type(varargs_type) + if sig_type is None: + sig_type = self.ctx.convert.unsolvable + else: + return None + new_subst = self._instantiate_and_match(other_type, sig_type, subst, view) + if new_subst is None: + return None + new_substs.append(new_subst) + + # Normal parameters can also be passed by keyword, so check that there + # is a corresponding keyword parameter in sig. + if other_is_normal: + # Check if sig can accept this parameter by keyword + if other_param in sig.kwonly_params or other_param in sig.param_names: + if other_param in sig.param_names: + kw_param_idx = sig.param_names.index(other_param) + else: + kw_param_idx = len(sig.param_names) + sig.kwonly_params.index( + other_param + ) + + # Check if it's the same parameter we just matched positionally + pos_param_idx = sig_pos_index - 1 if sig_pos_index > 0 else -1 + if pos_param_idx >= 0 and pos_param_idx < len(sig.param_names): + if kw_param_idx != pos_param_idx: + # This has to be either the same parameter as the positional one + # (i.e. they are at the same index in sig), or they both have to + # be optional (because only one can be passed). + pos_param = sig.param_names[pos_param_idx] + if ( + pos_param not in sig.defaults + or other_param not in sig.defaults + ): + return None + # Also check type compatibility for the keyword parameter + kw_type = sig.annotations.get( + other_param, self.ctx.convert.unsolvable + ) + new_subst = self._instantiate_and_match( + other_type, kw_type, subst, view + ) + if new_subst is None: + return None + elif pos_param_idx == -1: + # Matched via varargs - no additional check needed + pass + elif not sig.kwargs_name: + # sig has no **kwargs to accept this keyword argument + return None + + # sig cannot have unconsumed positional-only required parameters. + for i in range(sig_pos_index, sig.posonly_count): + if sig.param_names[i] not in sig.defaults: + return None + + # Match other_sig's keyword-only parameters against sig. + for kwonly_param in other_sig.kwonly_params: + other_type = other_sig.annotations.get( + kwonly_param, self.ctx.convert.unsolvable + ) + if kwonly_param in sig.kwonly_params or kwonly_param in sig.param_names: + other_is_optional = kwonly_param in other_sig.defaults + sig_is_required = kwonly_param not in sig.defaults + if other_is_optional and sig_is_required: + return None + sig_type = sig.annotations.get( + kwonly_param, self.ctx.convert.unsolvable + ) + unconsumed_kw_params.discard(kwonly_param) + elif sig.kwargs_name: + kwargs_type = sig.annotations.get( + sig.kwargs_name, self.ctx.convert.unsolvable + ) + sig_type = self.ctx.convert.get_element_type(kwargs_type) + if sig_type is None: + sig_type = self.ctx.convert.unsolvable + else: + return None + new_subst = self._instantiate_and_match(other_type, sig_type, subst, view) + if new_subst is None: + return None + new_substs.append(new_subst) + + # sig cannot have unconsumed required parameters (those not covered by + # *args or **kwargs). + if not sig.varargs_name and not sig.kwargs_name: + if unconsumed_kw_params: + return None + + sig_ret = sig.annotations.get("return", self.ctx.convert.unsolvable) + other_ret = other_sig.annotations.get("return", self.ctx.convert.unsolvable) + new_subst = self._instantiate_and_match(sig_ret, other_ret, subst, view) + if new_subst is None: + return None + new_substs.append(new_subst) + + return self._merge_substs(subst, new_substs) + def _match_type_against_callback_protocol( self, left, other_type, subst, view ): @@ -961,17 +1098,20 @@ def _match_type_against_callback_protocol( ): return None new_substs = [] + left_signatures = self._get_signatures(left, subst, view) for expected_method in method_var.data: signatures = function.get_signatures(expected_method) for sig in signatures: sig = sig.drop_first_parameter() # drop `self` - expected_callable = self.ctx.pytd_convert.signature_to_callable(sig) - new_subst = self._match_type_against_type( - left, expected_callable, subst, view - ) - if new_subst is not None: - # For a set of overloaded signatures, only one needs to match. - new_substs.append(new_subst) + found = False + for left_sig in left_signatures: + new_subst = self._match_signatures(left_sig, sig, subst, view) + if new_subst is not None: + # For a set of overloaded signatures, only one needs to match. + found = True + new_substs.append(new_subst) + break + if found: break else: # Every method_var binding must have a matching signature. diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt index 4e79dac9c..483878fbd 100644 --- a/pytype/tests/CMakeLists.txt +++ b/pytype/tests/CMakeLists.txt @@ -242,6 +242,15 @@ py_test( .test_base ) +py_test( + NAME + test_signatures + SRCS + test_signatures.py + DEPS + .test_base +) + py_test( NAME test_slice1 diff --git a/pytype/tests/test_protocols2.py b/pytype/tests/test_protocols2.py index 44b477a28..83d2ed95f 100644 --- a/pytype/tests/test_protocols2.py +++ b/pytype/tests/test_protocols2.py @@ -676,6 +676,28 @@ def f3(x: int) -> str: pythonpath=[d.path], ) + def test_callback_protocol_kw_only(self): + self.CheckWithErrors(""" + from typing import Protocol + class Foo(Protocol): + def __call__(self, *, x: str) -> str: + return x + + def f1() -> str: + return '' + def f2(x: str) -> str: + return '' + def f3(x: int) -> str: + return str(x) + + def accepts_foo(f: Foo): + pass + + accepts_foo(f1) # wrong-arg-types + accepts_foo(f2) + accepts_foo(f3) # wrong-arg-types + """) + def test_class_matches_callback_protocol(self): self.CheckWithErrors(""" from typing_extensions import Protocol diff --git a/pytype/tests/test_signatures.py b/pytype/tests/test_signatures.py new file mode 100644 index 000000000..dd1f3719d --- /dev/null +++ b/pytype/tests/test_signatures.py @@ -0,0 +1,499 @@ +"""Tests for matching against signatures.""" + +from pytype.tests import test_base + + +class SignatureTest(test_base.BaseTest): # pylint: disable=missing-docstring + + def test_no_params(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self): ... + def f1(): ... + def f2(x: int): ... + _: P = f1 + _: P = f2 # annotation-type-mismatch + """) + + def test_params_are_contravariant(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: bool): ... + def f1(x: int): ... + def f2(x: bool): ... + def f3(x: str): ... + _: P = f1 + _: P = f2 + _: P = f3 # annotation-type-mismatch + """) + + def test_return_type_is_covariant(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self) -> int: ... + def f1() -> object: ... + def f2() -> bool: return True + _: P = f1 # annotation-type-mismatch + _: P = f2 + """) + + # Each overload on the LHS must have at least one matching overload on + # the RHS. + + def test_overloads1(self): + self.CheckWithErrors(""" + from typing import Protocol, overload + class P(Protocol): + def __call__(self) -> int: ... + @overload + def f() -> int: ... + @overload + def f(x: int) -> float: ... + _: P = f + """) + + def test_overloads2(self): + self.CheckWithErrors(""" + from typing import Protocol, overload + class P(Protocol): + @overload + def __call__(self) -> int: ... + @overload + def __call__(self, x: bool) -> int: ... + def __call__(self, *args, **kwargs): raise NotImplementedError + + @overload + def f1() -> int: ... + @overload + def f1(x: int) -> bool: ... + def f2(x: int) -> float: return 42.0 + + _: P = f1 + _: P = f2 # annotation-type-mismatch + """) + + # Cases involving different parameter kinds, all including positional-only + # parameters. For the tests with positive expectations we vary everything + # that's significant (e.g., names, types). For the tests with negative + # expectations we avoid that so we can control what we're testing. + + def test_positional_only_optional_to_required(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int, /) -> object: ... + def f(x: object, y: object = ..., /) -> int: return 42 + _: P = f + """) + + def test_positional_only_required_to_optional(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int = 42, /) -> int: ... + def f(x: int, y: int, /) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_positional_only_optional_to_none(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /) -> object: ... + def f(x: object, y: object = ..., /) -> int: return 42 + _: P = f + """) + + def test_positional_only_unexpected_argument(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int = 42, /) -> int: ... + def f(x: int, /) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_positional_only_to_positional_keyword(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int, /) -> object: ... + def f(x: object, /, y: object = ...) -> int: return 42 + _: P = f + """) + + def test_positional_only_called_by_keyword(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /, y: int = 42) -> int: ... + def f(x: int, y: int, /) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_positional_keyword_optional_not_passed(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /) -> object: ... + def f(x: object, /, y: object = ...) -> int: return 42 + _: P = f + """) + + def test_positional_keyword_unexpected_argument(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /, y: int = 42) -> int: ... + def f(x: int, /) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_args_accepts_positional_only(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int, /) -> object: ... + def f(x: object, /, *args: object) -> int: return 42 + _: P = f + """) + + def test_args_not_required(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int = 42, /) -> object: ... + def f(x: object, /, *args: object) -> int: return 42 + _: P = f + """) + + def test_args_not_required_single_param(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /) -> object: ... + def f(x: object, /, *args: object) -> int: return 42 + _: P = f + """) + + def test_no_args_cant_accept_arbitrary_positional(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /, *args: int) -> int: ... + def f(x: int, y: int, /) -> object: ... + _: P = f # annotation-type-mismatch + """) + + def test_args_variance(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /, *args: int) -> object: ... + def f(x: object, /, *args: object) -> int: return 42 + _: P = f + """) + + def test_args_variance_negative(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: object, /, *args: object) -> int: ... + def f(x: int, /, *args: int) -> object: ... + _: P = f # annotation-type-mismatch + """) + + def test_args_with_extra_positional(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int, /, *args: int) -> object: ... + def f(x: object, /, *args: object) -> int: return 42 + _: P = f + """) + + # Dually for different parameter kinds, all including keyword-only + # parameters. + + def test_keyword_only_optional_to_required(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, y: int, x: int) -> object: ... + def f(*, x: object, y: object = ...) -> int: return 42 + _: P = f + """) + + def test_keyword_only_required_to_optional(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int, y: int = 42) -> int: ... + def f(*, x: int, y: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_keyword_only_optional_to_none(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int) -> object: ... + def f(*, x: object, y: object = ...) -> int: return 42 + _: P = f + """) + + def test_keyword_only_unexpected_argument(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int, y: int = 42) -> int: ... + def f(*, x: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_keyword_only_to_positional_keyword(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, y: int, x: int) -> object: ... + def f(x: object = ..., *, y: object) -> int: return 42 + _: P = f + """) + + def test_keyword_only_called_positionally(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int = 42, *, y: int) -> int: ... + def f(*, x: int, y: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_positional_keyword_to_keyword_only(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, y: int) -> object: ... + def f(x: object = ..., *, y: object) -> int: return 42 + _: P = f + """) + + def test_positional_keyword_unexpected_keyword_only(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int = 42, *, y: int) -> int: ... + def f(*, x: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_kwargs_accepts_keyword_only(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int, y: int) -> object: ... + def f(*, x: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_kwargs_accepts_optional_keyword_only(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int, y: int = 42) -> object: ... + def f(x: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_kwargs_not_required(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int) -> object: ... + def f(x: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_no_kwargs_cant_accept_arbitrary_keyword(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, *, x: int, **kwargs: int) -> int: ... + def f(*, x: int, y: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_kwargs_variance(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, **kwargs: int) -> object: ... + def f(x: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_kwargs_with_extra_keyword_only(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, *, y: int, **kwargs: int) -> object: ... + def f(x: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_args_kwargs_variance(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /, *args: int, **kwargs: int) -> object: ... + def f(x: object, /, *args: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_args_kwargs_with_keyword_only(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, /, *args: int, y: int, **kwargs: int) -> object: ... + def f(x: object, /, *args: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_kwargs_variance_negative(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: object, **kwargs: object) -> int: ... + def f(x: int, **kwargs: int) -> object: ... + _: P = f # annotation-type-mismatch + """) + + def test_kwargs_with_keyword_only_variance_negative(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: object, **kwargs: object) -> int: ... + def f(x: int, *, y: int, **kwargs: int) -> object: ... + _: P = f # annotation-type-mismatch + """) + + def test_args_kwargs_variance_negative(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: object, /, *args: object, **kwargs: object) -> int: ... + def f(x: int, /, *args: int, **kwargs: int) -> object: ... + _: P = f # annotation-type-mismatch + """) + + def test_args_kwargs_keyword_only_variance_negative(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: object, /, *args: object, **kwargs: object) -> int: ... + def f(x: int, /, *args: int, y: int, **kwargs: int) -> object: ... + _: P = f # annotation-type-mismatch + """) + + # Cases involving only normal parameters. + + def test_normal_params_optional_to_required(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int) -> object: ... + def f(x: object, y: object = ...) -> int: return 42 + _: P = f + """) + + def test_normal_params_required_to_optional(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int = 42) -> int: ... + def f(x: int, y: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_normal_params_optional_to_none(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int) -> object: ... + def f(x: object, y: object = ...) -> int: return 42 + _: P = f + """) + + def test_normal_params_unexpected_argument(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int = 42) -> int: ... + def f(x: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_normal_params_with_args_kwargs(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int) -> object: ... + def f(x: object, *args: object, **kwargs: object) -> int: return 42 + _: P = f + """) + + def test_normal_params_with_args_mismatch(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int) -> int: ... + def f(x: int, *args: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_normal_params_with_kwargs_mismatch(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int, y: int) -> int: ... + def f(x: int, **kwargs: int) -> float: return 3.14 + _: P = f # annotation-type-mismatch + """) + + # Weird cases where a parameter in the supertype can be accepted by + # either a positional-only or keyword-only parameter in the subtype. + # Both have to be optional because not both will be passed. + + def test_weird_positional_or_keyword_both_optional(self): + self.Check(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int) -> object: ... + def f(y: object = ..., /, *, x: object = ...) -> int: return 42 + _: P = f + """) + + def test_weird_positional_only_required(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int) -> int: ... + def f(y: int, /, *, x: int = 42) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + def test_weird_keyword_only_required(self): + self.CheckWithErrors(""" + from typing import Protocol + class P(Protocol): + def __call__(self, x: int) -> int: ... + def f(y: int = 42, /, *, x: int) -> int: return 42 + _: P = f # annotation-type-mismatch + """) + + +if __name__ == "__main__": + test_base.main()