Skip to content

Commit b1d5b92

Browse files
committed
Fix union inference of generic class and its generic type
1 parent 77cfb98 commit b1d5b92

File tree

2 files changed

+38
-8
lines changed

2 files changed

+38
-8
lines changed

mypy/constraints.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,24 @@ def _infer_constraints(
390390
# When the template is a union, we are okay with leaving some
391391
# type variables indeterminate. This helps with some special
392392
# cases, though this isn't very principled.
393+
394+
def _is_item_being_overlaped_by_other(item: Type) -> bool:
395+
# It returns true if the item is an argument of other item
396+
# that is subtype of the actual type
397+
return any(
398+
isinstance(p_type := get_proper_type(item_to_compare), Instance)
399+
and mypy.subtypes.is_subtype(actual, erase_typevars(p_type))
400+
and item in p_type.args
401+
for item_to_compare in template.items
402+
if item is not item_to_compare
403+
)
404+
393405
result = any_constraints(
394406
[
395407
infer_constraints_if_possible(t_item, actual, direction)
396-
for t_item in template.items
408+
for t_item in [
409+
item for item in template.items if not _is_item_being_overlaped_by_other(item)
410+
]
397411
],
398412
eager=False,
399413
)

test-data/unit/check-inference.test

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -873,13 +873,7 @@ def g(x: Union[T, List[T]]) -> List[T]: pass
873873
def h(x: List[str]) -> None: pass
874874
g('a')() # E: "List[str]" not callable
875875

876-
# The next line is a case where there are multiple ways to satisfy a constraint
877-
# involving a Union. Either T = List[str] or T = str would turn out to be valid,
878-
# but mypy doesn't know how to branch on these two options (and potentially have
879-
# to backtrack later) and defaults to T = Never. The result is an
880-
# awkward error message. Either a better error message, or simply accepting the
881-
# call, would be preferable here.
882-
g(['a']) # E: Argument 1 to "g" has incompatible type "List[str]"; expected "List[Never]"
876+
g(['a'])
883877

884878
h(g(['a']))
885879

@@ -891,6 +885,28 @@ i(b, a, b)
891885
i(a, b, b) # E: Argument 1 to "i" has incompatible type "List[int]"; expected "List[str]"
892886
[builtins fixtures/list.pyi]
893887

888+
[case testUnionInferenceOfGenericClassAndItsGenericType]
889+
from typing import Generic, TypeVar, Union
890+
891+
T = TypeVar('T')
892+
893+
class GenericClass(Generic[T]):
894+
def __init__(self, value: T) -> None:
895+
self.value = value
896+
897+
def method_with_union(arg: Union[GenericClass[T], T]) -> GenericClass[T]:
898+
if not isinstance(arg, GenericClass):
899+
arg = GenericClass(arg)
900+
return arg
901+
902+
result_1 = method_with_union(GenericClass("test"))
903+
reveal_type(result_1) # N: Revealed type is "__main__.GenericClass[builtins.str]"
904+
905+
result_2 = method_with_union("test")
906+
reveal_type(result_2) # N: Revealed type is "__main__.GenericClass[builtins.str]"
907+
908+
[builtins fixtures/isinstance.pyi]
909+
894910
[case testCallableListJoinInference]
895911
from typing import Any, Callable
896912

0 commit comments

Comments
 (0)