diff --git a/mypy/stubtest.py b/mypy/stubtest.py index db902bae08c9..482a14984950 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -506,13 +506,16 @@ def _is_disjoint_base(typ: type[object]) -> bool: def _verify_disjoint_base( stub: nodes.TypeInfo, runtime: type[object], object_path: list[str] ) -> Iterator[Error]: - # If it's final, doesn't matter whether it's a disjoint base or not - if stub.is_final: - return is_disjoint_runtime = _is_disjoint_base(runtime) # Don't complain about missing @disjoint_base if there are __slots__, because # in that case we can infer that it's a disjoint base. - if is_disjoint_runtime and not stub.is_disjoint_base and not runtime.__dict__.get("__slots__"): + if ( + is_disjoint_runtime + and not stub.is_disjoint_base + and not runtime.__dict__.get("__slots__") + and not stub.is_final + and not (stub.is_enum and stub.enum_members) + ): yield Error( object_path, "is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub", @@ -520,14 +523,40 @@ def _verify_disjoint_base( runtime, stub_desc=repr(stub), ) - elif not is_disjoint_runtime and stub.is_disjoint_base: - yield Error( - object_path, - "is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime", - stub, - runtime, - stub_desc=repr(stub), - ) + elif stub.is_disjoint_base: + if not is_disjoint_runtime: + yield Error( + object_path, + "is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime", + stub, + runtime, + stub_desc=repr(stub), + ) + if runtime.__dict__.get("__slots__"): + yield Error( + object_path, + "is marked as @disjoint_base, but also has slots; add __slots__ instead", + stub, + runtime, + stub_desc=repr(stub), + ) + elif stub.is_final: + yield Error( + object_path, + "is marked as @disjoint_base, but also marked as @final; remove @disjoint_base", + stub, + runtime, + stub_desc=repr(stub), + ) + elif stub.is_enum and stub.enum_members: + yield Error( + object_path, + "is marked as @disjoint_base, but is an enum with members, which is implicitly final; " + "remove @disjoint_base", + stub, + runtime, + stub_desc=repr(stub), + ) def _verify_metaclass( diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index c9404e206e4f..2bf071d34d48 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -1407,14 +1407,9 @@ def spam(x=Flags4(0)): pass stub=""" import sys from typing import Final, Literal - from typing_extensions import disjoint_base - if sys.version_info >= (3, 12): - class BytesEnum(bytes, enum.Enum): - a = b'foo' - else: - @disjoint_base - class BytesEnum(bytes, enum.Enum): - a = b'foo' + class BytesEnum(bytes, enum.Enum): + a = b'foo' + FOO: Literal[BytesEnum.a] BAR: Final = BytesEnum.a BAZ: BytesEnum @@ -1698,6 +1693,53 @@ def __next__(self) -> object: ... """, error=None, ) + yield Case( + runtime=""" + class IsDisjointBaseBecauseItHasSlots: + __slots__ = ("a",) + a: int + """, + stub=""" + from typing_extensions import disjoint_base + + @disjoint_base + class IsDisjointBaseBecauseItHasSlots: + a: int + """, + error="test_module.IsDisjointBaseBecauseItHasSlots", + ) + yield Case( + runtime=""" + class IsFinalSoDisjointBaseIsRedundant: ... + """, + stub=""" + from typing_extensions import disjoint_base, final + + @final + @disjoint_base + class IsFinalSoDisjointBaseIsRedundant: ... + """, + error="test_module.IsFinalSoDisjointBaseIsRedundant", + ) + yield Case( + runtime=""" + import enum + + class IsEnumWithMembersSoDisjointBaseIsRedundant(enum.Enum): + A = 1 + B = 2 + """, + stub=""" + from typing_extensions import disjoint_base + import enum + + @disjoint_base + class IsEnumWithMembersSoDisjointBaseIsRedundant(enum.Enum): + A = 1 + B = 2 + """, + error="test_module.IsEnumWithMembersSoDisjointBaseIsRedundant", + ) @collect_cases def test_has_runtime_final_decorator(self) -> Iterator[Case]: