Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
from mypy.typeops import (
bind_self,
can_have_shared_disjoint_base,
coerce_to_literal,
custom_special_method,
erase_def_to_union_or_bound,
Expand Down Expand Up @@ -2658,6 +2659,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
for base in typ.mro[1:]:
if base.is_final:
self.fail(message_registry.CANNOT_INHERIT_FROM_FINAL.format(base.name), defn)
if not can_have_shared_disjoint_base(typ.bases):
self.fail(message_registry.INCOMPATIBLE_DISJOINT_BASES.format(typ.name), defn)
with self.tscope.class_scope(defn.info), self.enter_partial_types(is_class=True):
old_binder = self.binder
self.binder = ConditionalTypeBinder(self.options)
Expand Down Expand Up @@ -5826,6 +5829,10 @@ def _make_fake_typeinfo_and_full_name(
format_type_distinctly(*base_classes, options=self.options, bare=True), "and"
)

if not can_have_shared_disjoint_base(base_classes):
errors.append((pretty_names_list, "have distinct disjoint bases"))
return None

new_errors = []
for base in base_classes:
if base.type.is_final:
Expand Down
8 changes: 1 addition & 7 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,13 +1484,7 @@ def analyze_decorator_or_funcbase_access(
if isinstance(defn, Decorator):
return analyze_var(name, defn.var, itype, mx)
typ = function_type(defn, mx.chk.named_type("builtins.function"))
is_trivial_self = False
if isinstance(defn, Decorator):
Copy link
Member Author

@JelleZijlstra JelleZijlstra Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition can never match because we already check for Decorator a few lines up.

# Use fast path if there are trivial decorators like @classmethod or @property
is_trivial_self = defn.func.is_trivial_self and not defn.decorators
elif isinstance(defn, (FuncDef, OverloadedFuncDef)):
is_trivial_self = defn.is_trivial_self
if is_trivial_self:
if isinstance(defn, (FuncDef, OverloadedFuncDef)) and defn.is_trivial_self:
return bind_self_fast(typ, mx.self_type)
typ = check_self_arg(typ, mx.self_type, defn.is_class, mx.context, name, mx.msg)
return bind_self(typ, original_type=mx.self_type, is_classmethod=defn.is_class)
Expand Down
3 changes: 3 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
)
CANNOT_MAKE_DELETABLE_FINAL: Final = ErrorMessage("Deletable attribute cannot be final")

# Disjoint bases
INCOMPATIBLE_DISJOINT_BASES: Final = ErrorMessage('Class "{}" has incompatible disjoint bases')

# Enum
ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDDEN: Final = ErrorMessage(
'Assigned "__members__" will be overridden by "Enum" internally'
Expand Down
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,6 +3004,7 @@ class is generic then it will be a type constructor of higher kind.
"_mro_refs",
"bad_mro",
"is_final",
"is_disjoint_base",
"declared_metaclass",
"metaclass_type",
"names",
Expand Down Expand Up @@ -3055,6 +3056,7 @@ class is generic then it will be a type constructor of higher kind.
_mro_refs: list[str] | None
bad_mro: bool # Could not construct full MRO
is_final: bool
is_disjoint_base: bool

declared_metaclass: mypy.types.Instance | None
metaclass_type: mypy.types.Instance | None
Expand Down Expand Up @@ -3209,6 +3211,7 @@ class is generic then it will be a type constructor of higher kind.
"is_protocol",
"runtime_protocol",
"is_final",
"is_disjoint_base",
"is_intersection",
]

Expand Down Expand Up @@ -3241,6 +3244,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
self.type_var_tuple_suffix: int | None = None
self.add_type_vars()
self.is_final = False
self.is_disjoint_base = False
self.is_enum = False
self.fallback_to_any = False
self.meta_fallback_to_any = False
Expand Down
3 changes: 3 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@
ASSERT_TYPE_NAMES,
DATACLASS_TRANSFORM_NAMES,
DEPRECATED_TYPE_NAMES,
DISJOINT_BASE_DECORATOR_NAMES,
FINAL_DECORATOR_NAMES,
FINAL_TYPE_NAMES,
IMPORTED_REVEAL_TYPE_NAMES,
Expand Down Expand Up @@ -2188,6 +2189,8 @@ def analyze_class_decorator_common(
"""
if refers_to_fullname(decorator, FINAL_DECORATOR_NAMES):
info.is_final = True
elif refers_to_fullname(decorator, DISJOINT_BASE_DECORATOR_NAMES):
info.is_disjoint_base = True
elif refers_to_fullname(decorator, TYPE_CHECK_ONLY_NAMES):
info.is_type_check_only = True
elif (deprecated := self.get_deprecated(decorator)) is not None:
Expand Down
63 changes: 63 additions & 0 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pkgutil
import re
import struct
import symtable
import sys
import traceback
Expand Down Expand Up @@ -466,6 +467,67 @@ class SubClass(runtime): # type: ignore[misc]
)


SIZEOF_PYOBJECT = struct.calcsize("P")


def _shape_differs(t1: type[object], t2: type[object]) -> bool:
"""Check whether two types differ in shape.

Mirrors the shape_differs() function in typeobject.c in CPython."""
if sys.version_info >= (3, 12):
return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
else:
# CPython had more complicated logic before 3.12:
# https://github.com/python/cpython/blob/f3c6f882cddc8dc30320d2e73edf019e201394fc/Objects/typeobject.c#L2224
# We attempt to mirror it here well enough to support the most common cases.
if t1.__itemsize__ or t2.__itemsize__:
return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
t_size = t1.__basicsize__
if not t2.__weakrefoffset__ and t1.__weakrefoffset__ + SIZEOF_PYOBJECT == t_size:
t_size -= SIZEOF_PYOBJECT
if not t2.__dictoffset__ and t1.__dictoffset__ + SIZEOF_PYOBJECT == t_size:
t_size -= SIZEOF_PYOBJECT
if not t2.__weakrefoffset__ and t2.__weakrefoffset__ == t_size:
t_size -= SIZEOF_PYOBJECT
return t_size != t2.__basicsize__


def _is_disjoint_base(typ: type[object]) -> bool:
"""Return whether a type is a disjoint base at runtime, mirroring CPython's logic in typeobject.c.

See PEP 800."""
if typ is object:
return True
base = typ.__base__
assert base is not None, f"Type {typ} has no base"
return _shape_differs(typ, base)


def _verify_disjoint_base(
stub: nodes.TypeInfo, runtime: type[object], object_path: list[str]
) -> Iterator[Error]:
# If it's final, doesn't matter whether it's a disjoint base or not
if stub.is_final:
return
is_disjoint_runtime = _is_disjoint_base(runtime)
if is_disjoint_runtime and not stub.is_disjoint_base:
yield Error(
object_path,
"is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub",
stub,
runtime,
stub_desc=repr(stub),
)
elif not is_disjoint_runtime and stub.is_disjoint_base:
yield Error(
object_path,
"is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime",
stub,
runtime,
stub_desc=repr(stub),
)


def _verify_metaclass(
stub: nodes.TypeInfo, runtime: type[Any], object_path: list[str], *, is_runtime_typeddict: bool
) -> Iterator[Error]:
Expand Down Expand Up @@ -534,6 +596,7 @@ def verify_typeinfo(
return

yield from _verify_final(stub, runtime, object_path)
yield from _verify_disjoint_base(stub, runtime, object_path)
is_runtime_typeddict = stub.typeddict_type is not None and is_typeddict(runtime)
yield from _verify_metaclass(
stub, runtime, object_path, is_runtime_typeddict=is_runtime_typeddict
Expand Down
65 changes: 63 additions & 2 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,9 +1405,16 @@ def spam(x=Flags4(0)): pass
)
yield Case(
stub="""
import sys
from typing import Final, Literal
class BytesEnum(bytes, enum.Enum):
a = b'foo'
from typing_extensions import disjoint_base
if sys.version_info >= (3, 12):
class BytesEnum(bytes, enum.Enum):
a = b'foo'
else:
@disjoint_base
class BytesEnum(bytes, enum.Enum):
a = b'foo'
FOO: Literal[BytesEnum.a]
BAR: Final = BytesEnum.a
BAZ: BytesEnum
Expand Down Expand Up @@ -1613,6 +1620,60 @@ def test_not_subclassable(self) -> Iterator[Case]:
error="CannotBeSubclassed",
)

@collect_cases
def test_disjoint_base(self) -> Iterator[Case]:
yield Case(
stub="""
class A: pass
""",
runtime="""
class A: pass
""",
error=None,
)
yield Case(
stub="""
from typing_extensions import disjoint_base

@disjoint_base
class B: pass
""",
runtime="""
class B: pass
""",
error="test_module.B",
)
yield Case(
stub="""
from typing_extensions import Self

class mytakewhile:
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> object: ...
""",
runtime="""
from itertools import takewhile as mytakewhile
""",
# Should have @disjoint_base
error="test_module.mytakewhile",
)
yield Case(
stub="""
from typing_extensions import disjoint_base, Self

@disjoint_base
class mycorrecttakewhile:
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> object: ...
""",
runtime="""
from itertools import takewhile as mycorrecttakewhile
""",
error=None,
)

@collect_cases
def test_has_runtime_final_decorator(self) -> Iterator[Case]:
yield Case(
Expand Down
51 changes: 51 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,3 +1253,54 @@ def named_type(fullname: str) -> Instance:
)
)
return subtype


