@@ -947,6 +947,143 @@ def _match_type_against_type(self, left, other_type, subst, view):
947947 % (type (left ), type (other_type ))
948948 )
949949
950+ def _match_signatures (self , sig , other_sig , subst , view ):
951+ """Checks if signature `sig` is compatible with signature `other_sig`."""
952+ # Track unconsumed required parameters in sig that could accept keyword
953+ # arguments (i.e., non-positional-only parameters without defaults).
954+ unconsumed_kw_params = {
955+ * sig .param_names [sig .posonly_count :],
956+ * sig .kwonly_params ,
957+ } - sig .defaults .keys ()
958+
959+ new_substs = []
960+
961+ # Track the index in sig.param_names as we consume positional parameters.
962+ sig_pos_index = 0
963+
964+ # For normal (positional-or-keyword) parameters in other_sig, sig must be
965+ # prepared to receive these as either positional or keyword arguments.
966+ for other_param_idx , other_param in enumerate (other_sig .param_names ):
967+ other_type = other_sig .annotations .get (
968+ other_param , self .ctx .convert .unsolvable
969+ )
970+ # Check if this is a normal parameter (not positional-only) in other_sig.
971+ other_is_normal = other_param_idx >= other_sig .posonly_count
972+
973+ # Check that there is a corresponding positional parameter in sig.
974+ # Note that the name does not have to agree.
975+ if sig_pos_index < len (sig .param_names ):
976+ sig_param = sig .param_names [sig_pos_index ]
977+ other_is_optional = other_param in other_sig .defaults
978+ sig_is_required = sig_param not in sig .defaults
979+ if other_is_optional and sig_is_required :
980+ return None
981+ sig_type = sig .annotations .get (sig_param , self .ctx .convert .unsolvable )
982+ sig_pos_index += 1
983+ unconsumed_kw_params .discard (sig_param )
984+ elif sig .varargs_name :
985+ varargs_type = sig .annotations .get (
986+ sig .varargs_name , self .ctx .convert .unsolvable
987+ )
988+ sig_type = self .ctx .convert .get_element_type (varargs_type )
989+ if sig_type is None :
990+ sig_type = self .ctx .convert .unsolvable
991+ else :
992+ return None
993+ new_subst = self ._instantiate_and_match (other_type , sig_type , subst , view )
994+ if new_subst is None :
995+ return None
996+ new_substs .append (new_subst )
997+
998+ # Normal parameters can also be passed by keyword, so check that there
999+ # is a corresponding keyword parameter in sig.
1000+ if other_is_normal :
1001+ # Check if sig can accept this parameter by keyword
1002+ if other_param in sig .kwonly_params or other_param in sig .param_names :
1003+ if other_param in sig .param_names :
1004+ kw_param_idx = sig .param_names .index (other_param )
1005+ else :
1006+ kw_param_idx = len (sig .param_names ) + sig .kwonly_params .index (
1007+ other_param
1008+ )
1009+
1010+ # Check if it's the same parameter we just matched positionally
1011+ pos_param_idx = sig_pos_index - 1 if sig_pos_index > 0 else - 1
1012+ if pos_param_idx >= 0 and pos_param_idx < len (sig .param_names ):
1013+ if kw_param_idx != pos_param_idx :
1014+ # This has to be either the same parameter as the positional one
1015+ # (i.e. they are at the same index in sig), or they both have to
1016+ # be optional (because only one can be passed).
1017+ pos_param = sig .param_names [pos_param_idx ]
1018+ if (
1019+ pos_param not in sig .defaults
1020+ or other_param not in sig .defaults
1021+ ):
1022+ return None
1023+ # Also check type compatibility for the keyword parameter
1024+ kw_type = sig .annotations .get (
1025+ other_param , self .ctx .convert .unsolvable
1026+ )
1027+ new_subst = self ._instantiate_and_match (
1028+ other_type , kw_type , subst , view
1029+ )
1030+ if new_subst is None :
1031+ return None
1032+ elif pos_param_idx == - 1 :
1033+ # Matched via varargs - no additional check needed
1034+ pass
1035+ elif not sig .kwargs_name :
1036+ # sig has no **kwargs to accept this keyword argument
1037+ return None
1038+
1039+ # sig cannot have unconsumed positional-only required parameters.
1040+ for i in range (sig_pos_index , sig .posonly_count ):
1041+ if sig .param_names [i ] not in sig .defaults :
1042+ return None
1043+
1044+ # Match other_sig's keyword-only parameters against sig.
1045+ for kwonly_param in other_sig .kwonly_params :
1046+ other_type = other_sig .annotations .get (
1047+ kwonly_param , self .ctx .convert .unsolvable
1048+ )
1049+ if kwonly_param in sig .kwonly_params or kwonly_param in sig .param_names :
1050+ other_is_optional = kwonly_param in other_sig .defaults
1051+ sig_is_required = kwonly_param not in sig .defaults
1052+ if other_is_optional and sig_is_required :
1053+ return None
1054+ sig_type = sig .annotations .get (
1055+ kwonly_param , self .ctx .convert .unsolvable
1056+ )
1057+ unconsumed_kw_params .discard (kwonly_param )
1058+ elif sig .kwargs_name :
1059+ kwargs_type = sig .annotations .get (
1060+ sig .kwargs_name , self .ctx .convert .unsolvable
1061+ )
1062+ sig_type = self .ctx .convert .get_element_type (kwargs_type )
1063+ if sig_type is None :
1064+ sig_type = self .ctx .convert .unsolvable
1065+ else :
1066+ return None
1067+ new_subst = self ._instantiate_and_match (other_type , sig_type , subst , view )
1068+ if new_subst is None :
1069+ return None
1070+ new_substs .append (new_subst )
1071+
1072+ # sig cannot have unconsumed required parameters (those not covered by
1073+ # *args or **kwargs).
1074+ if not sig .varargs_name and not sig .kwargs_name :
1075+ if unconsumed_kw_params :
1076+ return None
1077+
1078+ sig_ret = sig .annotations .get ("return" , self .ctx .convert .unsolvable )
1079+ other_ret = other_sig .annotations .get ("return" , self .ctx .convert .unsolvable )
1080+ new_subst = self ._instantiate_and_match (sig_ret , other_ret , subst , view )
1081+ if new_subst is None :
1082+ return None
1083+ new_substs .append (new_subst )
1084+
1085+ return self ._merge_substs (subst , new_substs )
1086+
9501087 def _match_type_against_callback_protocol (
9511088 self , left , other_type , subst , view
9521089 ):
@@ -961,17 +1098,20 @@ def _match_type_against_callback_protocol(
9611098 ):
9621099 return None
9631100 new_substs = []
1101+ left_signatures = self ._get_signatures (left , subst , view )
9641102 for expected_method in method_var .data :
9651103 signatures = function .get_signatures (expected_method )
9661104 for sig in signatures :
9671105 sig = sig .drop_first_parameter () # drop `self`
968- expected_callable = self .ctx .pytd_convert .signature_to_callable (sig )
969- new_subst = self ._match_type_against_type (
970- left , expected_callable , subst , view
971- )
972- if new_subst is not None :
973- # For a set of overloaded signatures, only one needs to match.
974- new_substs .append (new_subst )
1106+ found = False
1107+ for left_sig in left_signatures :
1108+ new_subst = self ._match_signatures (left_sig , sig , subst , view )
1109+ if new_subst is not None :
1110+ # For a set of overloaded signatures, only one needs to match.
1111+ found = True
1112+ new_substs .append (new_subst )
1113+ break
1114+ if found :
9751115 break
9761116 else :
9771117 # Every method_var binding must have a matching signature.
0 commit comments