diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b8f9bf087467..d9e0d4364916 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -113,6 +113,7 @@ from mypy.semanal_enum import ENUM_BASES from mypy.state import state from mypy.subtypes import ( + covers_at_runtime, find_member, is_equivalent, is_same_type, @@ -4037,14 +4038,21 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None: variants_raw = [(op_name, left_op, left_type, right_expr)] elif ( - is_subtype(right_type, left_type) - and isinstance(left_type, Instance) - and isinstance(right_type, Instance) - and not ( - left_type.type.alt_promote is not None - and left_type.type.alt_promote.type is right_type.type + ( + # Checking (A implies B) using the logically equivalent (not A or B), where + # A: left and right are both `Instance` objects + # B: right's __rop__ method is different from left's __op__ method + not (isinstance(left_type, Instance) and isinstance(right_type, Instance)) + or ( + lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name) + and ( + left_type.type.alt_promote is None + or left_type.type.alt_promote.type is not right_type.type + ) + ) ) - and lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name) + # Note: use `covers_at_runtime` instead of `is_subtype` (#19006) + and covers_at_runtime(right_type, left_type) ): # When we do "A() + B()" where B is a subclass of A, we'll actually try calling # B's __radd__ method first, but ONLY if B explicitly defines or overrides the diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index ea6eac9a39b3..85cd18050a34 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -681,6 +681,29 @@ class B: s: str s = A() + B() # E: Unsupported operand types for + ("A" and "B") + +[case testReverseBinaryOperator4] +from typing import assert_type, Never + +class Size(tuple[int, ...]): + def __add__(self, other: tuple[int, ...], /) -> "Size": return Size() # type: ignore[override] + def __radd__(self, other: tuple[int, ...], /) -> "Size": return Size() + +size: Size = Size([3, 4]) +tup0: tuple[()] = () +tup1: tuple[int] = (1,) +tup2: tuple[int, int] = (1, 2) +tupN: tuple[int, ...] = (1, 2, 3) +tupX: tuple[Never, ...] = () + +assert_type(tup0 + size, Size) +assert_type(tup1 + size, Size) +assert_type(tup2 + size, Size) +assert_type(tupN + size, Size) +assert_type(tupX + size, Size) + +[builtins fixtures/tuple-typeshed.pyi] + [case testBinaryOperatorWithAnyRightOperand] from typing import Any, cast class A: pass diff --git a/test-data/unit/fixtures/tuple-typeshed.pyi b/test-data/unit/fixtures/tuple-typeshed.pyi new file mode 100644 index 000000000000..57a1a29a87ad --- /dev/null +++ b/test-data/unit/fixtures/tuple-typeshed.pyi @@ -0,0 +1,59 @@ +# tuple definition from typeshed, +from typing import ( + Generic, + Sequence, + TypeVar, + Iterable, + Iterator, + Any, + overload, + Self, + Protocol, +) +from types import GenericAlias + +_T = TypeVar("_T") +_T_co = TypeVar('_T_co', covariant=True) + +class tuple(Sequence[_T_co], Generic[_T_co]): + def __new__(cls, iterable: Iterable[_T_co] = ..., /) -> Self: ... + def __len__(self) -> int: ... + def __contains__(self, key: object, /) -> bool: ... + @overload + def __getitem__(self, key: SupportsIndex, /) -> _T_co: ... + @overload + def __getitem__(self, key: slice, /) -> tuple[_T_co, ...]: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __lt__(self, value: tuple[_T_co, ...], /) -> bool: ... + def __le__(self, value: tuple[_T_co, ...], /) -> bool: ... + def __gt__(self, value: tuple[_T_co, ...], /) -> bool: ... + def __ge__(self, value: tuple[_T_co, ...], /) -> bool: ... + def __eq__(self, value: object, /) -> bool: ... + def __hash__(self) -> int: ... + @overload + def __add__(self, value: tuple[_T_co, ...], /) -> tuple[_T_co, ...]: ... + @overload + def __add__(self, value: tuple[_T, ...], /) -> tuple[_T_co | _T, ...]: ... + def __mul__(self, value: SupportsIndex, /) -> tuple[_T_co, ...]: ... + def __rmul__(self, value: SupportsIndex, /) -> tuple[_T_co, ...]: ... + def count(self, value: Any, /) -> int: ... + def index(self, value: Any, start: SupportsIndex = ..., stop: SupportsIndex = ..., /) -> int: ... + def __class_getitem__(cls, item: Any, /) -> GenericAlias: ... + +class dict: pass +class int: pass +class slice: pass +class bool(int): pass +class str: pass # For convenience +class object: pass +class type: pass +class ellipsis: pass +class SupportsIndex(Protocol): + def __index__(self) -> int: pass +class list(Sequence[_T], Generic[_T]): + @overload + def __getitem__(self, i: int) -> _T: ... + @overload + def __getitem__(self, s: slice) -> list[_T]: ... + def __contains__(self, item: object) -> bool: ... + def __iter__(self) -> Iterator[_T]: ...