def _is_disjoint_base(info: TypeInfo) -> bool:
# It either has the @disjoint_base decorator or defines nonempty __slots__.
if info.is_disjoint_base:
return True
if not info.slots:
return False
own_slots = {
slot
for slot in info.slots
if not any(
base_info.type.slots is not None and slot in base_info.type.slots
for base_info in info.bases
)
}
return bool(own_slots)


def _get_disjoint_base_of(instance: Instance) -> TypeInfo | None:
"""Returns the disjoint base of the given instance, if it exists."""
if _is_disjoint_base(instance.type):
return instance.type
for base in instance.type.mro:
if _is_disjoint_base(base):
return base
return None


def can_have_shared_disjoint_base(instances: list[Instance]) -> bool:
"""Returns whether the given instances can share a disjoint base.

This means that a child class of these classes can exist at runtime.
"""
# Ignore None disjoint bases (which are `object`).
disjoint_bases = [
base for instance in instances if (base := _get_disjoint_base_of(instance)) is not None
]
if not disjoint_bases:
# All are `object`.
return True

candidate = disjoint_bases[0]
for base in disjoint_bases[1:]:
if candidate.has_base(base.fullname):
continue
elif base.has_base(candidate.fullname):
candidate = base
else:
return False
return True
5 changes: 4 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@
# Supported Unpack type names.
UNPACK_TYPE_NAMES: Final = ("typing.Unpack", "typing_extensions.Unpack")

