Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
from mypy.typeops import (
bind_self,
can_have_shared_disjoint_base,
coerce_to_literal,
custom_special_method,
erase_def_to_union_or_bound,
Expand Down Expand Up @@ -2658,6 +2659,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
for base in typ.mro[1:]:
if base.is_final:
self.fail(message_registry.CANNOT_INHERIT_FROM_FINAL.format(base.name), defn)
if not can_have_shared_disjoint_base(typ.bases):
self.fail(message_registry.INCOMPATIBLE_DISJOINT_BASES.format(typ.name), defn)
with self.tscope.class_scope(defn.info), self.enter_partial_types(is_class=True):
old_binder = self.binder
self.binder = ConditionalTypeBinder(self.options)
Expand Down Expand Up @@ -5826,6 +5829,10 @@ def _make_fake_typeinfo_and_full_name(
format_type_distinctly(*base_classes, options=self.options, bare=True), "and"
)

if not can_have_shared_disjoint_base(base_classes):
errors.append((pretty_names_list, "have distinct disjoint bases"))
return None

new_errors = []
for base in base_classes:
if base.type.is_final:
Expand Down
8 changes: 1 addition & 7 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,13 +1484,7 @@ def analyze_decorator_or_funcbase_access(
if isinstance(defn, Decorator):
return analyze_var(name, defn.var, itype, mx)
typ = function_type(defn, mx.chk.named_type("builtins.function"))
is_trivial_self = False
if isinstance(defn, Decorator):
Copy link
Member Author

@JelleZijlstra JelleZijlstra Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition can never match because we already check for Decorator a few lines up.

# Use fast path if there are trivial decorators like @classmethod or @property
is_trivial_self = defn.func.is_trivial_self and not defn.decorators
elif isinstance(defn, (FuncDef, OverloadedFuncDef)):
is_trivial_self = defn.is_trivial_self
if is_trivial_self:
if isinstance(defn, (FuncDef, OverloadedFuncDef)) and defn.is_trivial_self:
return bind_self_fast(typ, mx.self_type)
typ = check_self_arg(typ, mx.self_type, defn.is_class, mx.context, name, mx.msg)
return bind_self(typ, original_type=mx.self_type, is_classmethod=defn.is_class)
Expand Down
3 changes: 3 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
)
CANNOT_MAKE_DELETABLE_FINAL: Final = ErrorMessage("Deletable attribute cannot be final")

# Disjoint bases
INCOMPATIBLE_DISJOINT_BASES: Final = ErrorMessage('Class "{}" has incompatible disjoint bases')

# Enum
ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDDEN: Final = ErrorMessage(
'Assigned "__members__" will be overridden by "Enum" internally'
Expand Down
4 changes: 4 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,6 +3004,7 @@ class is generic then it will be a type constructor of higher kind.
"_mro_refs",
"bad_mro",
"is_final",
"is_disjoint_base",
"declared_metaclass",
"metaclass_type",
"names",
Expand Down Expand Up @@ -3055,6 +3056,7 @@ class is generic then it will be a type constructor of higher kind.
_mro_refs: list[str] | None
bad_mro: bool # Could not construct full MRO
is_final: bool
is_disjoint_base: bool

declared_metaclass: mypy.types.Instance | None
metaclass_type: mypy.types.Instance | None
Expand Down Expand Up @@ -3209,6 +3211,7 @@ class is generic then it will be a type constructor of higher kind.
"is_protocol",
"runtime_protocol",
"is_final",
"is_disjoint_base",
"is_intersection",
]

Expand Down Expand Up @@ -3241,6 +3244,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
self.type_var_tuple_suffix: int | None = None
self.add_type_vars()
self.is_final = False
self.is_disjoint_base = False
self.is_enum = False
self.fallback_to_any = False
self.meta_fallback_to_any = False
Expand Down
3 changes: 3 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@
ASSERT_TYPE_NAMES,
DATACLASS_TRANSFORM_NAMES,
DEPRECATED_TYPE_NAMES,
DISJOINT_BASE_DECORATOR_NAMES,
FINAL_DECORATOR_NAMES,
FINAL_TYPE_NAMES,
IMPORTED_REVEAL_TYPE_NAMES,
Expand Down Expand Up @@ -2188,6 +2189,8 @@ def analyze_class_decorator_common(
"""
if refers_to_fullname(decorator, FINAL_DECORATOR_NAMES):
info.is_final = True
elif refers_to_fullname(decorator, DISJOINT_BASE_DECORATOR_NAMES):
info.is_disjoint_base = True
elif refers_to_fullname(decorator, TYPE_CHECK_ONLY_NAMES):
info.is_type_check_only = True
elif (deprecated := self.get_deprecated(decorator)) is not None:
Expand Down
63 changes: 63 additions & 0 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import pkgutil
import re
import struct
import symtable
import sys
import traceback
Expand Down Expand Up @@ -466,6 +467,67 @@ class SubClass(runtime): # type: ignore[misc]
)


SIZEOF_PYOBJECT = struct.calcsize("P")


def _shape_differs(t1: type[object], t2: type[object]) -> bool:
"""Check whether two types differ in shape.

Mirrors the shape_differs() function in typeobject.c in CPython."""
if sys.version_info >= (3, 12):
return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
else:
# CPython had more complicated logic before 3.12:
# https://github.com/python/cpython/blob/f3c6f882cddc8dc30320d2e73edf019e201394fc/Objects/typeobject.c#L2224
# We attempt to mirror it here well enough to support the most common cases.
if t1.__itemsize__ or t2.__itemsize__:
return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
t_size = t1.__basicsize__
if not t2.__weakrefoffset__ and t1.__weakrefoffset__ + SIZEOF_PYOBJECT == t_size:
t_size -= SIZEOF_PYOBJECT
if not t2.__dictoffset__ and t1.__dictoffset__ + SIZEOF_PYOBJECT == t_size:
t_size -= SIZEOF_PYOBJECT
if not t2.__weakrefoffset__ and t2.__weakrefoffset__ == t_size:
t_size -= SIZEOF_PYOBJECT
return t_size != t2.__basicsize__


def _is_disjoint_base(typ: type[object]) -> bool:
"""Return whether a type is a disjoint base at runtime, mirroring CPython's logic in typeobject.c.

See PEP 800."""
if typ is object:
return True
base = typ.__base__
assert base is not None, f"Type {typ} has no base"
return _shape_differs(typ, base)


def _verify_disjoint_base(
stub: nodes.TypeInfo, runtime: type[object], object_path: list[str]
) -> Iterator[Error]:
# If it's final, doesn't matter whether it's a disjoint base or not
if stub.is_final:
return
is_disjoint_runtime = _is_disjoint_base(runtime)
if is_disjoint_runtime and not stub.is_disjoint_base:
yield Error(
object_path,
"is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub",
stub,
runtime,
stub_desc=repr(stub),
)
elif not is_disjoint_runtime and stub.is_disjoint_base:
yield Error(
object_path,
"is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime",
stub,
runtime,
stub_desc=repr(stub),
)


def _verify_metaclass(
stub: nodes.TypeInfo, runtime: type[Any], object_path: list[str], *, is_runtime_typeddict: bool
) -> Iterator[Error]:
Expand Down Expand Up @@ -534,6 +596,7 @@ def verify_typeinfo(
return

yield from _verify_final(stub, runtime, object_path)
yield from _verify_disjoint_base(stub, runtime, object_path)
is_runtime_typeddict = stub.typeddict_type is not None and is_typeddict(runtime)
yield from _verify_metaclass(
stub, runtime, object_path, is_runtime_typeddict=is_runtime_typeddict
Expand Down
65 changes: 63 additions & 2 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,9 +1405,16 @@ def spam(x=Flags4(0)): pass
)
yield Case(
stub="""
import sys
from typing import Final, Literal
class BytesEnum(bytes, enum.Enum):
a = b'foo'
from typing_extensions import disjoint_base
if sys.version_info >= (3, 12):
class BytesEnum(bytes, enum.Enum):
a = b'foo'
else:
@disjoint_base
class BytesEnum(bytes, enum.Enum):
a = b'foo'
FOO: Literal[BytesEnum.a]
BAR: Final = BytesEnum.a
BAZ: BytesEnum
Expand Down Expand Up @@ -1613,6 +1620,60 @@ def test_not_subclassable(self) -> Iterator[Case]:
error="CannotBeSubclassed",
)

@collect_cases
def test_disjoint_base(self) -> Iterator[Case]:
yield Case(
stub="""
class A: pass
""",
runtime="""
class A: pass
""",
error=None,
)
yield Case(
stub="""
from typing_extensions import disjoint_base

@disjoint_base
class B: pass
""",
runtime="""
class B: pass
""",
error="test_module.B",
)
yield Case(
stub="""
from typing_extensions import Self

class mytakewhile:
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> object: ...
""",
runtime="""
from itertools import takewhile as mytakewhile
""",
# Should have @disjoint_base
error="test_module.mytakewhile",
)
yield Case(
stub="""
from typing_extensions import disjoint_base, Self

@disjoint_base
class mycorrecttakewhile:
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
def __iter__(self) -> Self: ...
def __next__(self) -> object: ...
""",
runtime="""
from itertools import takewhile as mycorrecttakewhile
""",
error=None,
)

@collect_cases
def test_has_runtime_final_decorator(self) -> Iterator[Case]:
yield Case(
Expand Down
51 changes: 51 additions & 0 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,3 +1253,54 @@ def named_type(fullname: str) -> Instance:
)
)
return subtype


