@@ -1781,18 +1781,6 @@ def check_callable_call(
17811781 isinstance (v , (ParamSpecType , TypeVarTupleType )) for v in callee .variables
17821782 )
17831783 callee = freshen_function_type_vars (callee )
1784- # callee = self.infer_function_type_arguments_using_context(callee, context)
1785- # if need_refresh:
1786- # # Argument kinds etc. may have changed due to
1787- # # ParamSpec or TypeVarTuple variables being replaced with an arbitrary
1788- # # number of arguments; recalculate actual-to-formal map
1789- # formal_to_actual = map_actuals_to_formals(
1790- # arg_kinds,
1791- # arg_names,
1792- # callee.arg_kinds,
1793- # callee.arg_names,
1794- # lambda i: self.accept(args[i]),
1795- # )
17961784 callee = self .infer_function_type_arguments (
17971785 callee , args , arg_kinds , arg_names , formal_to_actual , need_refresh , context
17981786 )
@@ -2125,6 +2113,11 @@ def _infer_constraints_from_context(
21252113 # # two passes.)
21262114 ret_type = callee .ret_type
21272115 if isinstance (ret_type , UnionType ) and isinstance (ctx , UnionType ):
2116+ # If both the context and the return type are unions, we simplify shared items
2117+ # e.g. T | None <: int | None => T <: int
2118+ # since the former would infer T <: int | None.
2119+ # whereas the latter would infer the more precise T <: int.
2120+
21282121 new_ret = [val for val in ret_type .items if val not in ctx .items ]
21292122 new_ctx = [val for val in ctx .items if val not in ret_type .items ]
21302123 ret_type = make_simplified_union (new_ret )
@@ -2170,48 +2163,6 @@ def _infer_constraints_from_context(
21702163 constraints = infer_constraints (ret_type , erased_ctx , SUBTYPE_OF )
21712164 return constraints
21722165
2173- def _filter_args (self , args : list [Type | None ]) -> list [Type | None ]:
2174- new_args : list [Type | None ] = []
2175- for arg in args :
2176- if arg is None :
2177- new_args .append (None )
2178- continue
2179- else :
2180- arg = replace_meta_vars (arg , ErasedType ())
2181- new_args .append (arg )
2182- # if has_erased_component(arg) or has_uninhabited_component(arg):
2183- # new_args.append(None)
2184- # else:
2185- # new_args.append(arg)
2186- return new_args
2187-
2188- def intersect_solutions (self , sol1 : list [Type | None ], sol2 : list [Type | None ]):
2189- # first, ensure that the None-patterns agree
2190- assert len (sol1 ) == len (sol2 )
2191-
2192- virtual_vars = []
2193- constraints = []
2194-
2195- for i , (tp1 , tp2 ) in enumerate (zip (sol1 , sol2 )):
2196- new_id = TypeVarId .new (- 1 )
2197- name = f"V{ i } "
2198- new_tvar = TypeVarType (
2199- name ,
2200- name ,
2201- new_id ,
2202- values = [],
2203- upper_bound = self .object_type (),
2204- default = AnyType (TypeOfAny .from_omitted_generics ),
2205- )
2206- virtual_vars .append (new_tvar )
2207- if tp1 is not None :
2208- c1 = Constraint (new_tvar , SUBTYPE_OF , tp1 )
2209- constraints .append (c1 )
2210- if tp2 is not None :
2211- c2 = Constraint (new_tvar , SUBTYPE_OF , tp2 )
2212- constraints .append (c2 )
2213- return virtual_vars , constraints
2214-
22152166 def infer_function_type_arguments (
22162167 self ,
22172168 callee_type : CallableType ,
@@ -2258,11 +2209,9 @@ def infer_function_type_arguments(
22582209 context = self .argument_infer_context (),
22592210 strict = self .chk .in_checked_function (),
22602211 )
2261- old_inferred_args = inferred_args
2262- new_inferred_args = None
22632212
22642213 if True : # NEW CODE
2265- constraints = infer_constraints_for_callable (
2214+ inner_constraints = infer_constraints_for_callable (
22662215 callee_type ,
22672216 pass1_args ,
22682217 arg_kinds ,
@@ -2271,49 +2220,48 @@ def infer_function_type_arguments(
22712220 context = self .argument_infer_context (),
22722221 )
22732222
2274- extra_constraints = self ._infer_constraints_from_context (callee_type , context )
2223+ outer_constraints = self ._infer_constraints_from_context (callee_type , context )
22752224
2276- _outer_solution = solve_constraints (
2225+ outer_solution = solve_constraints (
22772226 callee_type .variables ,
2278- extra_constraints ,
2227+ outer_constraints ,
22792228 strict = self .chk .in_checked_function (),
22802229 allow_polymorphic = False ,
22812230 )
22822231
2283- _inner_solution = solve_constraints (
2232+ inner_solution = solve_constraints (
22842233 callee_type .variables ,
2285- constraints ,
2234+ inner_constraints ,
22862235 strict = self .chk .in_checked_function (),
22872236 allow_polymorphic = False ,
22882237 )
22892238 # NOTE: The order of constraints is important here!
22902239 # solve(outer + inner) and solve(inner + outer) may yield different results.
2291- _joint_solution = solve_constraints (
2240+
2241+ joint_solution = solve_constraints (
22922242 callee_type .variables ,
2293- constraints + extra_constraints ,
2243+ outer_constraints + inner_constraints ,
22942244 strict = self .chk .in_checked_function (),
22952245 allow_polymorphic = False ,
22962246 )
22972247
2298- _reverse_joint_solution = solve_constraints (
2248+ reverse_joint_solution = solve_constraints (
22992249 callee_type .variables ,
2300- extra_constraints + constraints ,
2250+ inner_constraints + outer_constraints ,
23012251 strict = self .chk .in_checked_function (),
23022252 allow_polymorphic = False ,
23032253 )
23042254
2255+ target_solution = joint_solution
2256+
23052257 # Now, we select the solution to use.
23062258 # Note: Since joint uses both outer and inner constraints,
23072259 # and solution discovered by joint is also a solution for outer and inner.
23082260 # therefore, we can pick either inner or outer as a substitute for joint,
23092261 # and then try to solve again using only the inner constraints.
2310- # joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1])
2311- # reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1])
2312- outer_solution = _outer_solution
2313- inner_solution = _inner_solution
2314- joint_solution = _joint_solution
2315- reverse_joint_solution = _reverse_joint_solution
2316- target_solution = _reverse_joint_solution
2262+ use_joint = True
2263+ use_outer = True
2264+ use_inner = True
23172265
23182266 if True : # compute the outer and target return types.
23192267 if True :
@@ -2355,44 +2303,17 @@ def infer_function_type_arguments(
23552303 )
23562304 target_ret_type = get_proper_type (target_callee .ret_type )
23572305
2358- use_joint = True
2359- use_outer = True
2360- use_inner = True
2361- # check if we can use the joint solution, otherwise fallback to outer_solution
2362- # for outer_tp, inner_tp, joint_tp in zip(
2363- # outer_solution[0], inner_solution[0], target_solution[0]
2364- # ):
2365- # if joint_tp is None and outer_tp is not None:
2366- # use_joint = False
2367- # if has_erased_component(joint_tp) and not has_erased_component(inner_tp):
2368- # # If the joint solution is erased, but outer is not, we use outer.
2369- # use_joint = False
2370- # if has_erased_component(outer_tp) and not has_erased_component(inner_tp):
2371- # use_outer = False
2372- # if has_erased_component(inner_tp):
2373- # use_inner = False
2374-
2375- if any (tp is None for tp in inner_solution [0 ]):
2376- use_inner = False
2377- if any (tp is None for tp in outer_solution [0 ]):
2378- use_outer = False
2379- if any (tp is None for tp in joint_solution [0 ]):
2380- use_joint = False
2381-
23822306 if (
2383- # if the joint failed to solve use the outer solution instead.
2384- # any(joint_tp is None and outer_tp is not None for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0]))
2307+ # joint constraints failed to produce a complete solution
2308+ None in joint_solution [0 ]
23852309 # If the outer solution is more concrete than the joint solution, use the outer solution.
2386- # This also applies if the outer solution is a union type where at least one member
2387- # is a subtype of the target return type.
2388- is_subtype (outer_ret_type , target_ret_type )
2389- or (
2310+ or is_subtype (outer_ret_type , target_ret_type )
2311+ or ( # HACK to fix testLiteralAndGenericWithUnion
23902312 isinstance (outer_ret_type , UnionType )
23912313 and any (is_subtype (val , target_ret_type ) for val in outer_ret_type .items )
23922314 )
23932315 ):
23942316 use_joint = False
2395- # use_outer = True
23962317
23972318 if use_joint :
23982319 target_solution = reverse_joint_solution
@@ -2403,52 +2324,38 @@ def infer_function_type_arguments(
24032324 else :
24042325 raise AssertionError
24052326
2406- # what if the outer context is a union type?
2407- # we may have a case like:
2408- # outer : int | Literal["foo"]
2409- # inner: Literal["foo"]? (which gets translated into str later)
2410- # here, we would want `Literal["foo"]` to be used as the solution,
2411-
2412- _num = arg_pass_nums
2413- _c0 = constraints
2414- _c1 = extra_constraints
2415-
2416- _x0 = _outer_solution [0 ]
2417- _x2 = _inner_solution [0 ]
2418- _x3 = _joint_solution [0 ]
2419- _x4 = _reverse_joint_solution [0 ]
2420-
2421- _s1 = outer_solution [0 ]
2422- _s2 = inner_solution [0 ]
2423- _s3 = joint_solution [0 ]
2424- _s4 = reverse_joint_solution [0 ]
2425- _t0 = target_solution [0 ]
2426-
2427- _r1 = outer_ret_type
2428- _r2 = target_ret_type
2429- _y1 = outer_callee
2430- _y2 = target_callee
2431- _u0 = use_outer , use_joint
2327+ if __debug__ :
2328+ _num = arg_pass_nums
2329+ _c0 = inner_constraints
2330+ _c1 = outer_constraints
2331+
2332+ _s1 = outer_solution [0 ]
2333+ _s2 = inner_solution [0 ]
2334+ _s3 = joint_solution [0 ]
2335+ _s4 = reverse_joint_solution [0 ]
2336+ _t0 = target_solution [0 ]
2337+
2338+ _r1 = outer_ret_type
2339+ _r2 = target_ret_type
2340+ _y1 = outer_callee
2341+ _y2 = target_callee
2342+ _u0 = use_outer , use_joint
24322343
24332344 if use_joint :
2434- new_inferred_args = target_solution [0 ]
2435- # inferred_args = [
2436- # # Usually, joint_tp <: outer_tp (since superset of constraints),
2437- # # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]`
2438- # (outer_tp if is_subtype(outer_tp, joint_tp) else joint_tp)
2439- # for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0])
2440- # ]
2345+ # inferred_args = target_solution[0]
2346+ pass
24412347 elif use_outer :
24422348 # If we cannot use the joint solution, fallback to outer_solution
2443- new_inferred_args = target_solution [0 ]
2349+ inferred_args = target_solution [0 ]
24442350
24452351 # Only substitute non-Uninhabited and non-erased types.
24462352 new_args : list [Type | None ] = []
2447- for arg in new_inferred_args :
2353+ for arg in inferred_args :
24482354 if has_uninhabited_component (arg ) or has_erased_component (arg ):
24492355 new_args .append (None )
24502356 else :
24512357 new_args .append (arg )
2358+
24522359 # Don't show errors after we have only used the outer context for inference.
24532360 # We will use argument context to infer more variables.
24542361 callee_type = self .apply_generic_arguments (
@@ -2465,7 +2372,7 @@ def infer_function_type_arguments(
24652372 callee_type .arg_names ,
24662373 lambda i : self .accept (args [i ]),
24672374 )
2468- new_inferred_args , _ = infer_function_type_arguments (
2375+ inferred_args , _ = infer_function_type_arguments (
24692376 callee_type ,
24702377 pass1_args ,
24712378 arg_kinds ,
@@ -2475,21 +2382,10 @@ def infer_function_type_arguments(
24752382 strict = self .chk .in_checked_function (),
24762383 )
24772384 elif use_inner :
2478- new_inferred_args = inner_solution [0 ]
2385+ # inferred_args = inner_solution[0]
2386+ pass
24792387 else :
24802388 raise RuntimeError ("No solution found for function type arguments" )
2481- else : # OLD CODE
2482- pass
2483-
2484- if True : # USE NEW CODE
2485- inferred_args = new_inferred_args
2486- else : # USE OLD CODE
2487- inferred_args = old_inferred_args
2488-
2489- # show me
2490- _1 = new_inferred_args
2491- _2 = old_inferred_args
2492- _3 = inferred_args
24932389
24942390 if 2 in arg_pass_nums :
24952391 # Second pass of type inference.
0 commit comments