@@ -2134,6 +2134,14 @@ def _infer_constraints_from_context(
21342134 # variables in an expression are inferred at the same time.
21352135 # (And this is hard, also we need to be careful with lambdas that require
21362136 # two passes.)
2137+ # ret_as_union = make_simplified_union([ret_type])
2138+ # erased_ctx_as_union = make_simplified_union([ctx])
2139+ # if isinstance(ret_as_union, UnionType) and isinstance(erased_ctx_as_union, UnionType):
2140+ # new_ret = [val for val in ret_as_union.items if val not in erased_ctx_as_union.items]
2141+ # new_ctx = [val for val in erased_ctx_as_union.items if val not in ret_as_union.items]
2142+ # ret_type = make_simplified_union(new_ret)
2143+ # erased_ctx = make_simplified_union(new_ctx)
2144+
21372145 proper_ret = get_proper_type (ret_type )
21382146 if (
21392147 isinstance (proper_ret , TypeVarType )
@@ -2184,6 +2192,33 @@ def _filter_args(self, args: list[Type | None]) -> list[Type | None]:
21842192 # new_args.append(arg)
21852193 return new_args
21862194
2195+ def intersect_solutions (self , sol1 : list [Type | None ], sol2 : list [Type | None ]):
2196+ # first, ensure that the None-patterns agree
2197+ assert len (sol1 ) == len (sol2 )
2198+
2199+ virtual_vars = []
2200+ constraints = []
2201+
2202+ for i , (tp1 , tp2 ) in enumerate (zip (sol1 , sol2 )):
2203+ new_id = TypeVarId .new (- 1 )
2204+ name = f"V{ i } "
2205+ new_tvar = TypeVarType (
2206+ name ,
2207+ name ,
2208+ new_id ,
2209+ values = [],
2210+ upper_bound = self .object_type (),
2211+ default = AnyType (TypeOfAny .from_omitted_generics ),
2212+ )
2213+ virtual_vars .append (new_tvar )
2214+ if tp1 is not None :
2215+ c1 = Constraint (new_tvar , SUBTYPE_OF , tp1 )
2216+ constraints .append (c1 )
2217+ if tp2 is not None :
2218+ c2 = Constraint (new_tvar , SUBTYPE_OF , tp2 )
2219+ constraints .append (c2 )
2220+ return virtual_vars , constraints
2221+
21872222 def infer_function_type_arguments (
21882223 self ,
21892224 callee_type : CallableType ,
@@ -2311,9 +2346,37 @@ def infer_function_type_arguments(
23112346 inner_solution = (self ._filter_args (_inner_solution [0 ]), _inner_solution [1 ])
23122347 joint_solution = _joint_solution
23132348 reverse_joint_solution = _reverse_joint_solution
2314-
23152349 target_solution = _erased_reverse_joint_solution
23162350
2351+ if True : # compute the outer and target return types.
2352+ # Only substitute non-Uninhabited and non-erased types.
2353+ new_args : list [Type | None ] = []
2354+ for arg in outer_solution [0 ]:
2355+ if has_uninhabited_component (arg ) or has_erased_component (arg ):
2356+ new_args .append (None )
2357+ else :
2358+ new_args .append (arg )
2359+ # Don't show errors after we have only used the outer context for inference.
2360+ # We will use argument context to infer more variables.
2361+ outer_callee = self .apply_generic_arguments (
2362+ callee_type , new_args , context , skip_unsatisfied = True
2363+ )
2364+ outer_ret_type = get_proper_type (outer_callee .ret_type )
2365+
2366+ # Only substitute non-Uninhabited and non-erased types.
2367+ new_args : list [Type | None ] = []
2368+ for arg in target_solution [0 ]:
2369+ if has_uninhabited_component (arg ) or has_erased_component (arg ):
2370+ new_args .append (None )
2371+ else :
2372+ new_args .append (arg )
2373+ # Don't show errors after we have only used the outer context for inference.
2374+ # We will use argument context to infer more variables.
2375+ target_callee = self .apply_generic_arguments (
2376+ callee_type , new_args , context , skip_unsatisfied = True
2377+ )
2378+ target_ret_type = get_proper_type (target_callee .ret_type )
2379+
23172380 use_joint = True
23182381 use_outer = True
23192382 use_inner = True
@@ -2344,18 +2407,46 @@ def infer_function_type_arguments(
23442407 # Otherwise, we use joint.
23452408 combined_solution .append (joint_tp )
23462409
2410+ # new_vars, new_constraints = self.intersect_solutions(outer_solution[0], target_solution[0])
2411+ # intersected_solution = solve_constraints(
2412+ # new_vars,
2413+ # new_constraints,
2414+ # strict=self.chk.in_checked_function(),
2415+ # allow_polymorphic=False,
2416+ # )
2417+
23472418 # if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2348- if all (
2349- (joint_tp is None and outer_tp is None )
2350- or (
2351- (joint_tp is not None and outer_tp is not None )
2352- and is_subtype (outer_tp , joint_tp )
2353- )
2354- for outer_tp , joint_tp in zip (outer_solution [0 ], target_solution [0 ])
2419+ # if all(
2420+ # (joint_tp is None and outer_tp is None)
2421+ # or (
2422+ # (joint_tp is not None and outer_tp is not None)
2423+ # and (
2424+ # is_subtype(outer_tp, joint_tp)
2425+ # or (
2426+ # isinstance(outer_tp, UnionType)
2427+ # and any(is_subtype(val, joint_tp) for val in outer_tp.items)
2428+ # )
2429+ # )
2430+ # )
2431+ # for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0])
2432+ # ):
2433+ # use_joint = False
2434+ # use_outer = True
2435+
2436+ # if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2437+ if is_subtype (outer_ret_type , target_ret_type ) or (
2438+ isinstance (outer_ret_type , UnionType )
2439+ and any (is_subtype (val , target_ret_type ) for val in outer_ret_type .items )
23552440 ):
23562441 use_joint = False
23572442 use_outer = True
23582443
2444+ # what if the outer context is a union type?
2445+ # we may have a case like:
2446+ # outer : int | Literal["foo"]
2447+ # inner: Literal["foo"]? (which gets translated into str later)
2448+ # here, we would want `Literal["foo"]` to be used as the solution,
2449+
23592450 _num = arg_pass_nums
23602451 _c0 = constraints
23612452 _c1 = extra_constraints
@@ -2375,6 +2466,11 @@ def infer_function_type_arguments(
23752466 _s3 = joint_solution [0 ]
23762467 _s4 = reverse_joint_solution [0 ]
23772468 _t0 = target_solution [0 ]
2469+
2470+ _r1 = outer_ret_type
2471+ _r2 = target_ret_type
2472+ _y1 = outer_callee
2473+ _y2 = target_callee
23782474 _u0 = use_inner , use_outer , use_joint
23792475
23802476 if use_joint :
0 commit comments