def _is_disjoint_base(info: TypeInfo) -> bool:
# It either has the @disjoint_base decorator or defines nonempty __slots__.
if info.is_disjoint_base:
return True
if not info.slots:
return False
own_slots = {
slot
for slot in info.slots
if not any(
base_info.type.slots is not None and slot in base_info.type.slots
for base_info in info.bases
)
}
return bool(own_slots)


def _get_disjoint_base_of(instance: Instance) -> TypeInfo | None:
"""Returns the disjoint base of the given instance, if it exists."""
if _is_disjoint_base(instance.type):
return instance.type
for base in instance.type.mro:
if _is_disjoint_base(base):
return base
return None


def can_have_shared_disjoint_base(instances: list[Instance]) -> bool:
"""Returns whether the given instances can share a disjoint base.

This means that a child class of these classes can exist at runtime.
"""
# Ignore None disjoint bases (which are `object`).
disjoint_bases = [
base for instance in instances if (base := _get_disjoint_base_of(instance)) is not None
]
if not disjoint_bases:
# All are `object`.
return True

candidate = disjoint_bases[0]
for base in disjoint_bases[1:]:
if candidate.has_base(base.fullname):
continue
elif base.has_base(candidate.fullname):
candidate = base
else:
return False
return True
5 changes: 4 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@
# Supported Unpack type names.
UNPACK_TYPE_NAMES: Final = ("typing.Unpack", "typing_extensions.Unpack")

# Supported @deprecated type names
# Supported @deprecated decorator names
DEPRECATED_TYPE_NAMES: Final = ("warnings.deprecated", "typing_extensions.deprecated")

# Supported @disjoint_base decorator names
DISJOINT_BASE_DECORATOR_NAMES: Final = ("typing.disjoint_base", "typing_extensions.disjoint_base")

# We use this constant in various places when checking `tuple` subtyping:
TUPLE_LIKE_INSTANCE_NAMES: Final = (
"builtins.tuple",
Expand Down
2 changes: 1 addition & 1 deletion test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -6084,7 +6084,7 @@ class B(A):
__slots__ = ('a', 'b')
class C:
__slots__ = ('x',)
class D(B, C):
class D(B, C): # E: Class "D" has incompatible disjoint bases
__slots__ = ('aa', 'bb', 'cc')
[builtins fixtures/tuple.pyi]

Expand Down
Loading
Loading