Skip to content

Commit 0d23c61

Browse files
Implement PEP 800 (@disjoint_base) (#19678)
https://peps.python.org/pep-0800/ - Recognize the @disjoint_base decorator - Error if a class definition has incompatible disjoint bases - Recognize that classes with incompatible disjoint bases cannot exist - Check in stubtest that @disjoint_base is correctly applied - The self check found a line of dead code in mypy itself, due to classes that are disjoint bases from `__slots__`.
1 parent 722f4dd commit 0d23c61

15 files changed

+296
-14
lines changed

mypy/checker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
172172
from mypy.typeops import (
173173
bind_self,
174+
can_have_shared_disjoint_base,
174175
coerce_to_literal,
175176
custom_special_method,
176177
erase_def_to_union_or_bound,
@@ -2658,6 +2659,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
26582659
for base in typ.mro[1:]:
26592660
if base.is_final:
26602661
self.fail(message_registry.CANNOT_INHERIT_FROM_FINAL.format(base.name), defn)
2662+
if not can_have_shared_disjoint_base(typ.bases):
2663+
self.fail(message_registry.INCOMPATIBLE_DISJOINT_BASES.format(typ.name), defn)
26612664
with self.tscope.class_scope(defn.info), self.enter_partial_types(is_class=True):
26622665
old_binder = self.binder
26632666
self.binder = ConditionalTypeBinder(self.options)
@@ -5826,6 +5829,10 @@ def _make_fake_typeinfo_and_full_name(
58265829
format_type_distinctly(*base_classes, options=self.options, bare=True), "and"
58275830
)
58285831

5832+
if not can_have_shared_disjoint_base(base_classes):
5833+
errors.append((pretty_names_list, "have distinct disjoint bases"))
5834+
return None
5835+
58295836
new_errors = []
58305837
for base in base_classes:
58315838
if base.type.is_final:

mypy/checkmember.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,13 +1484,7 @@ def analyze_decorator_or_funcbase_access(
14841484
if isinstance(defn, Decorator):
14851485
return analyze_var(name, defn.var, itype, mx)
14861486
typ = function_type(defn, mx.chk.named_type("builtins.function"))
1487-
is_trivial_self = False
1488-
if isinstance(defn, Decorator):
1489-
# Use fast path if there are trivial decorators like @classmethod or @property
1490-
is_trivial_self = defn.func.is_trivial_self and not defn.decorators
1491-
elif isinstance(defn, (FuncDef, OverloadedFuncDef)):
1492-
is_trivial_self = defn.is_trivial_self
1493-
if is_trivial_self:
1487+
if isinstance(defn, (FuncDef, OverloadedFuncDef)) and defn.is_trivial_self:
14941488
return bind_self_fast(typ, mx.self_type)
14951489
typ = check_self_arg(typ, mx.self_type, defn.is_class, mx.context, name, mx.msg)
14961490
return bind_self(typ, original_type=mx.self_type, is_classmethod=defn.is_class)

mypy/message_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
239239
)
240240
CANNOT_MAKE_DELETABLE_FINAL: Final = ErrorMessage("Deletable attribute cannot be final")
241241

242+
# Disjoint bases
243+
INCOMPATIBLE_DISJOINT_BASES: Final = ErrorMessage('Class "{}" has incompatible disjoint bases')
244+
242245
# Enum
243246
ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDDEN: Final = ErrorMessage(
244247
'Assigned "__members__" will be overridden by "Enum" internally'

mypy/nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,6 +3004,7 @@ class is generic then it will be a type constructor of higher kind.
30043004
"_mro_refs",
30053005
"bad_mro",
30063006
"is_final",
3007+
"is_disjoint_base",
30073008
"declared_metaclass",
30083009
"metaclass_type",
30093010
"names",
@@ -3055,6 +3056,7 @@ class is generic then it will be a type constructor of higher kind.
30553056
_mro_refs: list[str] | None
30563057
bad_mro: bool # Could not construct full MRO
30573058
is_final: bool
3059+
is_disjoint_base: bool
30583060

30593061
declared_metaclass: mypy.types.Instance | None
30603062
metaclass_type: mypy.types.Instance | None
@@ -3209,6 +3211,7 @@ class is generic then it will be a type constructor of higher kind.
32093211
"is_protocol",
32103212
"runtime_protocol",
32113213
"is_final",
3214+
"is_disjoint_base",
32123215
"is_intersection",
32133216
]
32143217

@@ -3241,6 +3244,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
32413244
self.type_var_tuple_suffix: int | None = None
32423245
self.add_type_vars()
32433246
self.is_final = False
3247+
self.is_disjoint_base = False
32443248
self.is_enum = False
32453249
self.fallback_to_any = False
32463250
self.meta_fallback_to_any = False

mypy/semanal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@
254254
ASSERT_TYPE_NAMES,
255255
DATACLASS_TRANSFORM_NAMES,
256256
DEPRECATED_TYPE_NAMES,
257+
DISJOINT_BASE_DECORATOR_NAMES,
257258
FINAL_DECORATOR_NAMES,
258259
FINAL_TYPE_NAMES,
259260
IMPORTED_REVEAL_TYPE_NAMES,
@@ -2188,6 +2189,8 @@ def analyze_class_decorator_common(
21882189
"""
21892190
if refers_to_fullname(decorator, FINAL_DECORATOR_NAMES):
21902191
info.is_final = True
2192+
elif refers_to_fullname(decorator, DISJOINT_BASE_DECORATOR_NAMES):
2193+
info.is_disjoint_base = True
21912194
elif refers_to_fullname(decorator, TYPE_CHECK_ONLY_NAMES):
21922195
info.is_type_check_only = True
21932196
elif (deprecated := self.get_deprecated(decorator)) is not None:

mypy/stubtest.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import pkgutil
1919
import re
20+
import struct
2021
import symtable
2122
import sys
2223
import traceback
@@ -466,6 +467,67 @@ class SubClass(runtime): # type: ignore[misc]
466467
)
467468

468469

470+
SIZEOF_PYOBJECT = struct.calcsize("P")
471+
472+
473+
def _shape_differs(t1: type[object], t2: type[object]) -> bool:
474+
"""Check whether two types differ in shape.
475+
476+
Mirrors the shape_differs() function in typeobject.c in CPython."""
477+
if sys.version_info >= (3, 12):
478+
return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
479+
else:
480+
# CPython had more complicated logic before 3.12:
481+
# https://github.com/python/cpython/blob/f3c6f882cddc8dc30320d2e73edf019e201394fc/Objects/typeobject.c#L2224
482+
# We attempt to mirror it here well enough to support the most common cases.
483+
if t1.__itemsize__ or t2.__itemsize__:
484+
return t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
485+
t_size = t1.__basicsize__
486+
if not t2.__weakrefoffset__ and t1.__weakrefoffset__ + SIZEOF_PYOBJECT == t_size:
487+
t_size -= SIZEOF_PYOBJECT
488+
if not t2.__dictoffset__ and t1.__dictoffset__ + SIZEOF_PYOBJECT == t_size:
489+
t_size -= SIZEOF_PYOBJECT
490+
if not t2.__weakrefoffset__ and t2.__weakrefoffset__ == t_size:
491+
t_size -= SIZEOF_PYOBJECT
492+
return t_size != t2.__basicsize__
493+
494+
495+
def _is_disjoint_base(typ: type[object]) -> bool:
496+
"""Return whether a type is a disjoint base at runtime, mirroring CPython's logic in typeobject.c.
497+
498+
See PEP 800."""
499+
if typ is object:
500+
return True
501+
base = typ.__base__
502+
assert base is not None, f"Type {typ} has no base"
503+
return _shape_differs(typ, base)
504+
505+
506+
def _verify_disjoint_base(
507+
stub: nodes.TypeInfo, runtime: type[object], object_path: list[str]
508+
) -> Iterator[Error]:
509+
# If it's final, doesn't matter whether it's a disjoint base or not
510+
if stub.is_final:
511+
return
512+
is_disjoint_runtime = _is_disjoint_base(runtime)
513+
if is_disjoint_runtime and not stub.is_disjoint_base:
514+
yield Error(
515+
object_path,
516+
"is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub",
517+
stub,
518+
runtime,
519+
stub_desc=repr(stub),
520+
)
521+
elif not is_disjoint_runtime and stub.is_disjoint_base:
522+
yield Error(
523+
object_path,
524+
"is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime",
525+
stub,
526+
runtime,
527+
stub_desc=repr(stub),
528+
)
529+
530+
469531
def _verify_metaclass(
470532
stub: nodes.TypeInfo, runtime: type[Any], object_path: list[str], *, is_runtime_typeddict: bool
471533
) -> Iterator[Error]:
@@ -534,6 +596,7 @@ def verify_typeinfo(
534596
return
535597

536598
yield from _verify_final(stub, runtime, object_path)
599+
yield from _verify_disjoint_base(stub, runtime, object_path)
537600
is_runtime_typeddict = stub.typeddict_type is not None and is_typeddict(runtime)
538601
yield from _verify_metaclass(
539602
stub, runtime, object_path, is_runtime_typeddict=is_runtime_typeddict

mypy/test/teststubtest.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,9 +1405,16 @@ def spam(x=Flags4(0)): pass
14051405
)
14061406
yield Case(
14071407
stub="""
1408+
import sys
14081409
from typing import Final, Literal
1409-
class BytesEnum(bytes, enum.Enum):
1410-
a = b'foo'
1410+
from typing_extensions import disjoint_base
1411+
if sys.version_info >= (3, 12):
1412+
class BytesEnum(bytes, enum.Enum):
1413+
a = b'foo'
1414+
else:
1415+
@disjoint_base
1416+
class BytesEnum(bytes, enum.Enum):
1417+
a = b'foo'
14111418
FOO: Literal[BytesEnum.a]
14121419
BAR: Final = BytesEnum.a
14131420
BAZ: BytesEnum
@@ -1613,6 +1620,60 @@ def test_not_subclassable(self) -> Iterator[Case]:
16131620
error="CannotBeSubclassed",
16141621
)
16151622

1623+
@collect_cases
1624+
def test_disjoint_base(self) -> Iterator[Case]:
1625+
yield Case(
1626+
stub="""
1627+
class A: pass
1628+
""",
1629+
runtime="""
1630+
class A: pass
1631+
""",
1632+
error=None,
1633+
)
1634+
yield Case(
1635+
stub="""
1636+
from typing_extensions import disjoint_base
1637+
1638+
@disjoint_base
1639+
class B: pass
1640+
""",
1641+
runtime="""
1642+
class B: pass
1643+
""",
1644+
error="test_module.B",
1645+
)
1646+
yield Case(
1647+
stub="""
1648+
from typing_extensions import Self
1649+
1650+
class mytakewhile:
1651+
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
1652+
def __iter__(self) -> Self: ...
1653+
def __next__(self) -> object: ...
1654+
""",
1655+
runtime="""
1656+
from itertools import takewhile as mytakewhile
1657+
""",
1658+
# Should have @disjoint_base
1659+
error="test_module.mytakewhile",
1660+
)
1661+
yield Case(
1662+
stub="""
1663+
from typing_extensions import disjoint_base, Self
1664+
1665+
@disjoint_base
1666+
class mycorrecttakewhile:
1667+
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
1668+
def __iter__(self) -> Self: ...
1669+
def __next__(self) -> object: ...
1670+
""",
1671+
runtime="""
1672+
from itertools import takewhile as mycorrecttakewhile
1673+
""",
1674+
error=None,
1675+
)
1676+
16161677
@collect_cases
16171678
def test_has_runtime_final_decorator(self) -> Iterator[Case]:
16181679
yield Case(

mypy/typeops.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,3 +1253,54 @@ def named_type(fullname: str) -> Instance:
12531253
)
12541254
)
12551255
return subtype
1256+
1257+
1258+
def _is_disjoint_base(info: TypeInfo) -> bool:
1259+
# It either has the @disjoint_base decorator or defines nonempty __slots__.
1260+
if info.is_disjoint_base:
1261+
return True
1262+
if not info.slots:
1263+
return False
1264+
own_slots = {
1265+
slot
1266+
for slot in info.slots
1267+
if not any(
1268+
base_info.type.slots is not None and slot in base_info.type.slots
1269+
for base_info in info.bases
1270+
)
1271+
}
1272+
return bool(own_slots)
1273+
1274+
1275+
def _get_disjoint_base_of(instance: Instance) -> TypeInfo | None:
1276+
"""Returns the disjoint base of the given instance, if it exists."""
1277+
if _is_disjoint_base(instance.type):
1278+
return instance.type
1279+
for base in instance.type.mro:
1280+
if _is_disjoint_base(base):
1281+
return base
1282+
return None
1283+
1284+
1285+
def can_have_shared_disjoint_base(instances: list[Instance]) -> bool:
1286+
"""Returns whether the given instances can share a disjoint base.
1287+
1288+
This means that a child class of these classes can exist at runtime.
1289+
"""
1290+
# Ignore None disjoint bases (which are `object`).
1291+
disjoint_bases = [
1292+
base for instance in instances if (base := _get_disjoint_base_of(instance)) is not None
1293+
]
1294+
if not disjoint_bases:
1295+
# All are `object`.
1296+
return True
1297+
1298+
candidate = disjoint_bases[0]
1299+
for base in disjoint_bases[1:]:
1300+
if candidate.has_base(base.fullname):
1301+
continue
1302+
elif base.has_base(candidate.fullname):
1303+
candidate = base
1304+
else:
1305+
return False
1306+
return True

mypy/types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,12 @@
119119
# Supported Unpack type names.
120120
UNPACK_TYPE_NAMES: Final = ("typing.Unpack", "typing_extensions.Unpack")
121121

122-
# Supported @deprecated type names
122+
# Supported @deprecated decorator names
123123
DEPRECATED_TYPE_NAMES: Final = ("warnings.deprecated", "typing_extensions.deprecated")
124124

125+
# Supported @disjoint_base decorator names
126+
DISJOINT_BASE_DECORATOR_NAMES: Final = ("typing.disjoint_base", "typing_extensions.disjoint_base")
127+
125128
# We use this constant in various places when checking `tuple` subtyping:
126129
TUPLE_LIKE_INSTANCE_NAMES: Final = (
127130
"builtins.tuple",

test-data/unit/check-classes.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6084,7 +6084,7 @@ class B(A):
60846084
__slots__ = ('a', 'b')
60856085
class C:
60866086
__slots__ = ('x',)
6087-
class D(B, C):
6087+
class D(B, C): # E: Class "D" has incompatible disjoint bases
60886088
__slots__ = ('aa', 'bb', 'cc')
60896089
[builtins fixtures/tuple.pyi]
60906090

0 commit comments

Comments
 (0)