Skip to content

Commit d7145d9

Browse files
checks pass
1 parent 40f8ca8 commit d7145d9

File tree

3 files changed

+107
-11
lines changed

3 files changed

+107
-11
lines changed

mypy/checkexpr.py

Lines changed: 104 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2134,6 +2134,14 @@ def _infer_constraints_from_context(
21342134
# variables in an expression are inferred at the same time.
21352135
# (And this is hard, also we need to be careful with lambdas that require
21362136
# two passes.)
2137+
# ret_as_union = make_simplified_union([ret_type])
2138+
# erased_ctx_as_union = make_simplified_union([ctx])
2139+
# if isinstance(ret_as_union, UnionType) and isinstance(erased_ctx_as_union, UnionType):
2140+
# new_ret = [val for val in ret_as_union.items if val not in erased_ctx_as_union.items]
2141+
# new_ctx = [val for val in erased_ctx_as_union.items if val not in ret_as_union.items]
2142+
# ret_type = make_simplified_union(new_ret)
2143+
# erased_ctx = make_simplified_union(new_ctx)
2144+
21372145
proper_ret = get_proper_type(ret_type)
21382146
if (
21392147
isinstance(proper_ret, TypeVarType)
@@ -2184,6 +2192,33 @@ def _filter_args(self, args: list[Type | None]) -> list[Type | None]:
21842192
# new_args.append(arg)
21852193
return new_args
21862194

2195+
def intersect_solutions(self, sol1: list[Type | None], sol2: list[Type | None]):
2196+
# first, ensure that the None-patterns agree
2197+
assert len(sol1) == len(sol2)
2198+
2199+
virtual_vars = []
2200+
constraints = []
2201+
2202+
for i, (tp1, tp2) in enumerate(zip(sol1, sol2)):
2203+
new_id = TypeVarId.new(-1)
2204+
name = f"V{i}"
2205+
new_tvar = TypeVarType(
2206+
name,
2207+
name,
2208+
new_id,
2209+
values=[],
2210+
upper_bound=self.object_type(),
2211+
default=AnyType(TypeOfAny.from_omitted_generics),
2212+
)
2213+
virtual_vars.append(new_tvar)
2214+
if tp1 is not None:
2215+
c1 = Constraint(new_tvar, SUBTYPE_OF, tp1)
2216+
constraints.append(c1)
2217+
if tp2 is not None:
2218+
c2 = Constraint(new_tvar, SUBTYPE_OF, tp2)
2219+
constraints.append(c2)
2220+
return virtual_vars, constraints
2221+
21872222
def infer_function_type_arguments(
21882223
self,
21892224
callee_type: CallableType,
@@ -2311,9 +2346,37 @@ def infer_function_type_arguments(
23112346
inner_solution = (self._filter_args(_inner_solution[0]), _inner_solution[1])
23122347
joint_solution = _joint_solution
23132348
reverse_joint_solution = _reverse_joint_solution
2314-
23152349
target_solution = _erased_reverse_joint_solution
23162350

2351+
if True: # compute the outer and target return types.
2352+
# Only substitute non-Uninhabited and non-erased types.
2353+
new_args: list[Type | None] = []
2354+
for arg in outer_solution[0]:
2355+
if has_uninhabited_component(arg) or has_erased_component(arg):
2356+
new_args.append(None)
2357+
else:
2358+
new_args.append(arg)
2359+
# Don't show errors after we have only used the outer context for inference.
2360+
# We will use argument context to infer more variables.
2361+
outer_callee = self.apply_generic_arguments(
2362+
callee_type, new_args, context, skip_unsatisfied=True
2363+
)
2364+
outer_ret_type = get_proper_type(outer_callee.ret_type)
2365+
2366+
# Only substitute non-Uninhabited and non-erased types.
2367+
new_args: list[Type | None] = []
2368+
for arg in target_solution[0]:
2369+
if has_uninhabited_component(arg) or has_erased_component(arg):
2370+
new_args.append(None)
2371+
else:
2372+
new_args.append(arg)
2373+
# Don't show errors after we have only used the outer context for inference.
2374+
# We will use argument context to infer more variables.
2375+
target_callee = self.apply_generic_arguments(
2376+
callee_type, new_args, context, skip_unsatisfied=True
2377+
)
2378+
target_ret_type = get_proper_type(target_callee.ret_type)
2379+
23172380
use_joint = True
23182381
use_outer = True
23192382
use_inner = True
@@ -2344,18 +2407,46 @@ def infer_function_type_arguments(
23442407
# Otherwise, we use joint.
23452408
combined_solution.append(joint_tp)
23462409

2410+
# new_vars, new_constraints = self.intersect_solutions(outer_solution[0], target_solution[0])
2411+
# intersected_solution = solve_constraints(
2412+
# new_vars,
2413+
# new_constraints,
2414+
# strict=self.chk.in_checked_function(),
2415+
# allow_polymorphic=False,
2416+
# )
2417+
23472418
# if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2348-
if all(
2349-
(joint_tp is None and outer_tp is None)
2350-
or (
2351-
(joint_tp is not None and outer_tp is not None)
2352-
and is_subtype(outer_tp, joint_tp)
2353-
)
2354-
for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0])
2419+
# if all(
2420+
# (joint_tp is None and outer_tp is None)
2421+
# or (
2422+
# (joint_tp is not None and outer_tp is not None)
2423+
# and (
2424+
# is_subtype(outer_tp, joint_tp)
2425+
# or (
2426+
# isinstance(outer_tp, UnionType)
2427+
# and any(is_subtype(val, joint_tp) for val in outer_tp.items)
2428+
# )
2429+
# )
2430+
# )
2431+
# for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0])
2432+
# ):
2433+
# use_joint = False
2434+
# use_outer = True
2435+
2436+
# if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2437+
if is_subtype(outer_ret_type, target_ret_type) or (
2438+
isinstance(outer_ret_type, UnionType)
2439+
and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items)
23552440
):
23562441
use_joint = False
23572442
use_outer = True
23582443

2444+
# what if the outer context is a union type?
2445+
# we may have a case like:
2446+
# outer : int | Literal["foo"]
2447+
# inner: Literal["foo"]? (which gets translated into str later)
2448+
# here, we would want `Literal["foo"]` to be used as the solution,
2449+
23592450
_num = arg_pass_nums
23602451
_c0 = constraints
23612452
_c1 = extra_constraints
@@ -2375,6 +2466,11 @@ def infer_function_type_arguments(
23752466
_s3 = joint_solution[0]
23762467
_s4 = reverse_joint_solution[0]
23772468
_t0 = target_solution[0]
2469+
2470+
_r1 = outer_ret_type
2471+
_r2 = target_ret_type
2472+
_y1 = outer_callee
2473+
_y2 = target_callee
23782474
_u0 = use_inner, use_outer, use_joint
23792475

23802476
if use_joint:

test-data/unit/check-functions.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ...
33843384
def g(x: T, y: S) -> Union[T, S]: ...
33853385

33863386
x = [f, g]
3387-
reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`16, y: S`17) -> Union[T`16, S`17]]"
3387+
reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`14, y: S`15) -> Union[T`14, S`15]]"
33883388
[builtins fixtures/list.pyi]
33893389

33903390
[case testTypeVariableClashErrorMessage]

test-data/unit/check-generics.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]:
29292929
def id(__x: U) -> U:
29302930
...
29312931
fs = [id, id, id]
2932-
reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`35) -> builtins.list[S`35]"
2933-
reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`37) -> builtins.list[S`37]"
2932+
reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`31) -> builtins.list[S`31]"
2933+
reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`33) -> builtins.list[S`33]"
29342934
[builtins fixtures/list.pyi]
29352935

29362936
[case testInferenceAgainstGenericCurry]

0 commit comments

Comments
 (0)