diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 88b3005b1376..7a29f0054047 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4054,7 +4054,21 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None: # We store the determined order inside the 'variants_raw' variable, # which records tuples containing the method, base type, and the argument. - if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type): + if ( + op_name in operators.op_methods_that_shortcut + and is_same_type(left_type, right_type) + and not ( + # We consider typevars with equal IDs "same types" even if some narrowing + # has been applied. However, different bounds here might come from union + # expansion applied earlier, so we are not supposed to check them as + # being same types here. For plain union items `is_same_type` will + # return false, but not for typevars having these items as bounds. + # See testReversibleOpOnTypeVarProtocol. + isinstance(left_type, TypeVarType) + and isinstance(right_type, TypeVarType) + and not is_same_type(left_type.upper_bound, right_type.upper_bound) + ) + ): # When we do "A() + A()", for example, Python will only call the __add__ method, # never the __radd__ method. # @@ -4167,10 +4181,9 @@ def check_op( """ if allow_reverse: - left_variants = [base_type] + left_variants = self._union_items_from_typevar(base_type) base_type = get_proper_type(base_type) - if isinstance(base_type, UnionType): - left_variants = list(flatten_nested_unions(base_type.relevant_items())) + right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -4208,13 +4221,17 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants = [(right_type, arg)] - right_type = get_proper_type(right_type) - if isinstance(right_type, UnionType): + right_variants: list[tuple[Type, Expression]] + p_right = get_proper_type(right_type) + if isinstance(p_right, (UnionType, TypeVarType)): right_variants = [ (item, TempNode(item, context=context)) - for item in flatten_nested_unions(right_type.relevant_items()) + for item in self._union_items_from_typevar(right_type) ] + else: + # Preserve argument identity if we do not intend to modify it + right_variants = [(right_type, arg)] + right_type = p_right all_results = [] all_inferred = [] @@ -4264,6 +4281,20 @@ def check_op( context=context, ) + def _union_items_from_typevar(self, typ: Type) -> list[Type]: + variants = [typ] + typ = get_proper_type(typ) + base_type = typ + if unwrapped := (isinstance(typ, TypeVarType) and not typ.values): + typ = get_proper_type(typ.upper_bound) + if is_union := isinstance(typ, UnionType): + variants = list(flatten_nested_unions(typ.relevant_items())) + if is_union and unwrapped: + # If not a union, keep the original type + assert isinstance(base_type, TypeVarType) + variants = [base_type.copy_modified(upper_bound=item) for item in variants] + return variants + def check_boolean_op(self, e: OpExpr) -> Type: """Type check a boolean operation ('and' or 'or').""" diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..955a4d273e19 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -273,6 +273,10 @@ def is_same_type( ): return all(is_same_type(x, y) for x, y in zip(a.args, b.args)) elif isinstance(a, TypeVarType) and isinstance(b, TypeVarType) and a.id == b.id: + # This is not only a performance optimization. Deeper check will compare upper + # bounds, but we want to consider copies of the same type variable "same type". + # This makes sense semantically: even we have narrowed the upper bound somehow, + # it's still the same object it used to be before. return True # Note that using ignore_promotions=True (default) makes types like int and int64 diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..fb92305a9818 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -706,6 +706,86 @@ if int(): class C: def __lt__(self, o: object, x: str = "") -> int: ... +[case testReversibleOpOnTypeVarBound] +from typing import TypeVar, Union + +class A: + def __lt__(self, a: A) -> bool: ... + def __gt__(self, a: A) -> bool: ... + +class B(A): + def __lt__(self, b: B) -> bool: ... # type: ignore[override] + def __gt__(self, b: B) -> bool: ... # type: ignore[override] + +_T = TypeVar("_T", bound=Union[A, B]) + +def check(x: _T, y: _T) -> bool: + return x < y + +[case testReversibleOpOnTypeVarBoundPromotion] +from typing import TypeVar, Union + +_T = TypeVar("_T", bound=Union[int, float]) + +def check(x: _T, y: _T) -> bool: + return x < y +[builtins fixtures/ops.pyi] + +[case testReversibleOpOnTypeVarProtocol] +# https://github.com/python/mypy/issues/18203 +from typing import Protocol, TypeVar, Union +from typing_extensions import Self, runtime_checkable + +class A(Protocol): + def __add__(self, other: Union[int, Self]) -> Self: ... + def __radd__(self, other: Union[int, Self]) -> Self: ... + +AT = TypeVar("AT", bound=Union[int, A]) + +def f(a: AT, b: AT) -> None: + reveal_type(a + a) # N: Revealed type is "Union[builtins.int, AT`-1]" + reveal_type(a + b) # N: Revealed type is "Union[builtins.int, AT`-1]" + if isinstance(a, int): + reveal_type(a) # N: Revealed type is "AT`-1" + reveal_type(a + a) # N: Revealed type is "builtins.int" + reveal_type(a + b) # N: Revealed type is "Union[builtins.int, AT`-1]" + reveal_type(b + a) # N: Revealed type is "Union[builtins.int, AT`-1]" + +@runtime_checkable +class B(Protocol): + def __radd__(self, other: Union[int, Self]) -> Self: ... + +BT = TypeVar("BT", bound=Union[int, B]) + +def g(a: BT, b: BT) -> None: + reveal_type(a + a) # E: Unsupported left operand type for + ("BT") \ + # N: Both left and right operands are unions \ + # N: Revealed type is "Union[builtins.int, BT`-1, Any]" + reveal_type(a + b) # E: Unsupported left operand type for + ("BT") \ + # N: Both left and right operands are unions \ + # N: Revealed type is "Union[builtins.int, BT`-1, Any]" + if isinstance(a, int): + reveal_type(a) # N: Revealed type is "BT`-1" + reveal_type(0 + a) # N: Revealed type is "builtins.int" + reveal_type(a + 0) # N: Revealed type is "builtins.int" + reveal_type(a + a) # N: Revealed type is "builtins.int" + reveal_type(a + b) # N: Revealed type is "Union[builtins.int, BT`-1]" + reveal_type(b + a) # E: Unsupported left operand type for + ("BT") \ + # N: Left operand is of type "BT" \ + # N: Revealed type is "Union[builtins.int, Any]" + if isinstance(a, B): + reveal_type(a) # N: Revealed type is "BT`-1" + reveal_type(0 + a) # N: Revealed type is "BT`-1" + reveal_type(a + 0) # E: Unsupported left operand type for + ("BT") \ + # N: Revealed type is "Any" + reveal_type(a + a) # E: Unsupported left operand type for + ("BT") \ + # N: Revealed type is "Any" + reveal_type(a + b) # E: Unsupported left operand type for + ("BT") \ + # N: Right operand is of type "BT" \ + # N: Revealed type is "Any" +[builtins fixtures/isinstance.pyi] + + [case testErrorContextAndBinaryOperators] import typing class A: diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 67bc74b35c51..34e512b34984 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -61,6 +61,12 @@ class float: def __rdiv__(self, x: 'float') -> 'float': pass def __truediv__(self, x: 'float') -> 'float': pass def __rtruediv__(self, x: 'float') -> 'float': pass + def __eq__(self, x: object) -> bool: pass + def __ne__(self, x: object) -> bool: pass + def __lt__(self, x: 'float') -> bool: pass + def __le__(self, x: 'float') -> bool: pass + def __gt__(self, x: 'float') -> bool: pass + def __ge__(self, x: 'float') -> bool: pass class complex: def __add__(self, x: complex) -> complex: pass