Skip to content

Commit db7ed35

Browse files
3 test remain
1 parent e4238c7 commit db7ed35

File tree

3 files changed

+125
-34
lines changed

3 files changed

+125
-34
lines changed

mypy/checkexpr.py

Lines changed: 122 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
from mypy.subtypes import (
122122
find_member,
123123
is_equivalent,
124+
is_proper_subtype,
124125
is_same_type,
125126
is_subtype,
126127
non_method_protocol_members,
@@ -2093,7 +2094,7 @@ def infer_function_type_arguments_using_context(
20932094
)
20942095

20952096
def _infer_constraints_from_context(
2096-
self, callee: CallableType, error_context: Context
2097+
self, callee: CallableType, error_context: Context, erase: bool = True
20972098
) -> list[Constraint]:
20982099
"""Unify callable return type to type context to infer type vars.
20992100
@@ -2109,7 +2110,10 @@ def _infer_constraints_from_context(
21092110
# and they are not potential results; thus we replace them with the
21102111
# special ErasedType type. On the other hand, class type variables are
21112112
# valid results.
2112-
erased_ctx = replace_meta_vars(ctx, ErasedType())
2113+
if erase:
2114+
erased_ctx = replace_meta_vars(ctx, ErasedType())
2115+
else:
2116+
erased_ctx = ctx
21132117
ret_type = callee.ret_type
21142118
if is_overlapping_none(ret_type) and is_overlapping_none(ctx):
21152119
# If both the context and the return type are optional, unwrap the optional,
@@ -2168,10 +2172,16 @@ def _infer_constraints_from_context(
21682172
def _filter_args(self, args: list[Type | None]) -> list[Type | None]:
21692173
new_args: list[Type | None] = []
21702174
for arg in args:
2171-
if has_uninhabited_component(arg) or has_erased_component(arg):
2175+
if arg is None:
21722176
new_args.append(None)
2177+
continue
21732178
else:
2179+
arg = replace_meta_vars(arg, ErasedType())
21742180
new_args.append(arg)
2181+
# if has_erased_component(arg) or has_uninhabited_component(arg):
2182+
# new_args.append(None)
2183+
# else:
2184+
# new_args.append(arg)
21752185
return new_args
21762186

21772187
def infer_function_type_arguments(
@@ -2224,18 +2234,6 @@ def infer_function_type_arguments(
22242234
new_inferred_args = None
22252235

22262236
if True: # NEW CODE
2227-
extra_constraints = self._infer_constraints_from_context(callee_type, context)
2228-
2229-
# outer_solution
2230-
_outer_solution = solve_constraints(
2231-
callee_type.variables,
2232-
extra_constraints,
2233-
strict=self.chk.in_checked_function(),
2234-
allow_polymorphic=False,
2235-
)
2236-
outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1])
2237-
2238-
# inner solution
22392237
constraints = infer_constraints_for_callable(
22402238
callee_type,
22412239
pass1_args,
@@ -2244,38 +2242,84 @@ def infer_function_type_arguments(
22442242
formal_to_actual,
22452243
context=self.argument_infer_context(),
22462244
)
2247-
inner_solution = solve_constraints(
2245+
2246+
extra_constraints = self._infer_constraints_from_context(
2247+
callee_type, context, erase=False
2248+
)
2249+
erased_constraints = self._infer_constraints_from_context(
2250+
callee_type, context, erase=True
2251+
)
2252+
2253+
_outer_solution = solve_constraints(
2254+
callee_type.variables,
2255+
extra_constraints,
2256+
strict=self.chk.in_checked_function(),
2257+
allow_polymorphic=False,
2258+
)
2259+
2260+
_inner_solution = solve_constraints(
22482261
callee_type.variables,
22492262
constraints,
22502263
strict=self.chk.in_checked_function(),
22512264
allow_polymorphic=False,
22522265
)
2266+
# NOTE: The order of constraints is important here!
2267+
# solve(outer + inner) and solve(inner + outer) may yield different results.
2268+
_joint_solution = solve_constraints(
2269+
callee_type.variables,
2270+
constraints + extra_constraints,
2271+
strict=self.chk.in_checked_function(),
2272+
allow_polymorphic=False,
2273+
)
22532274

2254-
joint_solution = solve_constraints(
2275+
_reverse_joint_solution = solve_constraints(
22552276
callee_type.variables,
22562277
extra_constraints + constraints,
22572278
strict=self.chk.in_checked_function(),
22582279
allow_polymorphic=False,
22592280
)
22602281

2261-
# check if we can use the joint solution, otherwise fallback to outer_solution
2262-
for var1, var2 in zip(
2263-
outer_solution[0], joint_solution[0]
2264-
): # tuple[Type | None, Type | None]
2265-
if var2 is None and var1 is not None:
2266-
# using both constraints did not find a solution for this variable
2267-
# so we fallback to outer_solution, apply the solution, and then recompute the inner part
2268-
use_joint = False
2269-
break
2270-
else:
2271-
use_joint = True
2282+
_erased_outer_solution = solve_constraints(
2283+
callee_type.variables,
2284+
erased_constraints,
2285+
strict=self.chk.in_checked_function(),
2286+
allow_polymorphic=False,
2287+
)
2288+
2289+
_erased_joint_solution = solve_constraints(
2290+
callee_type.variables,
2291+
constraints + erased_constraints,
2292+
strict=self.chk.in_checked_function(),
2293+
allow_polymorphic=False,
2294+
)
2295+
2296+
_erased_reverse_joint_solution = solve_constraints(
2297+
callee_type.variables,
2298+
erased_constraints + constraints,
2299+
strict=self.chk.in_checked_function(),
2300+
allow_polymorphic=False,
2301+
)
2302+
2303+
# Now, we select the solution to use.
2304+
# Note: Since joint uses both outer and inner constraints,
2305+
# and solution discovered by joint is also a solution for outer and inner.
2306+
# therefore, we can pick either inner or outer as a substitute for joint,
2307+
# and then try to solve again using only the inner constraints.
2308+
# joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1])
2309+
# reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1])
2310+
outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1])
2311+
inner_solution = (self._filter_args(_inner_solution[0]), _inner_solution[1])
2312+
joint_solution = _joint_solution
2313+
reverse_joint_solution = _reverse_joint_solution
2314+
2315+
target_solution = _erased_reverse_joint_solution
22722316

22732317
use_joint = True
22742318
use_outer = True
22752319
use_inner = True
2276-
2320+
# check if we can use the joint solution, otherwise fallback to outer_solution
22772321
for outer_tp, inner_tp, joint_tp in zip(
2278-
outer_solution[0], inner_solution[0], joint_solution[0]
2322+
outer_solution[0], inner_solution[0], target_solution[0]
22792323
):
22802324
if joint_tp is None and outer_tp is not None:
22812325
use_joint = False
@@ -2287,8 +2331,55 @@ def infer_function_type_arguments(
22872331
if has_erased_component(inner_tp):
22882332
use_inner = False
22892333

2334+
combined_solution = []
2335+
for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]):
2336+
if (
2337+
outer_tp is not None
2338+
and joint_tp is not None
2339+
and is_proper_subtype(outer_tp, joint_tp)
2340+
):
2341+
# If outer is a subtype of joint, we can use joint.
2342+
combined_solution.append(outer_tp)
2343+
else:
2344+
# Otherwise, we use joint.
2345+
combined_solution.append(joint_tp)
2346+
2347+
_num = arg_pass_nums
2348+
_c0 = constraints
2349+
_c1 = extra_constraints
2350+
_c2 = erased_constraints
2351+
2352+
_x0 = _outer_solution[0]
2353+
_x2 = _inner_solution[0]
2354+
_x3 = _joint_solution[0]
2355+
_x4 = _reverse_joint_solution[0]
2356+
2357+
_e0 = _erased_outer_solution[0]
2358+
_e2 = _erased_joint_solution[0]
2359+
_e3 = _erased_reverse_joint_solution[0]
2360+
2361+
_s1 = outer_solution[0]
2362+
_s2 = inner_solution[0]
2363+
_s3 = joint_solution[0]
2364+
_s4 = reverse_joint_solution[0]
2365+
_s5 = combined_solution
2366+
2367+
_u0 = use_inner, use_outer, use_joint
2368+
2369+
# if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2370+
if all(
2371+
(joint_tp is None and outer_tp is None)
2372+
or (
2373+
(joint_tp is not None and outer_tp is not None)
2374+
and is_subtype(outer_tp, joint_tp)
2375+
)
2376+
for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0])
2377+
):
2378+
use_joint = False
2379+
use_outer = True
2380+
22902381
if use_joint:
2291-
new_inferred_args = joint_solution[0]
2382+
new_inferred_args = target_solution[0]
22922383
# inferred_args = [
22932384
# # Usually, joint_tp <: outer_tp (since superset of constraints),
22942385
# # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]`

test-data/unit/check-functions.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ...
33843384
def g(x: T, y: S) -> Union[T, S]: ...
33853385

33863386
x = [f, g]
3387-
reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`4, y: S`5) -> Union[T`4, S`5]]"
3387+
reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`16, y: S`17) -> Union[T`16, S`17]]"
33883388
[builtins fixtures/list.pyi]
33893389

33903390
[case testTypeVariableClashErrorMessage]

test-data/unit/check-generics.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]:
29292929
def id(__x: U) -> U:
29302930
...
29312931
fs = [id, id, id]
2932-
reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`11) -> builtins.list[S`11]"
2933-
reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`13) -> builtins.list[S`13]"
2932+
reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`23) -> builtins.list[S`23]"
2933+
reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`25) -> builtins.list[S`25]"
29342934
[builtins fixtures/list.pyi]
29352935

29362936
[case testInferenceAgainstGenericCurry]

0 commit comments

Comments
 (0)