Skip to content

Commit 412f63b

Browse files
committed
Apply union expansion when checking ops to typevars
1 parent 70d7881 commit 412f63b

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

mypy/checkexpr.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4155,10 +4155,9 @@ def check_op(
41554155
"""
41564156

41574157
if allow_reverse:
4158-
left_variants = [base_type]
4158+
left_variants = self._union_items_from_typevar(base_type)
41594159
base_type = get_proper_type(base_type)
4160-
if isinstance(base_type, UnionType):
4161-
left_variants = list(flatten_nested_unions(base_type.relevant_items()))
4160+
41624161
right_type = self.accept(arg)
41634162

41644163
# Step 1: We first try leaving the right arguments alone and destructure
@@ -4196,13 +4195,18 @@ def check_op(
41964195
# We don't do the same for the base expression because it could lead to weird
41974196
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
41984197
# TODO: Can we use `type_overrides_set()` here?
4199-
right_variants = [(right_type, arg)]
4200-
right_type = get_proper_type(right_type)
4201-
if isinstance(right_type, UnionType):
4198+
right_variants: list[tuple[Type, Expression]]
4199+
if isinstance(right_type, ProperType) and isinstance(
4200+
right_type, (UnionType, TypeVarType)
4201+
):
42024202
right_variants = [
42034203
(item, TempNode(item, context=context))
4204-
for item in flatten_nested_unions(right_type.relevant_items())
4204+
for item in self._union_items_from_typevar(right_type)
42054205
]
4206+
else:
4207+
# Preserve argument identity if we do not intend to modify it
4208+
right_variants = [(right_type, arg)]
4209+
right_type = get_proper_type(right_type)
42064210

42074211
all_results = []
42084212
all_inferred = []
@@ -4252,6 +4256,19 @@ def check_op(
42524256
context=context,
42534257
)
42544258

4259+
def _union_items_from_typevar(self, typ: Type) -> list[Type]:
4260+
variants = [typ]
4261+
typ = get_proper_type(typ)
4262+
base_type = typ
4263+
if unwrapped := (isinstance(typ, TypeVarType) and not typ.values):
4264+
typ = get_proper_type(typ.upper_bound)
4265+
if isinstance(typ, UnionType):
4266+
variants = list(flatten_nested_unions(typ.relevant_items()))
4267+
if unwrapped:
4268+
assert isinstance(base_type, TypeVarType)
4269+
variants = [base_type.copy_modified(upper_bound=item) for item in variants]
4270+
return variants
4271+
42554272
def check_boolean_op(self, e: OpExpr, context: Context) -> Type:
42564273
"""Type check a boolean operation ('and' or 'or')."""
42574274

test-data/unit/check-expressions.test

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,32 @@ if int():
706706
class C:
707707
def __lt__(self, o: object, x: str = "") -> int: ...
708708

709+
[case testReversibleOpOnTypeVarBound]
710+
from typing import TypeVar, Union
711+
712+
class A:
713+
def __lt__(self, a: A) -> bool: ...
714+
def __gt__(self, a: A) -> bool: ...
715+
716+
class B(A):
717+
def __lt__(self, b: B) -> bool: ... # type: ignore[override]
718+
def __gt__(self, b: B) -> bool: ... # type: ignore[override]
719+
720+
_T = TypeVar("_T", bound=Union[A, B])
721+
722+
def check(x: _T, y: _T) -> bool:
723+
return x < y
724+
725+
[case testReversibleOpOnTypeVarBoundPromotion]
726+
from typing import TypeVar, Union
727+
728+
_T = TypeVar("_T", bound=Union[int, float])
729+
730+
def check(x: _T, y: _T) -> bool:
731+
return x < y
732+
[builtins fixtures/ops.pyi]
733+
734+
709735
[case testErrorContextAndBinaryOperators]
710736
import typing
711737
class A:

test-data/unit/fixtures/ops.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class float:
6161
def __rdiv__(self, x: 'float') -> 'float': pass
6262
def __truediv__(self, x: 'float') -> 'float': pass
6363
def __rtruediv__(self, x: 'float') -> 'float': pass
64+
def __eq__(self, x: object) -> bool: pass
65+
def __ne__(self, x: object) -> bool: pass
66+
def __lt__(self, x: 'float') -> bool: pass
67+
def __le__(self, x: 'float') -> bool: pass
68+
def __gt__(self, x: 'float') -> bool: pass
69+
def __ge__(self, x: 'float') -> bool: pass
6470

6571
class complex:
6672
def __add__(self, x: complex) -> complex: pass

0 commit comments

Comments
 (0)