121121from mypy .subtypes import (
122122 find_member ,
123123 is_equivalent ,
124+ is_proper_subtype ,
124125 is_same_type ,
125126 is_subtype ,
126127 non_method_protocol_members ,
@@ -2093,7 +2094,7 @@ def infer_function_type_arguments_using_context(
20932094 )
20942095
20952096 def _infer_constraints_from_context (
2096- self , callee : CallableType , error_context : Context
2097+ self , callee : CallableType , error_context : Context , erase : bool = True
20972098 ) -> list [Constraint ]:
20982099 """Unify callable return type to type context to infer type vars.
20992100
@@ -2109,7 +2110,10 @@ def _infer_constraints_from_context(
21092110 # and they are not potential results; thus we replace them with the
21102111 # special ErasedType type. On the other hand, class type variables are
21112112 # valid results.
2112- erased_ctx = replace_meta_vars (ctx , ErasedType ())
2113+ if erase :
2114+ erased_ctx = replace_meta_vars (ctx , ErasedType ())
2115+ else :
2116+ erased_ctx = ctx
21132117 ret_type = callee .ret_type
21142118 if is_overlapping_none (ret_type ) and is_overlapping_none (ctx ):
21152119 # If both the context and the return type are optional, unwrap the optional,
@@ -2168,10 +2172,16 @@ def _infer_constraints_from_context(
21682172 def _filter_args (self , args : list [Type | None ]) -> list [Type | None ]:
21692173 new_args : list [Type | None ] = []
21702174 for arg in args :
2171- if has_uninhabited_component ( arg ) or has_erased_component ( arg ) :
2175+ if arg is None :
21722176 new_args .append (None )
2177+ continue
21732178 else :
2179+ arg = replace_meta_vars (arg , ErasedType ())
21742180 new_args .append (arg )
2181+ # if has_erased_component(arg) or has_uninhabited_component(arg):
2182+ # new_args.append(None)
2183+ # else:
2184+ # new_args.append(arg)
21752185 return new_args
21762186
21772187 def infer_function_type_arguments (
@@ -2224,18 +2234,6 @@ def infer_function_type_arguments(
22242234 new_inferred_args = None
22252235
22262236 if True : # NEW CODE
2227- extra_constraints = self ._infer_constraints_from_context (callee_type , context )
2228-
2229- # outer_solution
2230- _outer_solution = solve_constraints (
2231- callee_type .variables ,
2232- extra_constraints ,
2233- strict = self .chk .in_checked_function (),
2234- allow_polymorphic = False ,
2235- )
2236- outer_solution = (self ._filter_args (_outer_solution [0 ]), _outer_solution [1 ])
2237-
2238- # inner solution
22392237 constraints = infer_constraints_for_callable (
22402238 callee_type ,
22412239 pass1_args ,
@@ -2244,38 +2242,84 @@ def infer_function_type_arguments(
22442242 formal_to_actual ,
22452243 context = self .argument_infer_context (),
22462244 )
2247- inner_solution = solve_constraints (
2245+
2246+ extra_constraints = self ._infer_constraints_from_context (
2247+ callee_type , context , erase = False
2248+ )
2249+ erased_constraints = self ._infer_constraints_from_context (
2250+ callee_type , context , erase = True
2251+ )
2252+
2253+ _outer_solution = solve_constraints (
2254+ callee_type .variables ,
2255+ extra_constraints ,
2256+ strict = self .chk .in_checked_function (),
2257+ allow_polymorphic = False ,
2258+ )
2259+
2260+ _inner_solution = solve_constraints (
22482261 callee_type .variables ,
22492262 constraints ,
22502263 strict = self .chk .in_checked_function (),
22512264 allow_polymorphic = False ,
22522265 )
2266+ # NOTE: The order of constraints is important here!
2267+ # solve(outer + inner) and solve(inner + outer) may yield different results.
2268+ _joint_solution = solve_constraints (
2269+ callee_type .variables ,
2270+ constraints + extra_constraints ,
2271+ strict = self .chk .in_checked_function (),
2272+ allow_polymorphic = False ,
2273+ )
22532274
2254- joint_solution = solve_constraints (
2275+ _reverse_joint_solution = solve_constraints (
22552276 callee_type .variables ,
22562277 extra_constraints + constraints ,
22572278 strict = self .chk .in_checked_function (),
22582279 allow_polymorphic = False ,
22592280 )
22602281
2261- # check if we can use the joint solution, otherwise fallback to outer_solution
2262- for var1 , var2 in zip (
2263- outer_solution [0 ], joint_solution [0 ]
2264- ): # tuple[Type | None, Type | None]
2265- if var2 is None and var1 is not None :
2266- # using both constraints did not find a solution for this variable
2267- # so we fallback to outer_solution, apply the solution, and then recompute the inner part
2268- use_joint = False
2269- break
2270- else :
2271- use_joint = True
2282+ _erased_outer_solution = solve_constraints (
2283+ callee_type .variables ,
2284+ erased_constraints ,
2285+ strict = self .chk .in_checked_function (),
2286+ allow_polymorphic = False ,
2287+ )
2288+
2289+ _erased_joint_solution = solve_constraints (
2290+ callee_type .variables ,
2291+ constraints + erased_constraints ,
2292+ strict = self .chk .in_checked_function (),
2293+ allow_polymorphic = False ,
2294+ )
2295+
2296+ _erased_reverse_joint_solution = solve_constraints (
2297+ callee_type .variables ,
2298+ erased_constraints + constraints ,
2299+ strict = self .chk .in_checked_function (),
2300+ allow_polymorphic = False ,
2301+ )
2302+
2303+ # Now, we select the solution to use.
2304+ # Note: Since joint uses both outer and inner constraints,
2305+ # and solution discovered by joint is also a solution for outer and inner.
2306+ # therefore, we can pick either inner or outer as a substitute for joint,
2307+ # and then try to solve again using only the inner constraints.
2308+ # joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1])
2309+ # reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1])
2310+ outer_solution = (self ._filter_args (_outer_solution [0 ]), _outer_solution [1 ])
2311+ inner_solution = (self ._filter_args (_inner_solution [0 ]), _inner_solution [1 ])
2312+ joint_solution = _joint_solution
2313+ reverse_joint_solution = _reverse_joint_solution
2314+
2315+ target_solution = _erased_reverse_joint_solution
22722316
22732317 use_joint = True
22742318 use_outer = True
22752319 use_inner = True
2276-
2320+ # check if we can use the joint solution, otherwise fallback to outer_solution
22772321 for outer_tp , inner_tp , joint_tp in zip (
2278- outer_solution [0 ], inner_solution [0 ], joint_solution [0 ]
2322+ outer_solution [0 ], inner_solution [0 ], target_solution [0 ]
22792323 ):
22802324 if joint_tp is None and outer_tp is not None :
22812325 use_joint = False
@@ -2287,8 +2331,55 @@ def infer_function_type_arguments(
22872331 if has_erased_component (inner_tp ):
22882332 use_inner = False
22892333
2334+ combined_solution = []
2335+ for outer_tp , joint_tp in zip (outer_solution [0 ], target_solution [0 ]):
2336+ if (
2337+ outer_tp is not None
2338+ and joint_tp is not None
2339+ and is_proper_subtype (outer_tp , joint_tp )
2340+ ):
2341+ # If outer is a subtype of joint, we can use joint.
2342+ combined_solution .append (outer_tp )
2343+ else :
2344+ # Otherwise, we use joint.
2345+ combined_solution .append (joint_tp )
2346+
2347+ _num = arg_pass_nums
2348+ _c0 = constraints
2349+ _c1 = extra_constraints
2350+ _c2 = erased_constraints
2351+
2352+ _x0 = _outer_solution [0 ]
2353+ _x2 = _inner_solution [0 ]
2354+ _x3 = _joint_solution [0 ]
2355+ _x4 = _reverse_joint_solution [0 ]
2356+
2357+ _e0 = _erased_outer_solution [0 ]
2358+ _e2 = _erased_joint_solution [0 ]
2359+ _e3 = _erased_reverse_joint_solution [0 ]
2360+
2361+ _s1 = outer_solution [0 ]
2362+ _s2 = inner_solution [0 ]
2363+ _s3 = joint_solution [0 ]
2364+ _s4 = reverse_joint_solution [0 ]
2365+ _s5 = combined_solution
2366+
2367+ _u0 = use_inner , use_outer , use_joint
2368+
2369+ # if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2370+ if all (
2371+ (joint_tp is None and outer_tp is None )
2372+ or (
2373+ (joint_tp is not None and outer_tp is not None )
2374+ and is_subtype (outer_tp , joint_tp )
2375+ )
2376+ for outer_tp , joint_tp in zip (outer_solution [0 ], target_solution [0 ])
2377+ ):
2378+ use_joint = False
2379+ use_outer = True
2380+
22902381 if use_joint :
2291- new_inferred_args = joint_solution [0 ]
2382+ new_inferred_args = target_solution [0 ]
22922383 # inferred_args = [
22932384 # # Usually, joint_tp <: outer_tp (since superset of constraints),
22942385 # # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]`
0 commit comments