Skip to content

Commit ce4956d

Browse files
some simplifications
1 parent d7145d9 commit ce4956d

File tree

3 files changed

+118
-161
lines changed

3 files changed

+118
-161
lines changed

mypy/checkexpr.py

Lines changed: 115 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@
121121
from mypy.subtypes import (
122122
find_member,
123123
is_equivalent,
124-
is_proper_subtype,
125124
is_same_type,
126125
is_subtype,
127126
non_method_protocol_members,
@@ -2094,7 +2093,7 @@ def infer_function_type_arguments_using_context(
20942093
)
20952094

20962095
def _infer_constraints_from_context(
2097-
self, callee: CallableType, error_context: Context, erase: bool = True
2096+
self, callee: CallableType, error_context: Context
20982097
) -> list[Constraint]:
20992098
"""Unify callable return type to type context to infer type vars.
21002099
@@ -2105,47 +2104,35 @@ def _infer_constraints_from_context(
21052104
ctx = self.type_context[-1]
21062105
if not ctx:
21072106
return []
2108-
# The return type may have references to type metavariables that
2109-
# we are inferring right now. We must consider them as indeterminate
2110-
# and they are not potential results; thus we replace them with the
2111-
# special ErasedType type. On the other hand, class type variables are
2112-
# valid results.
2113-
if erase:
2114-
erased_ctx = replace_meta_vars(ctx, ErasedType())
2115-
else:
2116-
erased_ctx = ctx
2107+
# if is_overlapping_none(ret_type) and is_overlapping_none(ctx):
2108+
# # If both the context and the return type are optional, unwrap the optional,
2109+
# # since in 99% cases this is what a user expects. In other words, we replace
2110+
# # Optional[T] <: Optional[int]
2111+
# # with
2112+
# # T <: int
2113+
# # while the former would infer T <: Optional[int].
2114+
# ret_type = remove_optional(ret_type)
2115+
# erased_ctx = remove_optional(erased_ctx)
2116+
# #
2117+
# # TODO: Instead of this hack and the one below, we need to use outer and
2118+
# # inner contexts at the same time. This is however not easy because of two
2119+
# # reasons:
2120+
# # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables
2121+
# # on both sides. (This is not too hard.)
2122+
# # * We need to update all the inference "infrastructure", so that all
2123+
# # variables in an expression are inferred at the same time.
2124+
# # (And this is hard, also we need to be careful with lambdas that require
2125+
# # two passes.)
21172126
ret_type = callee.ret_type
2118-
if is_overlapping_none(ret_type) and is_overlapping_none(ctx):
2119-
# If both the context and the return type are optional, unwrap the optional,
2120-
# since in 99% cases this is what a user expects. In other words, we replace
2121-
# Optional[T] <: Optional[int]
2122-
# with
2123-
# T <: int
2124-
# while the former would infer T <: Optional[int].
2125-
ret_type = remove_optional(ret_type)
2126-
erased_ctx = remove_optional(erased_ctx)
2127-
#
2128-
# TODO: Instead of this hack and the one below, we need to use outer and
2129-
# inner contexts at the same time. This is however not easy because of two
2130-
# reasons:
2131-
# * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables
2132-
# on both sides. (This is not too hard.)
2133-
# * We need to update all the inference "infrastructure", so that all
2134-
# variables in an expression are inferred at the same time.
2135-
# (And this is hard, also we need to be careful with lambdas that require
2136-
# two passes.)
2137-
# ret_as_union = make_simplified_union([ret_type])
2138-
# erased_ctx_as_union = make_simplified_union([ctx])
2139-
# if isinstance(ret_as_union, UnionType) and isinstance(erased_ctx_as_union, UnionType):
2140-
# new_ret = [val for val in ret_as_union.items if val not in erased_ctx_as_union.items]
2141-
# new_ctx = [val for val in erased_ctx_as_union.items if val not in ret_as_union.items]
2142-
# ret_type = make_simplified_union(new_ret)
2143-
# erased_ctx = make_simplified_union(new_ctx)
2127+
if isinstance(ret_type, UnionType) and isinstance(ctx, UnionType):
2128+
new_ret = [val for val in ret_type.items if val not in ctx.items]
2129+
new_ctx = [val for val in ctx.items if val not in ret_type.items]
2130+
ret_type = make_simplified_union(new_ret)
2131+
ctx = make_simplified_union(new_ctx)
21442132

21452133
proper_ret = get_proper_type(ret_type)
2146-
if (
2147-
isinstance(proper_ret, TypeVarType)
2148-
or isinstance(proper_ret, UnionType)
2134+
if isinstance(proper_ret, TypeVarType) or (
2135+
isinstance(proper_ret, UnionType)
21492136
and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items)
21502137
):
21512138
# Another special case: the return type is a type variable. If it's unrestricted,
@@ -2174,6 +2161,12 @@ def _infer_constraints_from_context(
21742161
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
21752162
return []
21762163

2164+
# The return type may have references to type metavariables that
2165+
# we are inferring right now. We must consider them as indeterminate
2166+
# and they are not potential results; thus we replace them with the
2167+
# special ErasedType type. On the other hand, class type variables are
2168+
# valid results.
2169+
erased_ctx = replace_meta_vars(ctx, ErasedType())
21772170
constraints = infer_constraints(ret_type, erased_ctx, SUBTYPE_OF)
21782171
return constraints
21792172

@@ -2278,12 +2271,7 @@ def infer_function_type_arguments(
22782271
context=self.argument_infer_context(),
22792272
)
22802273

2281-
extra_constraints = self._infer_constraints_from_context(
2282-
callee_type, context, erase=False
2283-
)
2284-
erased_constraints = self._infer_constraints_from_context(
2285-
callee_type, context, erase=True
2286-
)
2274+
extra_constraints = self._infer_constraints_from_context(callee_type, context)
22872275

22882276
_outer_solution = solve_constraints(
22892277
callee_type.variables,
@@ -2314,132 +2302,106 @@ def infer_function_type_arguments(
23142302
allow_polymorphic=False,
23152303
)
23162304

2317-
_erased_outer_solution = solve_constraints(
2318-
callee_type.variables,
2319-
erased_constraints,
2320-
strict=self.chk.in_checked_function(),
2321-
allow_polymorphic=False,
2322-
)
2323-
2324-
_erased_joint_solution = solve_constraints(
2325-
callee_type.variables,
2326-
constraints + erased_constraints,
2327-
strict=self.chk.in_checked_function(),
2328-
allow_polymorphic=False,
2329-
)
2330-
2331-
_erased_reverse_joint_solution = solve_constraints(
2332-
callee_type.variables,
2333-
erased_constraints + constraints,
2334-
strict=self.chk.in_checked_function(),
2335-
allow_polymorphic=False,
2336-
)
2337-
23382305
# Now, we select the solution to use.
23392306
# Note: Since joint uses both outer and inner constraints,
23402307
# and solution discovered by joint is also a solution for outer and inner.
23412308
# therefore, we can pick either inner or outer as a substitute for joint,
23422309
# and then try to solve again using only the inner constraints.
23432310
# joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1])
23442311
# reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1])
2345-
outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1])
2346-
inner_solution = (self._filter_args(_inner_solution[0]), _inner_solution[1])
2312+
outer_solution = _outer_solution
2313+
inner_solution = _inner_solution
23472314
joint_solution = _joint_solution
23482315
reverse_joint_solution = _reverse_joint_solution
2349-
target_solution = _erased_reverse_joint_solution
2316+
target_solution = _reverse_joint_solution
23502317

23512318
if True: # compute the outer and target return types.
2352-
# Only substitute non-Uninhabited and non-erased types.
2353-
new_args: list[Type | None] = []
2354-
for arg in outer_solution[0]:
2355-
if has_uninhabited_component(arg) or has_erased_component(arg):
2356-
new_args.append(None)
2357-
else:
2358-
new_args.append(arg)
2359-
# Don't show errors after we have only used the outer context for inference.
2360-
# We will use argument context to infer more variables.
2361-
outer_callee = self.apply_generic_arguments(
2362-
callee_type, new_args, context, skip_unsatisfied=True
2363-
)
2364-
outer_ret_type = get_proper_type(outer_callee.ret_type)
2319+
if True:
2320+
outer_callee = self.apply_generic_arguments(
2321+
callee_type, outer_solution[0], context, skip_unsatisfied=True
2322+
)
2323+
outer_ret_type = get_proper_type(outer_callee.ret_type)
23652324

2366-
# Only substitute non-Uninhabited and non-erased types.
2367-
new_args: list[Type | None] = []
2368-
for arg in target_solution[0]:
2369-
if has_uninhabited_component(arg) or has_erased_component(arg):
2370-
new_args.append(None)
2371-
else:
2372-
new_args.append(arg)
2373-
# Don't show errors after we have only used the outer context for inference.
2374-
# We will use argument context to infer more variables.
2375-
target_callee = self.apply_generic_arguments(
2376-
callee_type, new_args, context, skip_unsatisfied=True
2377-
)
2378-
target_ret_type = get_proper_type(target_callee.ret_type)
2325+
target_callee = self.apply_generic_arguments(
2326+
callee_type, target_solution[0], context, skip_unsatisfied=True
2327+
)
2328+
target_ret_type = get_proper_type(target_callee.ret_type)
2329+
else:
2330+
# Only substitute non-Uninhabited and non-erased types.
2331+
new_args: list[Type | None] = []
2332+
for arg in outer_solution[0]:
2333+
if has_uninhabited_component(arg) or has_erased_component(arg):
2334+
new_args.append(None)
2335+
else:
2336+
new_args.append(arg)
2337+
# Don't show errors after we have only used the outer context for inference.
2338+
# We will use argument context to infer more variables.
2339+
outer_callee = self.apply_generic_arguments(
2340+
callee_type, new_args, context, skip_unsatisfied=True
2341+
)
2342+
outer_ret_type = get_proper_type(outer_callee.ret_type)
2343+
2344+
# Only substitute non-Uninhabited and non-erased types.
2345+
new_args: list[Type | None] = []
2346+
for arg in target_solution[0]:
2347+
if has_uninhabited_component(arg) or has_erased_component(arg):
2348+
new_args.append(None)
2349+
else:
2350+
new_args.append(arg)
2351+
# Don't show errors after we have only used the outer context for inference.
2352+
# We will use argument context to infer more variables.
2353+
target_callee = self.apply_generic_arguments(
2354+
callee_type, new_args, context, skip_unsatisfied=True
2355+
)
2356+
target_ret_type = get_proper_type(target_callee.ret_type)
23792357

23802358
use_joint = True
23812359
use_outer = True
23822360
use_inner = True
23832361
# check if we can use the joint solution, otherwise fallback to outer_solution
2384-
for outer_tp, inner_tp, joint_tp in zip(
2385-
outer_solution[0], inner_solution[0], target_solution[0]
2386-
):
2387-
if joint_tp is None and outer_tp is not None:
2388-
use_joint = False
2389-
if has_erased_component(joint_tp) and not has_erased_component(inner_tp):
2390-
# If the joint solution is erased, but outer is not, we use outer.
2391-
use_joint = False
2392-
if has_erased_component(outer_tp) and not has_erased_component(inner_tp):
2393-
use_outer = False
2394-
if has_erased_component(inner_tp):
2395-
use_inner = False
2396-
2397-
combined_solution = []
2398-
for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]):
2399-
if (
2400-
outer_tp is not None
2401-
and joint_tp is not None
2402-
and is_proper_subtype(outer_tp, joint_tp)
2403-
):
2404-
# If outer is a subtype of joint, we can use joint.
2405-
combined_solution.append(outer_tp)
2406-
else:
2407-
# Otherwise, we use joint.
2408-
combined_solution.append(joint_tp)
2409-
2410-
# new_vars, new_constraints = self.intersect_solutions(outer_solution[0], target_solution[0])
2411-
# intersected_solution = solve_constraints(
2412-
# new_vars,
2413-
# new_constraints,
2414-
# strict=self.chk.in_checked_function(),
2415-
# allow_polymorphic=False,
2416-
# )
2417-
2418-
# if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2419-
# if all(
2420-
# (joint_tp is None and outer_tp is None)
2421-
# or (
2422-
# (joint_tp is not None and outer_tp is not None)
2423-
# and (
2424-
# is_subtype(outer_tp, joint_tp)
2425-
# or (
2426-
# isinstance(outer_tp, UnionType)
2427-
# and any(is_subtype(val, joint_tp) for val in outer_tp.items)
2428-
# )
2429-
# )
2430-
# )
2431-
# for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0])
2362+
# for outer_tp, inner_tp, joint_tp in zip(
2363+
# outer_solution[0], inner_solution[0], target_solution[0]
24322364
# ):
2433-
# use_joint = False
2434-
# use_outer = True
2365+
# if joint_tp is None and outer_tp is not None:
2366+
# use_joint = False
2367+
# if has_erased_component(joint_tp) and not has_erased_component(inner_tp):
2368+
# # If the joint solution is erased, but outer is not, we use outer.
2369+
# use_joint = False
2370+
# if has_erased_component(outer_tp) and not has_erased_component(inner_tp):
2371+
# use_outer = False
2372+
# if has_erased_component(inner_tp):
2373+
# use_inner = False
2374+
2375+
if any(tp is None for tp in inner_solution[0]):
2376+
use_inner = False
2377+
if any(tp is None for tp in outer_solution[0]):
2378+
use_outer = False
2379+
if any(tp is None for tp in joint_solution[0]):
2380+
use_joint = False
24352381

2436-
# if the outer solution is more concrete than the joint solution, use the outer solution (2 step)
2437-
if is_subtype(outer_ret_type, target_ret_type) or (
2438-
isinstance(outer_ret_type, UnionType)
2439-
and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items)
2382+
if (
2383+
# if the joint failed to solve use the outer solution instead.
2384+
# any(joint_tp is None and outer_tp is not None for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0]))
2385+
# If the outer solution is more concrete than the joint solution, use the outer solution.
2386+
# This also applies if the outer solution is a union type where at least one member
2387+
# is a subtype of the target return type.
2388+
is_subtype(outer_ret_type, target_ret_type)
2389+
or (
2390+
isinstance(outer_ret_type, UnionType)
2391+
and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items)
2392+
)
24402393
):
24412394
use_joint = False
2442-
use_outer = True
2395+
# use_outer = True
2396+
2397+
if use_joint:
2398+
target_solution = reverse_joint_solution
2399+
elif use_outer:
2400+
target_solution = outer_solution
2401+
elif use_inner:
2402+
target_solution = inner_solution
2403+
else:
2404+
raise AssertionError
24432405

24442406
# what if the outer context is a union type?
24452407
# we may have a case like:
@@ -2450,17 +2412,12 @@ def infer_function_type_arguments(
24502412
_num = arg_pass_nums
24512413
_c0 = constraints
24522414
_c1 = extra_constraints
2453-
_c2 = erased_constraints
24542415

24552416
_x0 = _outer_solution[0]
24562417
_x2 = _inner_solution[0]
24572418
_x3 = _joint_solution[0]
24582419
_x4 = _reverse_joint_solution[0]
24592420

2460-
_e0 = _erased_outer_solution[0]
2461-
_e2 = _erased_joint_solution[0]
2462-
_e3 = _erased_reverse_joint_solution[0]
2463-
24642421
_s1 = outer_solution[0]
24652422
_s2 = inner_solution[0]
24662423
_s3 = joint_solution[0]
@@ -2471,7 +2428,7 @@ def infer_function_type_arguments(
24712428
_r2 = target_ret_type
24722429
_y1 = outer_callee
24732430
_y2 = target_callee
2474-
_u0 = use_inner, use_outer, use_joint
2431+
_u0 = use_outer, use_joint
24752432

24762433
if use_joint:
24772434
new_inferred_args = target_solution[0]
@@ -2483,7 +2440,7 @@ def infer_function_type_arguments(
24832440
# ]
24842441
elif use_outer:
24852442
# If we cannot use the joint solution, fallback to outer_solution
2486-
new_inferred_args = outer_solution[0]
2443+
new_inferred_args = target_solution[0]
24872444

24882445
# Only substitute non-Uninhabited and non-erased types.
24892446
new_args: list[Type | None] = []

test-data/unit/check-functions.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ...
33843384
def g(x: T, y: S) -> Union[T, S]: ...
33853385

33863386
x = [f, g]
3387-
reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`14, y: S`15) -> Union[T`14, S`15]]"
3387+
reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`12, y: S`13) -> Union[T`12, S`13]]"
33883388
[builtins fixtures/list.pyi]
33893389

33903390
[case testTypeVariableClashErrorMessage]

test-data/unit/check-generics.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]:
29292929
def id(__x: U) -> U:
29302930
...
29312931
fs = [id, id, id]
2932-
reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`31) -> builtins.list[S`31]"
2933-
reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`33) -> builtins.list[S`33]"
2932+
reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`27) -> builtins.list[S`27]"
2933+
reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`29) -> builtins.list[S`29]"
29342934
[builtins fixtures/list.pyi]
29352935

29362936
[case testInferenceAgainstGenericCurry]

0 commit comments

Comments
 (0)