# Supported @deprecated type names
# Supported @deprecated decorator names
DEPRECATED_TYPE_NAMES: Final = ("warnings.deprecated", "typing_extensions.deprecated")

# Supported @disjoint_base decorator names
DISJOINT_BASE_DECORATOR_NAMES: Final = ("typing.disjoint_base", "typing_extensions.disjoint_base")

# We use this constant in various places when checking `tuple` subtyping:
TUPLE_LIKE_INSTANCE_NAMES: Final = (
"builtins.tuple",
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -6084,7 +6084,7 @@ class B(A):
__slots__ = ('a', 'b')
class C:
__slots__ = ('x',)
class D(B, C):
class D(B, C): # E: Class "D" has incompatible disjoint bases
__slots__ = ('aa', 'bb', 'cc')
[builtins fixtures/tuple.pyi]

Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-final.test
Original file line number Diff line number Diff line change
Expand Up @@ -1272,3 +1272,24 @@ if FOO is not None:

def func() -> int:
return FOO

[case testDisjointBase]
from typing_extensions import disjoint_base

@disjoint_base
class Disjoint1: pass

@disjoint_base
class Disjoint2: pass

@disjoint_base
class DisjointChild(Disjoint1): pass

class C1: pass
class C2(Disjoint1, C1): pass
class C3(DisjointChild, Disjoint1): pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some ideas for additional test cases:

  • Test reversed order of bases.
  • Test with all three Disjoint* bases.
  • Test a subclass of C3/C2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, also added some more tests.


class C4(Disjoint1, Disjoint2): # E: Class "C4" has incompatible disjoint bases
pass

[builtins fixtures/tuple.pyi]
6 changes: 3 additions & 3 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -3170,21 +3170,21 @@ C(5, 'foo', True)

[file a.py]
import attrs
@attrs.define
@attrs.define(slots=False)
class A:
a: int

[file b.py]
import attrs
@attrs.define
@attrs.define(slots=False)
class B:
b: str

[file c.py]
from a import A
from b import B
import attrs
@attrs.define
@attrs.define(slots=False)
class C(A, B):
c: bool

Expand Down
Loading
Loading