Skip to content

Commit 7cd577d

Browse files
committed
refactor(mypy): improve type inference for Union type comparisons
- Add special handling for Union[T, X] vs Union[Y, Z] case - Implement more precise subtype checking for Union types - Enhance type variable inference in Union type comparisons
1 parent 13fa6c3 commit 7cd577d

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

mypy/constraints.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,40 @@ def _infer_constraints(
383383
res.extend(infer_constraints(t_item, actual, direction))
384384
return res
385385
if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
386+
# Special handling for Union[T, X] vs Union[Y, Z] case
387+
template_proper = get_proper_type(orig_template)
388+
if isinstance(template_proper, UnionType) and len(template_proper.items) == 2:
389+
type_var_items = []
390+
non_type_var_items = []
391+
392+
for t_item in template_proper.items:
393+
t_item_proper = get_proper_type(t_item)
394+
if isinstance(t_item_proper, TypeVarType):
395+
type_var_items.append(t_item_proper)
396+
else:
397+
non_type_var_items.append(t_item_proper)
398+
399+
if len(type_var_items) == 1 and len(non_type_var_items) == 1:
400+
# This is Union[T, X] vs Union[Y, Z] case
401+
type_var = type_var_items[0]
402+
non_type_var = non_type_var_items[0]
403+
404+
# Check if any actual items are NOT subtypes of the non-type-var part
405+
compatible_items = []
406+
actual_proper = get_proper_type(actual)
407+
if isinstance(actual_proper, UnionType):
408+
for actual_item in actual_proper.items:
409+
if not mypy.subtypes.is_subtype(actual_item, non_type_var):
410+
compatible_items.append(actual_item)
411+
412+
# If we have compatible items, create constraint for the type variable
413+
if compatible_items:
414+
if len(compatible_items) == 1:
415+
return [Constraint(type_var, SUBTYPE_OF, compatible_items[0])]
416+
else:
417+
union_type = UnionType.make_union(compatible_items)
418+
return [Constraint(type_var, SUBTYPE_OF, union_type)]
419+
386420
res = []
387421
for a_item in actual.items:
388422
# `orig_template` has to be preserved intact in case it's recursive.
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
[case testUnionTypeVarInferenceBasic]
2+
from typing import TypeVar, Union
3+
4+
class A: pass
5+
class B: pass
6+
class C: pass
7+
8+
T = TypeVar('T')
9+
10+
def foo(x: Union[T, A]) -> T: ...
11+
12+
obj: Union[B, C]
13+
reveal_type(foo(obj)) # N: Revealed type is "__main__.B | __main__.C"
14+
15+
[builtins fixtures/tuple.pyi]
16+
17+
[case testUnionTypeVarInferenceSingle]
18+
from typing import TypeVar, Union
19+
20+
class A: pass
21+
class B: pass
22+
23+
T = TypeVar('T')
24+
25+
def foo(x: Union[T, A]) -> T: ...
26+
27+
obj: B
28+
reveal_type(foo(obj)) # N: Revealed type is "__main__.B"
29+
30+
[builtins fixtures/tuple.pyi]
31+
32+
[case testUnionTypeVarInferenceThreeWay]
33+
from typing import TypeVar, Union
34+
35+
class A: pass
36+
class B: pass
37+
class C: pass
38+
class D: pass
39+
40+
T = TypeVar('T')
41+
42+
def foo(x: Union[T, A]) -> T: ...
43+
44+
obj: Union[B, C, D]
45+
reveal_type(foo(obj)) # N: Revealed type is "__main__.B | __main__.C | __main__.D"
46+
47+
[builtins fixtures/tuple.pyi]
48+
49+
[case testUnionTypeVarInferenceOverlapping]
50+
from typing import TypeVar, Union
51+
52+
class A: pass
53+
class B: pass
54+
55+
T = TypeVar('T')
56+
57+
def foo(x: Union[T, A]) -> T: ...
58+
59+
obj: Union[A, B]
60+
reveal_type(foo(obj)) # N: Revealed type is "__main__.A | __main__.B"
61+
62+
[builtins fixtures/tuple.pyi]
63+
64+
[case testUnionTypeVarInferenceJustA]
65+
from typing import TypeVar, Union
66+
67+
class A: pass
68+
69+
T = TypeVar('T')
70+
71+
def foo(x: Union[T, A]) -> T: ...
72+
73+
obj: A
74+
reveal_type(foo(obj)) # N: Revealed type is "__main__.A"
75+
76+
[builtins fixtures/tuple.pyi]
77+
78+
[case testUnionTypeVarInferenceComplex]
79+
from typing import TypeVar, Union
80+
from dataclasses import dataclass
81+
import pathlib
82+
83+
class Cancelled: pass
84+
85+
T = TypeVar('T')
86+
87+
@dataclass
88+
class CreateProject:
89+
jsonFilePath: pathlib.Path
90+
91+
@dataclass
92+
class LoadProject:
93+
jsonFilePath: pathlib.Path
94+
95+
@dataclass
96+
class MigrateProject:
97+
oldJsonFilePath: pathlib.Path
98+
newProjectFolderPath: pathlib.Path
99+
100+
Project = Union[CreateProject, LoadProject, MigrateProject]
101+
102+
def getProject() -> Union[Project, Cancelled]: ...
103+
104+
def value(maybeCancelled: Union[T, Cancelled]) -> T: ...
105+
106+
def main() -> None:
107+
maybeCancelled = getProject()
108+
project: Project = reveal_type(value(maybeCancelled)) # N: Revealed type is "__main__.CreateProject | __main__.LoadProject | __main__.MigrateProject"
109+
110+
[builtins fixtures/tuple.pyi]
111+
[typing fixtures/typing-medium.pyi]
112+
113+
[case testUnionTypeVarInferenceComplex]
114+
from typing import TypeVar, Union
115+
import pathlib
116+
117+
class Cancelled: pass
118+
119+
T = TypeVar('T')
120+
121+
@dataclass
122+
class CreateProject:
123+
jsonFilePath: pathlib.Path
124+
125+
@dataclass
126+
class LoadProject:
127+
jsonFilePath: pathlib.Path
128+
129+
@dataclass
130+
class MigrateProject:
131+
oldJsonFilePath: pathlib.Path
132+
newProjectFolderPath: pathlib.Path
133+
134+
Project = Union[CreateProject, LoadProject, MigrateProject]
135+
136+
def getProject() -> Union[Project, Cancelled]: ...
137+
138+
def value(maybeCancelled: Union[T, Cancelled]) -> T: ...
139+
140+
def main() -> None:
141+
maybeCancelled = getProject()
142+
project: Project = reveal_type(value(maybeCancelled)) # N: Revealed type is "__main__.CreateProject | __main__.LoadProject | __main__.MigrateProject"
143+
144+
[builtins fixtures/tuple.pyi]
145+
[typing fixtures/typing-medium.pyi]

0 commit comments

Comments
 (0)