Skip to content

Commit c519711

Browse files
superbobrycopybara-github
authored andcommitted
Added generalized signature matching
This is necessary to properly support callback protocols. Note that the implementation does not currently support typing.ParamSpec. This will be added in a follow up. PiperOrigin-RevId: 827866351
1 parent 9e87d0f commit c519711

File tree

4 files changed

+677
-7
lines changed

4 files changed

+677
-7
lines changed

pytype/matcher.py

Lines changed: 147 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,143 @@ def _match_type_against_type(self, left, other_type, subst, view):
947947
% (type(left), type(other_type))
948948
)
949949

950+
def _match_signatures(self, sig, other_sig, subst, view):
951+
"""Checks if signature `sig` is compatible with signature `other_sig`."""
952+
# Track unconsumed required parameters in sig that could accept keyword
953+
# arguments (i.e., non-positional-only parameters without defaults).
954+
unconsumed_kw_params = {
955+
*sig.param_names[sig.posonly_count :],
956+
*sig.kwonly_params,
957+
} - sig.defaults.keys()
958+
959+
new_substs = []
960+
961+
# Track the index in sig.param_names as we consume positional parameters.
962+
sig_pos_index = 0
963+
964+
# For normal (positional-or-keyword) parameters in other_sig, sig must be
965+
# prepared to receive these as either positional or keyword arguments.
966+
for other_param_idx, other_param in enumerate(other_sig.param_names):
967+
other_type = other_sig.annotations.get(
968+
other_param, self.ctx.convert.unsolvable
969+
)
970+
# Check if this is a normal parameter (not positional-only) in other_sig.
971+
other_is_normal = other_param_idx >= other_sig.posonly_count
972+
973+
# Check that there is a corresponding positional parameter in sig.
974+
# Note that the name does not have to agree.
975+
if sig_pos_index < len(sig.param_names):
976+
sig_param = sig.param_names[sig_pos_index]
977+
other_is_optional = other_param in other_sig.defaults
978+
sig_is_required = sig_param not in sig.defaults
979+
if other_is_optional and sig_is_required:
980+
return None
981+
sig_type = sig.annotations.get(sig_param, self.ctx.convert.unsolvable)
982+
sig_pos_index += 1
983+
unconsumed_kw_params.discard(sig_param)
984+
elif sig.varargs_name:
985+
varargs_type = sig.annotations.get(
986+
sig.varargs_name, self.ctx.convert.unsolvable
987+
)
988+
sig_type = self.ctx.convert.get_element_type(varargs_type)
989+
if sig_type is None:
990+
sig_type = self.ctx.convert.unsolvable
991+
else:
992+
return None
993+
new_subst = self._instantiate_and_match(other_type, sig_type, subst, view)
994+
if new_subst is None:
995+
return None
996+
new_substs.append(new_subst)
997+
998+
# Normal parameters can also be passed by keyword, so check that there
999+
# is a corresponding keyword parameter in sig.
1000+
if other_is_normal:
1001+
# Check if sig can accept this parameter by keyword
1002+
if other_param in sig.kwonly_params or other_param in sig.param_names:
1003+
if other_param in sig.param_names:
1004+
kw_param_idx = sig.param_names.index(other_param)
1005+
else:
1006+
kw_param_idx = len(sig.param_names) + sig.kwonly_params.index(
1007+
other_param
1008+
)
1009+
1010+
# Check if it's the same parameter we just matched positionally
1011+
pos_param_idx = sig_pos_index - 1 if sig_pos_index > 0 else -1
1012+
if pos_param_idx >= 0 and pos_param_idx < len(sig.param_names):
1013+
if kw_param_idx != pos_param_idx:
1014+
# This has to be either the same parameter as the positional one
1015+
# (i.e. they are at the same index in sig), or they both have to
1016+
# be optional (because only one can be passed).
1017+
pos_param = sig.param_names[pos_param_idx]
1018+
if (
1019+
pos_param not in sig.defaults
1020+
or other_param not in sig.defaults
1021+
):
1022+
return None
1023+
# Also check type compatibility for the keyword parameter
1024+
kw_type = sig.annotations.get(
1025+
other_param, self.ctx.convert.unsolvable
1026+
)
1027+
new_subst = self._instantiate_and_match(
1028+
other_type, kw_type, subst, view
1029+
)
1030+
if new_subst is None:
1031+
return None
1032+
elif pos_param_idx == -1:
1033+
# Matched via varargs - no additional check needed
1034+
pass
1035+
elif not sig.kwargs_name:
1036+
# sig has no **kwargs to accept this keyword argument
1037+
return None
1038+
1039+
# sig cannot have unconsumed positional-only required parameters.
1040+
for i in range(sig_pos_index, sig.posonly_count):
1041+
if sig.param_names[i] not in sig.defaults:
1042+
return None
1043+
1044+
# Match other_sig's keyword-only parameters against sig.
1045+
for kwonly_param in other_sig.kwonly_params:
1046+
other_type = other_sig.annotations.get(
1047+
kwonly_param, self.ctx.convert.unsolvable
1048+
)
1049+
if kwonly_param in sig.kwonly_params or kwonly_param in sig.param_names:
1050+
other_is_optional = kwonly_param in other_sig.defaults
1051+
sig_is_required = kwonly_param not in sig.defaults
1052+
if other_is_optional and sig_is_required:
1053+
return None
1054+
sig_type = sig.annotations.get(
1055+
kwonly_param, self.ctx.convert.unsolvable
1056+
)
1057+
unconsumed_kw_params.discard(kwonly_param)
1058+
elif sig.kwargs_name:
1059+
kwargs_type = sig.annotations.get(
1060+
sig.kwargs_name, self.ctx.convert.unsolvable
1061+
)
1062+
sig_type = self.ctx.convert.get_element_type(kwargs_type)
1063+
if sig_type is None:
1064+
sig_type = self.ctx.convert.unsolvable
1065+
else:
1066+
return None
1067+
new_subst = self._instantiate_and_match(other_type, sig_type, subst, view)
1068+
if new_subst is None:
1069+
return None
1070+
new_substs.append(new_subst)
1071+
1072+
# sig cannot have unconsumed required parameters (those not covered by
1073+
# *args or **kwargs).
1074+
if not sig.varargs_name and not sig.kwargs_name:
1075+
if unconsumed_kw_params:
1076+
return None
1077+
1078+
sig_ret = sig.annotations.get("return", self.ctx.convert.unsolvable)
1079+
other_ret = other_sig.annotations.get("return", self.ctx.convert.unsolvable)
1080+
new_subst = self._instantiate_and_match(sig_ret, other_ret, subst, view)
1081+
if new_subst is None:
1082+
return None
1083+
new_substs.append(new_subst)
1084+
1085+
return self._merge_substs(subst, new_substs)
1086+
9501087
def _match_type_against_callback_protocol(
9511088
self, left, other_type, subst, view
9521089
):
@@ -961,17 +1098,20 @@ def _match_type_against_callback_protocol(
9611098
):
9621099
return None
9631100
new_substs = []
1101+
left_signatures = self._get_signatures(left, subst, view)
9641102
for expected_method in method_var.data:
9651103
signatures = function.get_signatures(expected_method)
9661104
for sig in signatures:
9671105
sig = sig.drop_first_parameter() # drop `self`
968-
expected_callable = self.ctx.pytd_convert.signature_to_callable(sig)
969-
new_subst = self._match_type_against_type(
970-
left, expected_callable, subst, view
971-
)
972-
if new_subst is not None:
973-
# For a set of overloaded signatures, only one needs to match.
974-
new_substs.append(new_subst)
1106+
found = False
1107+
for left_sig in left_signatures:
1108+
new_subst = self._match_signatures(left_sig, sig, subst, view)
1109+
if new_subst is not None:
1110+
# For a set of overloaded signatures, only one needs to match.
1111+
found = True
1112+
new_substs.append(new_subst)
1113+
break
1114+
if found:
9751115
break
9761116
else:
9771117
# Every method_var binding must have a matching signature.

