Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 147 additions & 7 deletions pytype/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,143 @@ def _match_type_against_type(self, left, other_type, subst, view):
% (type(left), type(other_type))
)

def _match_signatures(self, sig, other_sig, subst, view):
"""Checks if signature `sig` is compatible with signature `other_sig`."""
# Track unconsumed required parameters in sig that could accept keyword
# arguments (i.e., non-positional-only parameters without defaults).
unconsumed_kw_params = {
*sig.param_names[sig.posonly_count :],
*sig.kwonly_params,
} - sig.defaults.keys()

new_substs = []

# Track the index in sig.param_names as we consume positional parameters.
sig_pos_index = 0

# For normal (positional-or-keyword) parameters in other_sig, sig must be
# prepared to receive these as either positional or keyword arguments.
for other_param_idx, other_param in enumerate(other_sig.param_names):
other_type = other_sig.annotations.get(
other_param, self.ctx.convert.unsolvable
)
# Check if this is a normal parameter (not positional-only) in other_sig.
other_is_normal = other_param_idx >= other_sig.posonly_count

# Check that there is a corresponding positional parameter in sig.
# Note that the name does not have to agree.
if sig_pos_index < len(sig.param_names):
sig_param = sig.param_names[sig_pos_index]
other_is_optional = other_param in other_sig.defaults
sig_is_required = sig_param not in sig.defaults
if other_is_optional and sig_is_required:
return None
sig_type = sig.annotations.get(sig_param, self.ctx.convert.unsolvable)
sig_pos_index += 1
unconsumed_kw_params.discard(sig_param)
elif sig.varargs_name:
varargs_type = sig.annotations.get(
sig.varargs_name, self.ctx.convert.unsolvable
)
sig_type = self.ctx.convert.get_element_type(varargs_type)
if sig_type is None:
sig_type = self.ctx.convert.unsolvable
else:
return None
new_subst = self._instantiate_and_match(other_type, sig_type, subst, view)
if new_subst is None:
return None
new_substs.append(new_subst)

# Normal parameters can also be passed by keyword, so check that there
# is a corresponding keyword parameter in sig.
if other_is_normal:
# Check if sig can accept this parameter by keyword
if other_param in sig.kwonly_params or other_param in sig.param_names:
if other_param in sig.param_names:
kw_param_idx = sig.param_names.index(other_param)
else:
kw_param_idx = len(sig.param_names) + sig.kwonly_params.index(
other_param
)

# Check if it's the same parameter we just matched positionally
pos_param_idx = sig_pos_index - 1 if sig_pos_index > 0 else -1
if pos_param_idx >= 0 and pos_param_idx < len(sig.param_names):
if kw_param_idx != pos_param_idx:
# This has to be either the same parameter as the positional one
# (i.e. they are at the same index in sig), or they both have to
# be optional (because only one can be passed).
pos_param = sig.param_names[pos_param_idx]
if (
pos_param not in sig.defaults
or other_param not in sig.defaults
):
return None
# Also check type compatibility for the keyword parameter
kw_type = sig.annotations.get(
other_param, self.ctx.convert.unsolvable
)
new_subst = self._instantiate_and_match(
other_type, kw_type, subst, view
)
if new_subst is None:
return None
elif pos_param_idx == -1:
# Matched via varargs - no additional check needed
pass
elif not sig.kwargs_name:
# sig has no **kwargs to accept this keyword argument
return None

# sig cannot have unconsumed positional-only required parameters.
for i in range(sig_pos_index, sig.posonly_count):
if sig.param_names[i] not in sig.defaults:
return None

# Match other_sig's keyword-only parameters against sig.
for kwonly_param in other_sig.kwonly_params:
other_type = other_sig.annotations.get(
kwonly_param, self.ctx.convert.unsolvable
)
if kwonly_param in sig.kwonly_params or kwonly_param in sig.param_names:
other_is_optional = kwonly_param in other_sig.defaults
sig_is_required = kwonly_param not in sig.defaults
if other_is_optional and sig_is_required:
return None
sig_type = sig.annotations.get(
kwonly_param, self.ctx.convert.unsolvable
)
unconsumed_kw_params.discard(kwonly_param)
elif sig.kwargs_name:
kwargs_type = sig.annotations.get(
sig.kwargs_name, self.ctx.convert.unsolvable
)
sig_type = self.ctx.convert.get_element_type(kwargs_type)
if sig_type is None:
sig_type = self.ctx.convert.unsolvable
else:
return None
new_subst = self._instantiate_and_match(other_type, sig_type, subst, view)
if new_subst is None:
return None
new_substs.append(new_subst)

# sig cannot have unconsumed required parameters (those not covered by
# *args or **kwargs).
if not sig.varargs_name and not sig.kwargs_name:
if unconsumed_kw_params:
return None

sig_ret = sig.annotations.get("return", self.ctx.convert.unsolvable)
other_ret = other_sig.annotations.get("return", self.ctx.convert.unsolvable)
new_subst = self._instantiate_and_match(sig_ret, other_ret, subst, view)
if new_subst is None:
return None
new_substs.append(new_subst)

return self._merge_substs(subst, new_substs)

def _match_type_against_callback_protocol(
self, left, other_type, subst, view
):
Expand All @@ -961,17 +1098,20 @@ def _match_type_against_callback_protocol(
):
return None
new_substs = []
left_signatures = self._get_signatures(left, subst, view)
for expected_method in method_var.data:
signatures = function.get_signatures(expected_method)
for sig in signatures:
sig = sig.drop_first_parameter() # drop `self`
expected_callable = self.ctx.pytd_convert.signature_to_callable(sig)
new_subst = self._match_type_against_type(
left, expected_callable, subst, view
)
if new_subst is not None:
# For a set of overloaded signatures, only one needs to match.
new_substs.append(new_subst)
found = False
for left_sig in left_signatures:
new_subst = self._match_signatures(left_sig, sig, subst, view)
if new_subst is not None:
# For a set of overloaded signatures, only one needs to match.
found = True
new_substs.append(new_subst)
break
if found:
break
else:
# Every method_var binding must have a matching signature.
Expand Down
9 changes: 9 additions & 0 deletions pytype/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@ py_test(
.test_base
)

py_test(
NAME
test_signatures
SRCS
test_signatures.py
DEPS
.test_base
)

py_test(
NAME
test_slice1
Expand Down
22 changes: 22 additions & 0 deletions pytype/tests/test_protocols2.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,28 @@ def f3(x: int) -> str:
pythonpath=[d.path],
)

def test_callback_protocol_kw_only(self):
self.CheckWithErrors("""
from typing import Protocol
class Foo(Protocol):
def __call__(self, *, x: str) -> str:
return x
def f1() -> str:
return ''
def f2(x: str) -> str:
return ''
def f3(x: int) -> str:
return str(x)
def accepts_foo(f: Foo):
pass
accepts_foo(f1) # wrong-arg-types
accepts_foo(f2)
accepts_foo(f3) # wrong-arg-types
""")

def test_class_matches_callback_protocol(self):
self.CheckWithErrors("""
from typing_extensions import Protocol
Expand Down
Loading
Loading