121121from mypy .subtypes import (
122122 find_member ,
123123 is_equivalent ,
124- is_proper_subtype ,
125124 is_same_type ,
126125 is_subtype ,
127126 non_method_protocol_members ,
@@ -2094,7 +2093,7 @@ def infer_function_type_arguments_using_context(
20942093 )
20952094
20962095 def _infer_constraints_from_context (
2097- self , callee : CallableType , error_context : Context , erase : bool = True
2096+ self , callee : CallableType , error_context : Context
20982097 ) -> list [Constraint ]:
20992098 """Unify callable return type to type context to infer type vars.
21002099
@@ -2105,47 +2104,35 @@ def _infer_constraints_from_context(
21052104 ctx = self .type_context [- 1 ]
21062105 if not ctx :
21072106 return []
2108- # The return type may have references to type metavariables that
2109- # we are inferring right now. We must consider them as indeterminate
2110- # and they are not potential results; thus we replace them with the
2111- # special ErasedType type. On the other hand, class type variables are
2112- # valid results.
2113- if erase :
2114- erased_ctx = replace_meta_vars (ctx , ErasedType ())
2115- else :
2116- erased_ctx = ctx
2107+ # if is_overlapping_none(ret_type) and is_overlapping_none(ctx):
2108+ # # If both the context and the return type are optional, unwrap the optional,
2109+ # # since in 99% cases this is what a user expects. In other words, we replace
2110+ # # Optional[T] <: Optional[int]
2111+ # # with
2112+ # # T <: int
2113+ # # while the former would infer T <: Optional[int].
2114+ # ret_type = remove_optional(ret_type)
2115+ # erased_ctx = remove_optional(erased_ctx)
2116+ # #
2117+ # # TODO: Instead of this hack and the one below, we need to use outer and
2118+ # # inner contexts at the same time. This is however not easy because of two
2119+ # # reasons:
2120+ # # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables
2121+ # # on both sides. (This is not too hard.)
2122+ # # * We need to update all the inference "infrastructure", so that all
2123+ # # variables in an expression are inferred at the same time.
2124+ # # (And this is hard, also we need to be careful with lambdas that require
2125+ # # two passes.)
21172126 ret_type = callee .ret_type
2118- if is_overlapping_none (ret_type ) and is_overlapping_none (ctx ):
2119- # If both the context and the return type are optional, unwrap the optional,
2120- # since in 99% cases this is what a user expects. In other words, we replace
2121- # Optional[T] <: Optional[int]
2122- # with
2123- # T <: int
2124- # while the former would infer T <: Optional[int].
2125- ret_type = remove_optional (ret_type )
2126- erased_ctx = remove_optional (erased_ctx )
2127- #
2128- # TODO: Instead of this hack and the one below, we need to use outer and
2129- # inner contexts at the same time. This is however not easy because of two
2130- # reasons:
2131- # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables
2132- # on both sides. (This is not too hard.)
2133- # * We need to update all the inference "infrastructure", so that all
2134- # variables in an expression are inferred at the same time.
2135- # (And this is hard, also we need to be careful with lambdas that require
2136- # two passes.)
2137- # ret_as_union = make_simplified_union([ret_type])
2138- # erased_ctx_as_union = make_simplified_union([ctx])
2139- # if isinstance(ret_as_union, UnionType) and isinstance(erased_ctx_as_union, UnionType):
2140- # new_ret = [val for val in ret_as_union.items if val not in erased_ctx_as_union.items]
2141- # new_ctx = [val for val in erased_ctx_as_union.items if val not in ret_as_union.items]
2142- # ret_type = make_simplified_union(new_ret)
2143- # erased_ctx = make_simplified_union(new_ctx)
2127+ if isinstance (ret_type , UnionType ) and isinstance (ctx , UnionType ):
2128+ new_ret = [val for val in ret_type .items if val not in ctx .items ]
2129+ new_ctx = [val for val in ctx .items if val not in ret_type .items ]
2130+ ret_type = make_simplified_union (new_ret )
2131+ ctx = make_simplified_union (new_ctx )
21442132
21452133 proper_ret = get_proper_type (ret_type )
2146- if (
2147- isinstance (proper_ret , TypeVarType )
2148- or isinstance (proper_ret , UnionType )
2134+ if isinstance (proper_ret , TypeVarType ) or (
2135+ isinstance (proper_ret , UnionType )
21492136 and all (isinstance (get_proper_type (u ), TypeVarType ) for u in proper_ret .items )
21502137 ):
21512138 # Another special case: the return type is a type variable. If it's unrestricted,
@@ -2174,6 +2161,12 @@ def _infer_constraints_from_context(
21742161 if not is_generic_instance (ctx ) and not is_literal_type_like (ctx ):
21752162 return []
21762163
2164+ # The return type may have references to type metavariables that
2165+ # we are inferring right now. We must consider them as indeterminate
2166+ # and they are not potential results; thus we replace them with the
2167+ # special ErasedType type. On the other hand, class type variables are
2168+ # valid results.
2169+ erased_ctx = replace_meta_vars (ctx , ErasedType ())
21772170 constraints = infer_constraints (ret_type , erased_ctx , SUBTYPE_OF )
21782171 return constraints
21792172
@@ -2278,12 +2271,7 @@ def infer_function_type_arguments(
22782271 context = self .argument_infer_context (),
22792272 )
22802273
2281- extra_constraints = self ._infer_constraints_from_context (
2282- callee_type , context , erase = False
2283- )
2284- erased_constraints = self ._infer_constraints_from_context (
2285- callee_type , context , erase = True
2286- )
2274+ extra_constraints = self ._infer_constraints_from_context (callee_type , context )
22872275
22882276 _outer_solution = solve_constraints (
22892277 callee_type .variables ,
@@ -2314,132 +2302,106 @@ def infer_function_type_arguments(
23142302 allow_polymorphic = False ,
23152303 )
23162304
2317- _erased_outer_solution = solve_constraints (
2318- callee_type .variables ,
2319- erased_constraints ,
2320- strict = self .chk .in_checked_function (),
2321- allow_polymorphic = False ,
2322- )
2323-
2324- _erased_joint_solution = solve_constraints (
2325- callee_type .variables ,
2326- constraints + erased_constraints ,
2327- strict = self .chk .in_checked_function (),
2328- allow_polymorphic = False ,
2329- )
2330-
2331- _erased_reverse_joint_solution = solve_constraints (
2332- callee_type .variables ,
2333- erased_constraints + constraints ,
2334- strict = self .chk .in_checked_function (),
2335- allow_polymorphic = False ,
2336- )
2337-
23382305 # Now, we select the solution to use.
23392306 # Note: Since joint uses both outer and inner constraints,
23402307 # and solution discovered by joint is also a solution for outer and inner.
23412308 # therefore, we can pick either inner or outer as a substitute for joint,
23422309 # and then try to solve again using only the inner constraints.
23432310 # joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1])
23442311 # reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1])
2345- outer_solution = ( self . _filter_args ( _outer_solution [ 0 ]), _outer_solution [ 1 ])
2346- inner_solution = ( self . _filter_args ( _inner_solution [ 0 ]), _inner_solution [ 1 ])
2312+ outer_solution = _outer_solution
2313+ inner_solution = _inner_solution
23472314 joint_solution = _joint_solution
23482315 reverse_joint_solution = _reverse_joint_solution
2349- target_solution = _erased_reverse_joint_solution
2316+ target_solution = _reverse_joint_solution
23502317
23512318 if True : # compute the outer and target return types.
2352- # Only substitute non-Uninhabited and non-erased types.
2353- new_args : list [Type | None ] = []
2354- for arg in outer_solution [0 ]:
2355- if has_uninhabited_component (arg ) or has_erased_component (arg ):
2356- new_args .append (None )
2357- else :
2358- new_args .append (arg )
2359- # Don't show errors after we have only used the outer context for inference.
2360- # We will use argument context to infer more variables.
2361- outer_callee = self .apply_generic_arguments (
2362- callee_type , new_args , context , skip_unsatisfied = True
2363- )
2364- outer_ret_type = get_proper_type (outer_callee .ret_type )
2319+ if True :
2320+ outer_callee = self .apply_generic_arguments (
2321+ callee_type , outer_solution [0 ], context , skip_unsatisfied = True
2322+ )
2323+ outer_ret_type = get_proper_type (outer_callee .ret_type )
23652324
2366- # Only substitute non-Uninhabited and non-erased types.
2367- new_args : list [Type | None ] = []
2368- for arg in target_solution [0 ]:
2369- if has_uninhabited_component (arg ) or has_erased_component (arg ):
2370- new_args .append (None )
2371- else :
2372- new_args .append (arg )
2373- # Don't show errors after we have only used the outer context for inference.
2374- # We will use argument context to infer more variables.
2375- target_callee = self .apply_generic_arguments (
2376- callee_type , new_args , context , skip_unsatisfied = True
2377- )
2378- target_ret_type = get_proper_type (target_callee .ret_type )
2325+ target_callee = self .apply_generic_arguments (
2326+ callee_type , target_solution [0 ], context , skip_unsatisfied = True
2327+ )
2328+ target_ret_type = get_proper_type (target_callee .ret_type )
2329+ else :
2330+ # Only substitute non-Uninhabited and non-erased types.
2331+ new_args : list [Type | None ] = []
2332+ for arg in outer_solution [0 ]:
2333+ if has_uninhabited_component (arg ) or has_erased_component (arg ):
2334+ new_args .append (None )
2335+ else :
2336+ new_args .append (arg )
2337+ # Don't show errors after we have only used the outer context for inference.
2338+ # We will use argument context to infer more variables.
2339+ outer_callee = self .apply_generic_arguments (
2340+ callee_type , new_args , context , skip_unsatisfied = True
2341+ )
2342+ outer_ret_type = get_proper_type (outer_callee .ret_type )
2343+
2344+ # Only substitute non-Uninhabited and non-erased types.
2345+ new_args : list [Type | None ] = []
2346+ for arg in target_solution [0 ]:
2347+ if has_uninhabited_component (arg ) or has_erased_component (arg ):
2348+ new_args .append (None )
2349+ else :
2350+ new_args .append (arg )
2351+ # Don't show errors after we have only used the outer context for inference.
2352+ # We will use argument context to infer more variables.
2353+ target_callee = self .apply_generic_arguments (
2354+ callee_type , new_args , context , skip_unsatisfied = True
2355+ )
2356+ target_ret_type = get_proper_type (target_callee .ret_type )
23792357
23802358 use_joint = True
23812359 use_outer = True
23822360 use_inner = True
23832361 # check if we can use the joint solution, otherwise fallback to outer_solution
2384- for outer_tp , inner_tp , joint_tp in zip (
2385- outer_solution [0 ], inner_solution [0 ], target_solution [0 ]
2386- ):
2387- if joint_tp is None and outer_tp is not None :
2388- use_joint = False
2389- if has_erased_component (joint_tp ) and not has_erased_component (inner_tp ):
2390- # If the joint solution is erased, but outer is not, we use outer.
2391- use_joint = False
2392- if has_erased_component (outer_tp ) and not has_erased_component (inner_tp ):
2393- use_outer = False
2394- if has_erased_component (inner_tp ):
2395- use_inner = False
2396-
2397- combined_solution = []
2398- for outer_tp , joint_tp in zip (outer_solution [0 ], target_solution [0 ]):
2399- if (
2400- outer_tp is not None
2401- and joint_tp is not None
2402- and is_proper_subtype (outer_tp , joint_tp )
2403- ):
2404- # If outer is a subtype of joint, we can use joint.
2405- combined_solution .append (outer_tp )
2406- else :
2407- # Otherwise, we use joint.
2408- combined_solution .append (joint_tp )
2409-
2410- # new_vars, new_constraints = self.intersect_solutions(outer_solution[0], target_solution[0])
2411- # intersected_solution = solve_constraints(
2412- # new_vars,
2413- # new_constraints,
2414- # strict=self.chk.in_checked_function(),
2415- # allow_polymorphic=False,
2416- # )
2417-
2418- # if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2419- # if all(
2420- # (joint_tp is None and outer_tp is None)
2421- # or (
2422- # (joint_tp is not None and outer_tp is not None)
2423- # and (
2424- # is_subtype(outer_tp, joint_tp)
2425- # or (
2426- # isinstance(outer_tp, UnionType)
2427- # and any(is_subtype(val, joint_tp) for val in outer_tp.items)
2428- # )
2429- # )
2430- # )
2431- # for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0])
2362+ # for outer_tp, inner_tp, joint_tp in zip(
2363+ # outer_solution[0], inner_solution[0], target_solution[0]
24322364 # ):
2433- # use_joint = False
2434- # use_outer = True
2365+ # if joint_tp is None and outer_tp is not None:
2366+ # use_joint = False
2367+ # if has_erased_component(joint_tp) and not has_erased_component(inner_tp):
2368+ # # If the joint solution is erased, but outer is not, we use outer.
2369+ # use_joint = False
2370+ # if has_erased_component(outer_tp) and not has_erased_component(inner_tp):
2371+ # use_outer = False
2372+ # if has_erased_component(inner_tp):
2373+ # use_inner = False
2374+
2375+ if any (tp is None for tp in inner_solution [0 ]):
2376+ use_inner = False
2377+ if any (tp is None for tp in outer_solution [0 ]):
2378+ use_outer = False
2379+ if any (tp is None for tp in joint_solution [0 ]):
2380+ use_joint = False
24352381
2436- # if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2437- if is_subtype (outer_ret_type , target_ret_type ) or (
2438- isinstance (outer_ret_type , UnionType )
2439- and any (is_subtype (val , target_ret_type ) for val in outer_ret_type .items )
2382+ if (
2383+ # if the joint failed to solve use the outer solution instead.
2384+ # any(joint_tp is None and outer_tp is not None for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0]))
2385+ # If the outer solution is more concrete than the joint solution, use the outer solution.
2386+ # This also applies if the outer solution is a union type where at least one member
2387+ # is a subtype of the target return type.
2388+ is_subtype (outer_ret_type , target_ret_type )
2389+ or (
2390+ isinstance (outer_ret_type , UnionType )
2391+ and any (is_subtype (val , target_ret_type ) for val in outer_ret_type .items )
2392+ )
24402393 ):
24412394 use_joint = False
2442- use_outer = True
2395+ # use_outer = True
2396+
2397+ if use_joint :
2398+ target_solution = reverse_joint_solution
2399+ elif use_outer :
2400+ target_solution = outer_solution
2401+ elif use_inner :
2402+ target_solution = inner_solution
2403+ else :
2404+ raise AssertionError
24432405
24442406 # what if the outer context is a union type?
24452407 # we may have a case like:
@@ -2450,17 +2412,12 @@ def infer_function_type_arguments(
24502412 _num = arg_pass_nums
24512413 _c0 = constraints
24522414 _c1 = extra_constraints
2453- _c2 = erased_constraints
24542415
24552416 _x0 = _outer_solution [0 ]
24562417 _x2 = _inner_solution [0 ]
24572418 _x3 = _joint_solution [0 ]
24582419 _x4 = _reverse_joint_solution [0 ]
24592420
2460- _e0 = _erased_outer_solution [0 ]
2461- _e2 = _erased_joint_solution [0 ]
2462- _e3 = _erased_reverse_joint_solution [0 ]
2463-
24642421 _s1 = outer_solution [0 ]
24652422 _s2 = inner_solution [0 ]
24662423 _s3 = joint_solution [0 ]
@@ -2471,7 +2428,7 @@ def infer_function_type_arguments(
24712428 _r2 = target_ret_type
24722429 _y1 = outer_callee
24732430 _y2 = target_callee
2474- _u0 = use_inner , use_outer , use_joint
2431+ _u0 = use_outer , use_joint
24752432
24762433 if use_joint :
24772434 new_inferred_args = target_solution [0 ]
@@ -2483,7 +2440,7 @@ def infer_function_type_arguments(
24832440 # ]
24842441 elif use_outer :
24852442 # If we cannot use the joint solution, fallback to outer_solution
2486- new_inferred_args = outer_solution [0 ]
2443+ new_inferred_args = target_solution [0 ]
24872444
24882445 # Only substitute non-Uninhabited and non-erased types.
24892446 new_args : list [Type | None ] = []
0 commit comments