pytype/tests/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,3 +1305,12 @@ py_test(
13051305
.test_base
13061306
pytype.pytd.pytd
13071307
)
1308+
1309+
py_test(
1310+
NAME
1311+
test_signatures
1312+
SRCS
1313+
test_signatures.py
1314+
DEPS
1315+
.test_base
1316+
)

pytype/tests/test_protocols2.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,28 @@ def f3(x: int) -> str:
676676
pythonpath=[d.path],
677677
)
678678

679+
def test_callback_protocol_kw_only(self):
680+
self.CheckWithErrors("""
681+
from typing import Protocol
682+
class Foo(Protocol):
683+
def __call__(self, *, x: str) -> str:
684+
return x
685+
686+
def f1() -> str:
687+
return ''
688+
def f2(x: str) -> str:
689+
return ''
690+
def f3(x: int) -> str:
691+
return str(x)
692+
693+
def accepts_foo(f: Foo):
694+
pass
695+
696+
accepts_foo(f1) # wrong-arg-types
697+
accepts_foo(f2)
698+
accepts_foo(f3) # wrong-arg-types
699+
""")
700+
679701
def test_class_matches_callback_protocol(self):
680702
self.CheckWithErrors("""
681703
from typing_extensions import Protocol

0 commit comments

Comments
 (0)