@@ -1716,33 +1716,9 @@ def check_callable_call(
17161716 callee = callee .copy_modified (ret_type = fresh_ret_type )
17171717
17181718 if callee .is_generic ():
1719- need_refresh = any (
1720- isinstance ( v , ( ParamSpecType , TypeVarTupleType )) for v in callee . variables
1719+ callee , formal_to_actual = self . adjust_generic_callable_params_mapping (
1720+ callee , args , arg_kinds , arg_names , formal_to_actual , context
17211721 )
1722- callee = freshen_function_type_vars (callee )
1723- callee = self .infer_function_type_arguments_using_context (callee , context )
1724- if need_refresh :
1725- # Argument kinds etc. may have changed due to
1726- # ParamSpec or TypeVarTuple variables being replaced with an arbitrary
1727- # number of arguments; recalculate actual-to-formal map
1728- formal_to_actual = map_actuals_to_formals (
1729- arg_kinds ,
1730- arg_names ,
1731- callee .arg_kinds ,
1732- callee .arg_names ,
1733- lambda i : self .accept (args [i ]),
1734- )
1735- callee = self .infer_function_type_arguments (
1736- callee , args , arg_kinds , arg_names , formal_to_actual , need_refresh , context
1737- )
1738- if need_refresh :
1739- formal_to_actual = map_actuals_to_formals (
1740- arg_kinds ,
1741- arg_names ,
1742- callee .arg_kinds ,
1743- callee .arg_names ,
1744- lambda i : self .accept (args [i ]),
1745- )
17461722
17471723 param_spec = callee .param_spec ()
17481724 if (
@@ -2633,7 +2609,7 @@ def check_overload_call(
26332609 arg_types = self .infer_arg_types_in_empty_context (args )
26342610 # Step 1: Filter call targets to remove ones where the argument counts don't match
26352611 plausible_targets = self .plausible_overload_call_targets (
2636- arg_types , arg_kinds , arg_names , callee
2612+ args , arg_types , arg_kinds , arg_names , callee , context
26372613 )
26382614
26392615 # Step 2: If the arguments contain a union, we try performing union math first,
@@ -2751,12 +2727,52 @@ def check_overload_call(
27512727 self .chk .fail (message_registry .TOO_MANY_UNION_COMBINATIONS , context )
27522728 return result
27532729
2730+ def adjust_generic_callable_params_mapping (
2731+ self ,
2732+ callee : CallableType ,
2733+ args : list [Expression ],
2734+ arg_kinds : list [ArgKind ],
2735+ arg_names : Sequence [str | None ] | None ,
2736+ formal_to_actual : list [list [int ]],
2737+ context : Context ,
2738+ ) -> tuple [CallableType , list [list [int ]]]:
2739+ need_refresh = any (
2740+ isinstance (v , (ParamSpecType , TypeVarTupleType )) for v in callee .variables
2741+ )
2742+ callee = freshen_function_type_vars (callee )
2743+ callee = self .infer_function_type_arguments_using_context (callee , context )
2744+ if need_refresh :
2745+ # Argument kinds etc. may have changed due to
2746+ # ParamSpec or TypeVarTuple variables being replaced with an arbitrary
2747+ # number of arguments; recalculate actual-to-formal map
2748+ formal_to_actual = map_actuals_to_formals (
2749+ arg_kinds ,
2750+ arg_names ,
2751+ callee .arg_kinds ,
2752+ callee .arg_names ,
2753+ lambda i : self .accept (args [i ]),
2754+ )
2755+ callee = self .infer_function_type_arguments (
2756+ callee , args , arg_kinds , arg_names , formal_to_actual , need_refresh , context
2757+ )
2758+ if need_refresh :
2759+ formal_to_actual = map_actuals_to_formals (
2760+ arg_kinds ,
2761+ arg_names ,
2762+ callee .arg_kinds ,
2763+ callee .arg_names ,
2764+ lambda i : self .accept (args [i ]),
2765+ )
2766+ return callee , formal_to_actual
2767+
27542768 def plausible_overload_call_targets (
27552769 self ,
2770+ args : list [Expression ],
27562771 arg_types : list [Type ],
27572772 arg_kinds : list [ArgKind ],
27582773 arg_names : Sequence [str | None ] | None ,
27592774 overload : Overloaded ,
2775+ context : Context ,
27602776 ) -> list [CallableType ]:
27612777 """Returns all overload call targets that having matching argument counts.
27622778
@@ -2790,6 +2806,10 @@ def has_shape(typ: Type) -> bool:
27902806 formal_to_actual = map_actuals_to_formals (
27912807 arg_kinds , arg_names , typ .arg_kinds , typ .arg_names , lambda i : arg_types [i ]
27922808 )
2809+ if typ .is_generic ():
2810+ typ , formal_to_actual = self .adjust_generic_callable_params_mapping (
2811+ typ , args , arg_kinds , arg_names , formal_to_actual , context
2812+ )
27932813
27942814 with self .msg .filter_errors ():
27952815 if self .check_argument_count (
0 commit comments