diff --git a/mypy/checker.py b/mypy/checker.py index 47c72924bf3c..0fe77e953d06 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -171,6 +171,7 @@ from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type from mypy.typeops import ( bind_self, + can_have_shared_disjoint_base, coerce_to_literal, custom_special_method, erase_def_to_union_or_bound, @@ -2658,6 +2659,8 @@ def visit_class_def(self, defn: ClassDef) -> None: for base in typ.mro[1:]: if base.is_final: self.fail(message_registry.CANNOT_INHERIT_FROM_FINAL.format(base.name), defn) + if not can_have_shared_disjoint_base(typ.bases): + self.fail(message_registry.INCOMPATIBLE_DISJOINT_BASES.format(typ.name), defn) with self.tscope.class_scope(defn.info), self.enter_partial_types(is_class=True): old_binder = self.binder self.binder = ConditionalTypeBinder(self.options) @@ -5826,6 +5829,10 @@ def _make_fake_typeinfo_and_full_name( format_type_distinctly(*base_classes, options=self.options, bare=True), "and" ) + if not can_have_shared_disjoint_base(base_classes): + errors.append((pretty_names_list, "have distinct disjoint bases")) + return None + new_errors = [] for base in base_classes: if base.type.is_final: diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 2c41f2e273cc..e7de1b7a304f 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1484,13 +1484,7 @@ def analyze_decorator_or_funcbase_access( if isinstance(defn, Decorator): return analyze_var(name, defn.var, itype, mx) typ = function_type(defn, mx.chk.named_type("builtins.function")) - is_trivial_self = False - if isinstance(defn, Decorator): - # Use fast path if there are trivial decorators like @classmethod or @property - is_trivial_self = defn.func.is_trivial_self and not defn.decorators - elif isinstance(defn, (FuncDef, OverloadedFuncDef)): - is_trivial_self = defn.is_trivial_self - if is_trivial_self: + if isinstance(defn, (FuncDef, OverloadedFuncDef)) and defn.is_trivial_self: return bind_self_fast(typ, mx.self_type) typ = check_self_arg(typ, mx.self_type, defn.is_class, mx.context, name, mx.msg) return bind_self(typ, original_type=mx.self_type, is_classmethod=defn.is_class) diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 381aedfca059..09004322aee9 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -239,6 +239,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage: ) CANNOT_MAKE_DELETABLE_FINAL: Final = ErrorMessage("Deletable attribute cannot be final") +# Disjoint bases +INCOMPATIBLE_DISJOINT_BASES: Final = ErrorMessage('Class "{}" has incompatible disjoint bases') + # Enum ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDDEN: Final = ErrorMessage( 'Assigned "__members__" will be overridden by "Enum" internally' diff --git a/mypy/nodes.py b/mypy/nodes.py index 99b9bf72c948..8c2110b156f1 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -3004,6 +3004,7 @@ class is generic then it will be a type constructor of higher kind. "_mro_refs", "bad_mro", "is_final", + "is_disjoint_base", "declared_metaclass", "metaclass_type", "names", @@ -3055,6 +3056,7 @@ class is generic then it will be a type constructor of higher kind. _mro_refs: list[str] | None bad_mro: bool # Could not construct full MRO is_final: bool + is_disjoint_base: bool declared_metaclass: mypy.types.Instance | None metaclass_type: mypy.types.Instance | None @@ -3209,6 +3211,7 @@ class is generic then it will be a type constructor of higher kind. "is_protocol", "runtime_protocol", "is_final", + "is_disjoint_base", "is_intersection", ] @@ -3241,6 +3244,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None self.type_var_tuple_suffix: int | None = None self.add_type_vars() self.is_final = False + self.is_disjoint_base = False self.is_enum = False self.fallback_to_any = False self.meta_fallback_to_any = False diff --git a/mypy/semanal.py b/mypy/semanal.py index eef658d9300b..e8426a4e4885 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -254,6 +254,7 @@ ASSERT_TYPE_NAMES, DATACLASS_TRANSFORM_NAMES, DEPRECATED_TYPE_NAMES, + DISJOINT_BASE_DECORATOR_NAMES, FINAL_DECORATOR_NAMES, FINAL_TYPE_NAMES, IMPORTED_REVEAL_TYPE_NAMES, @@ -2188,6 +2189,8 @@ def analyze_class_decorator_common( """ if refers_to_fullname(decorator, FINAL_DECORATOR_NAMES): info.is_final = True + elif refers_to_fullname(decorator, DISJOINT_BASE_DECORATOR_NAMES): + info.is_disjoint_base = True elif refers_to_fullname(decorator, TYPE_CHECK_ONLY_NAMES): info.is_type_check_only = True elif (deprecated := self.get_deprecated(decorator)) is not None: diff --git a/mypy/stubtest.py b/mypy/stubtest.py index ef8c8dc318e1..43da0518b3f9 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -17,6 +17,7 @@ import os import pkgutil import re +import struct import symtable import sys import traceback @@ -466,6 +467,67 @@ class SubClass(runtime): # type: ignore[misc] ) +SIZEOF_PYOBJECT = struct.calcsize("P") + + +def _shape_differs(t1: type[object], t2: type[object]) -> bool: + """Check whether two types differ in shape. + + Mirrors the shape_differs() function in typeobject.c in CPython.""" + if sys.version_info >= (3, 12): + return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__ + else: + # CPython had more complicated logic before 3.12: + # https://github.com/python/cpython/blob/f3c6f882cddc8dc30320d2e73edf019e201394fc/Objects/typeobject.c#L2224 + # We attempt to mirror it here well enough to support the most common cases. + if t1.__itemsize__ or t2.__itemsize__: + return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__ + t_size = t1.__basicsize__ + if not t2.__weakrefoffset__ and t1.__weakrefoffset__ + SIZEOF_PYOBJECT == t_size: + t_size -= SIZEOF_PYOBJECT + if not t2.__dictoffset__ and t1.__dictoffset__ + SIZEOF_PYOBJECT == t_size: + t_size -= SIZEOF_PYOBJECT + if not t2.__weakrefoffset__ and t2.__weakrefoffset__ == t_size: + t_size -= SIZEOF_PYOBJECT + return t_size != t2.__basicsize__ + + +def _is_disjoint_base(typ: type[object]) -> bool: + """Return whether a type is a disjoint base at runtime, mirroring CPython's logic in typeobject.c. + + See PEP 800.""" + if typ is object: + return True + base = typ.__base__ + assert base is not None, f"Type {typ} has no base" + return _shape_differs(typ, base) + + +def _verify_disjoint_base( + stub: nodes.TypeInfo, runtime: type[object], object_path: list[str] +) -> Iterator[Error]: + # If it's final, doesn't matter whether it's a disjoint base or not + if stub.is_final: + return + is_disjoint_runtime = _is_disjoint_base(runtime) + if is_disjoint_runtime and not stub.is_disjoint_base: + yield Error( + object_path, + "is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub", + stub, + runtime, + stub_desc=repr(stub), + ) + elif not is_disjoint_runtime and stub.is_disjoint_base: + yield Error( + object_path, + "is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime", + stub, + runtime, + stub_desc=repr(stub), + ) + + def _verify_metaclass( stub: nodes.TypeInfo, runtime: type[Any], object_path: list[str], *, is_runtime_typeddict: bool ) -> Iterator[Error]: @@ -534,6 +596,7 @@ def verify_typeinfo( return yield from _verify_final(stub, runtime, object_path) + yield from _verify_disjoint_base(stub, runtime, object_path) is_runtime_typeddict = stub.typeddict_type is not None and is_typeddict(runtime) yield from _verify_metaclass( stub, runtime, object_path, is_runtime_typeddict=is_runtime_typeddict diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index b071c0ee8ab6..69e2abe62f85 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -1405,9 +1405,16 @@ def spam(x=Flags4(0)): pass ) yield Case( stub=""" + import sys from typing import Final, Literal - class BytesEnum(bytes, enum.Enum): - a = b'foo' + from typing_extensions import disjoint_base + if sys.version_info >= (3, 12): + class BytesEnum(bytes, enum.Enum): + a = b'foo' + else: + @disjoint_base + class BytesEnum(bytes, enum.Enum): + a = b'foo' FOO: Literal[BytesEnum.a] BAR: Final = BytesEnum.a BAZ: BytesEnum @@ -1613,6 +1620,60 @@ def test_not_subclassable(self) -> Iterator[Case]: error="CannotBeSubclassed", ) + @collect_cases + def test_disjoint_base(self) -> Iterator[Case]: + yield Case( + stub=""" + class A: pass + """, + runtime=""" + class A: pass + """, + error=None, + ) + yield Case( + stub=""" + from typing_extensions import disjoint_base + + @disjoint_base + class B: pass + """, + runtime=""" + class B: pass + """, + error="test_module.B", + ) + yield Case( + stub=""" + from typing_extensions import Self + + class mytakewhile: + def __new__(cls, predicate: object, iterable: object, /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> object: ... + """, + runtime=""" + from itertools import takewhile as mytakewhile + """, + # Should have @disjoint_base + error="test_module.mytakewhile", + ) + yield Case( + stub=""" + from typing_extensions import disjoint_base, Self + + @disjoint_base + class mycorrecttakewhile: + def __new__(cls, predicate: object, iterable: object, /) -> Self: ... + def __iter__(self) -> Self: ... + def __next__(self) -> object: ... + """, + runtime=""" + from itertools import takewhile as mycorrecttakewhile + """, + error=None, + ) + @collect_cases def test_has_runtime_final_decorator(self) -> Iterator[Case]: yield Case( diff --git a/mypy/typeops.py b/mypy/typeops.py index 88b3c5da48ce..0cb6018d01fd 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -1253,3 +1253,54 @@ def named_type(fullname: str) -> Instance: ) ) return subtype + + +def _is_disjoint_base(info: TypeInfo) -> bool: + # It either has the @disjoint_base decorator or defines nonempty __slots__. + if info.is_disjoint_base: + return True + if not info.slots: + return False + own_slots = { + slot + for slot in info.slots + if not any( + base_info.type.slots is not None and slot in base_info.type.slots + for base_info in info.bases + ) + } + return bool(own_slots) + + +def _get_disjoint_base_of(instance: Instance) -> TypeInfo | None: + """Returns the disjoint base of the given instance, if it exists.""" + if _is_disjoint_base(instance.type): + return instance.type + for base in instance.type.mro: + if _is_disjoint_base(base): + return base + return None + + +def can_have_shared_disjoint_base(instances: list[Instance]) -> bool: + """Returns whether the given instances can share a disjoint base. + + This means that a child class of these classes can exist at runtime. + """ + # Ignore None disjoint bases (which are `object`). + disjoint_bases = [ + base for instance in instances if (base := _get_disjoint_base_of(instance)) is not None + ] + if not disjoint_bases: + # All are `object`. + return True + + candidate = disjoint_bases[0] + for base in disjoint_bases[1:]: + if candidate.has_base(base.fullname): + continue + elif base.has_base(candidate.fullname): + candidate = base + else: + return False + return True diff --git a/mypy/types.py b/mypy/types.py index 26c5b474ba6c..4fa8b8e64703 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -119,9 +119,12 @@ # Supported Unpack type names. UNPACK_TYPE_NAMES: Final = ("typing.Unpack", "typing_extensions.Unpack") -# Supported @deprecated type names +# Supported @deprecated decorator names DEPRECATED_TYPE_NAMES: Final = ("warnings.deprecated", "typing_extensions.deprecated") +# Supported @disjoint_base decorator names +DISJOINT_BASE_DECORATOR_NAMES: Final = ("typing.disjoint_base", "typing_extensions.disjoint_base") + # We use this constant in various places when checking `tuple` subtyping: TUPLE_LIKE_INSTANCE_NAMES: Final = ( "builtins.tuple", diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 23dbe2bc07af..5cc4910fb265 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -6084,7 +6084,7 @@ class B(A): __slots__ = ('a', 'b') class C: __slots__ = ('x',) -class D(B, C): +class D(B, C): # E: Class "D" has incompatible disjoint bases __slots__ = ('aa', 'bb', 'cc') [builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-final.test b/test-data/unit/check-final.test index d23199dc8b33..e3fc4614fc06 100644 --- a/test-data/unit/check-final.test +++ b/test-data/unit/check-final.test @@ -1272,3 +1272,67 @@ if FOO is not None: def func() -> int: return FOO + +[case testDisjointBase] +from typing_extensions import disjoint_base + +@disjoint_base +class Disjoint1: pass + +@disjoint_base +class Disjoint2: pass + +@disjoint_base +class DisjointChild(Disjoint1): pass + +class C1: pass +class C2(Disjoint1, C1): pass +class C3(DisjointChild, Disjoint1): pass + +class C4(Disjoint1, Disjoint2): # E: Class "C4" has incompatible disjoint bases + pass + +class C5(Disjoint2, Disjoint1): # E: Class "C5" has incompatible disjoint bases + pass + +class C6(Disjoint2, DisjointChild): # E: Class "C6" has incompatible disjoint bases + pass + +class C7(DisjointChild, Disjoint2): # E: Class "C7" has incompatible disjoint bases + pass + +class C8(DisjointChild, Disjoint1, Disjoint2): # E: Class "C8" has incompatible disjoint bases + pass + +class C9(C2, Disjoint2): # E: Class "C9" has incompatible disjoint bases + pass + +class C10(C3, Disjoint2): # E: Class "C10" has incompatible disjoint bases + pass + +[builtins fixtures/tuple.pyi] +[case testDisjointBaseSlots] +class S1: + __slots__ = ("a",) + +class S2: + __slots__ = ("b",) + +class S3: + __slots__ = () + +class S4(S1): + __slots__ = ("c",) + +class S5(S1, S2): # E: Class "S5" has incompatible disjoint bases + pass + +class S6(S1, S3): pass # OK +class S7(S3, S1): pass # OK + +class S8(S4, S1): pass # OK + +class S9(S2, S4): # E: Class "S9" has incompatible disjoint bases + pass + +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 7d791319537f..c3fe98e69d95 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -3170,13 +3170,13 @@ C(5, 'foo', True) [file a.py] import attrs -@attrs.define +@attrs.define(slots=False) class A: a: int [file b.py] import attrs -@attrs.define +@attrs.define(slots=False) class B: b: str @@ -3184,7 +3184,7 @@ class B: from a import A from b import B import attrs -@attrs.define +@attrs.define(slots=False) class C(A, B): c: bool diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 640fc10915d1..5043d5422108 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -2551,6 +2551,33 @@ def f2(x: T2) -> T2: return C() [builtins fixtures/isinstance.pyi] +[case testIsInstanceDisjointBase] +# flags: --warn-unreachable +from typing_extensions import disjoint_base + +@disjoint_base +class Disjoint1: pass +@disjoint_base +class Disjoint2: pass +@disjoint_base +class Disjoint3(Disjoint1): pass +class Child(Disjoint1): pass +class Unrelated: pass + +def f(d1: Disjoint1, u: Unrelated, c: Child) -> object: + if isinstance(d1, Disjoint2): # E: Subclass of "Disjoint1" and "Disjoint2" cannot exist: have distinct disjoint bases + return u # E: Statement is unreachable + if isinstance(u, Disjoint1): # OK + return d1 + if isinstance(c, Disjoint3): # OK + return c + if isinstance(c, Disjoint2): # E: Subclass of "Child" and "Disjoint2" cannot exist: have distinct disjoint bases + return c # E: Statement is unreachable + return d1 + +[builtins fixtures/isinstance.pyi] + + [case testIsInstanceAdHocIntersectionUsage] # flags: --warn-unreachable class A: pass diff --git a/test-data/unit/check-slots.test b/test-data/unit/check-slots.test index 10b664bffb11..25dd630e1cbe 100644 --- a/test-data/unit/check-slots.test +++ b/test-data/unit/check-slots.test @@ -180,6 +180,7 @@ b.m = 2 b.b = 2 b._two = 2 [out] +main:5: error: Class "B" has incompatible disjoint bases main:11: error: Trying to assign name "_one" that is not in "__slots__" of type "__main__.B" main:16: error: "B" has no attribute "b" main:17: error: "B" has no attribute "_two" diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index cb054b0e6b4f..6158a0c9ebbc 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -93,5 +93,6 @@ def dataclass_transform( def override(__arg: _T) -> _T: ... def deprecated(__msg: str) -> Callable[[_T], _T]: ... +def disjoint_base(__arg: _T) -> _T: ... _FutureFeatureFixture = 0