Skip to content

Commit 3ef3702

Browse files
committed
Implement PEP 800 (@disjoint_base)
https://peps.python.org/pep-0800/ - Recognize the @disjoint_base decorator - Error if a class definition has incompatible disjoint bases - Recognize that classes with incompatible disjoint bases cannot exist - Check in stubtest that @disjoint_base is correctly applied
1 parent 3fcfcb8 commit 3ef3702

File tree

11 files changed

+246
-1
lines changed

11 files changed

+246
-1
lines changed

mypy/checker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@
171171
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type
172172
from mypy.typeops import (
173173
bind_self,
174+
can_have_shared_disjoint_base,
174175
coerce_to_literal,
175176
custom_special_method,
176177
erase_def_to_union_or_bound,
@@ -2658,6 +2659,8 @@ def visit_class_def(self, defn: ClassDef) -> None:
26582659
for base in typ.mro[1:]:
26592660
if base.is_final:
26602661
self.fail(message_registry.CANNOT_INHERIT_FROM_FINAL.format(base.name), defn)
2662+
if not can_have_shared_disjoint_base(typ.bases):
2663+
self.fail(message_registry.INCOMPATIBLE_DISJOINT_BASES.format(typ.name), defn)
26612664
with self.tscope.class_scope(defn.info), self.enter_partial_types(is_class=True):
26622665
old_binder = self.binder
26632666
self.binder = ConditionalTypeBinder(self.options)
@@ -5826,6 +5829,10 @@ def _make_fake_typeinfo_and_full_name(
58265829
format_type_distinctly(*base_classes, options=self.options, bare=True), "and"
58275830
)
58285831

5832+
if not can_have_shared_disjoint_base(base_classes):
5833+
errors.append((pretty_names_list, "have distinct disjoint bases"))
5834+
return None
5835+
58295836
new_errors = []
58305837
for base in base_classes:
58315838
if base.type.is_final:

mypy/message_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,11 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
239239
)
240240
CANNOT_MAKE_DELETABLE_FINAL: Final = ErrorMessage("Deletable attribute cannot be final")
241241

242+
# Disjoint bases
243+
INCOMPATIBLE_DISJOINT_BASES: Final = ErrorMessage(
244+
'Class "{}" has incompatible disjoint bases'
245+
)
246+
242247
# Enum
243248
ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDDEN: Final = ErrorMessage(
244249
'Assigned "__members__" will be overridden by "Enum" internally'

mypy/nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,6 +3004,7 @@ class is generic then it will be a type constructor of higher kind.
30043004
"_mro_refs",
30053005
"bad_mro",
30063006
"is_final",
3007+
"is_disjoint_base",
30073008
"declared_metaclass",
30083009
"metaclass_type",
30093010
"names",
@@ -3055,6 +3056,7 @@ class is generic then it will be a type constructor of higher kind.
30553056
_mro_refs: list[str] | None
30563057
bad_mro: bool # Could not construct full MRO
30573058
is_final: bool
3059+
is_disjoint_base: bool
30583060

30593061
declared_metaclass: mypy.types.Instance | None
30603062
metaclass_type: mypy.types.Instance | None
@@ -3209,6 +3211,7 @@ class is generic then it will be a type constructor of higher kind.
32093211
"is_protocol",
32103212
"runtime_protocol",
32113213
"is_final",
3214+
"is_disjoint_base",
32123215
"is_intersection",
32133216
]
32143217

@@ -3241,6 +3244,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None
32413244
self.type_var_tuple_suffix: int | None = None
32423245
self.add_type_vars()
32433246
self.is_final = False
3247+
self.is_disjoint_base = False
32443248
self.is_enum = False
32453249
self.fallback_to_any = False
32463250
self.meta_fallback_to_any = False

mypy/semanal.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@
254254
ASSERT_TYPE_NAMES,
255255
DATACLASS_TRANSFORM_NAMES,
256256
DEPRECATED_TYPE_NAMES,
257+
DISJOINT_BASE_DECORATOR_NAMES,
257258
FINAL_DECORATOR_NAMES,
258259
FINAL_TYPE_NAMES,
259260
IMPORTED_REVEAL_TYPE_NAMES,
@@ -2188,6 +2189,8 @@ def analyze_class_decorator_common(
21882189
"""
21892190
if refers_to_fullname(decorator, FINAL_DECORATOR_NAMES):
21902191
info.is_final = True
2192+
elif refers_to_fullname(decorator, DISJOINT_BASE_DECORATOR_NAMES):
2193+
info.is_disjoint_base = True
21912194
elif refers_to_fullname(decorator, TYPE_CHECK_ONLY_NAMES):
21922195
info.is_type_check_only = True
21932196
elif (deprecated := self.get_deprecated(decorator)) is not None:

mypy/stubtest.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import pkgutil
1919
import re
20+
import struct
2021
import symtable
2122
import sys
2223
import traceback
@@ -465,6 +466,82 @@ class SubClass(runtime): # type: ignore[misc]
465466
stub_desc=repr(stub),
466467
)
467468

469+
SIZEOF_PYOBJECT = struct.calcsize("P")
470+
471+
472+
def _shape_differs(t1: type[object], t2: type[object]) -> bool:
473+
"""Check whether two types differ in shape.
474+
475+
Mirrors the shape_differs() function in typeobject.c in CPython."""
476+
if sys.version_info >= (3, 12):
477+
return (
478+
t1.__basicsize__ != t2.__basicsize__ or t1.__itemsize__ != t2.__itemsize__
479+
)
480+
else:
481+
# CPython had more complicated logic before 3.12:
482+
# https://github.com/python/cpython/blob/f3c6f882cddc8dc30320d2e73edf019e201394fc/Objects/typeobject.c#L2224
483+
# We attempt to mirror it here well enough to support the most common cases.
484+
if t1.__itemsize__ or t2.__itemsize__:
485+
return (
486+
t1.__basicsize__ != t2.__basicsize__
487+
or t1.__itemsize__ != t2.__itemsize__
488+
)
489+
t_size = t1.__basicsize__
490+
if (
491+
# TODO: better understanding of attributes of type objects
492+
not t2.__weakrefoffset__ # static analysis: ignore[value_always_true]
493+
and t1.__weakrefoffset__ + SIZEOF_PYOBJECT == t_size
494+
):
495+
t_size -= SIZEOF_PYOBJECT
496+
if (
497+
not t2.__dictoffset__ # static analysis: ignore[value_always_true]
498+
and t1.__dictoffset__ + SIZEOF_PYOBJECT == t_size
499+
):
500+
t_size -= SIZEOF_PYOBJECT
501+
if (
502+
not t2.__weakrefoffset__ # static analysis: ignore[value_always_true]
503+
and t2.__weakrefoffset__ == t_size
504+
):
505+
t_size -= SIZEOF_PYOBJECT
506+
return t_size != t2.__basicsize__
507+
508+
509+
def _is_disjoint_base(typ: type[object]) -> bool:
510+
"""Return whether a type is a disjoint base at runtime, mirroring CPython's logic in typeobject.c.
511+
512+
See PEP 800."""
513+
if typ is object:
514+
return True
515+
base = typ.__base__
516+
assert base is not None, f"Type {typ} has no base"
517+
return _shape_differs(typ, base)
518+
519+
520+
521+
def _verify_disjoint_base(
522+
stub: nodes.TypeInfo, runtime: type[object], object_path: list[str]
523+
) -> Iterator[Error]:
524+
# If it's final, doesn't matter whether it's a disjoint base or not
525+
if stub.is_final:
526+
return
527+
is_disjoint_runtime = _is_disjoint_base(runtime)
528+
if is_disjoint_runtime and not stub.is_disjoint_base:
529+
yield Error(
530+
object_path,
531+
"is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub",
532+
stub,
533+
runtime,
534+
stub_desc=repr(stub),
535+
)
536+
elif not is_disjoint_runtime and stub.is_disjoint_base:
537+
yield Error(
538+
object_path,
539+
"is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime",
540+
stub,
541+
runtime,
542+
stub_desc=repr(stub),
543+
)
544+
468545

469546
def _verify_metaclass(
470547
stub: nodes.TypeInfo, runtime: type[Any], object_path: list[str], *, is_runtime_typeddict: bool
@@ -534,6 +611,7 @@ def verify_typeinfo(
534611
return
535612

536613
yield from _verify_final(stub, runtime, object_path)
614+
yield from _verify_disjoint_base(stub, runtime, object_path)
537615
is_runtime_typeddict = stub.typeddict_type is not None and is_typeddict(runtime)
538616
yield from _verify_metaclass(
539617
stub, runtime, object_path, is_runtime_typeddict=is_runtime_typeddict

mypy/test/teststubtest.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,60 @@ def test_not_subclassable(self) -> Iterator[Case]:
16131613
error="CannotBeSubclassed",
16141614
)
16151615

1616+
@collect_cases
1617+
def test_disjoint_base(self) -> Iterator[Case]:
1618+
yield Case(
1619+
stub="""
1620+
class A: pass
1621+
""",
1622+
runtime="""
1623+
class A: pass
1624+
""",
1625+
error=None,
1626+
)
1627+
yield Case(
1628+
stub="""
1629+
from typing_extensions import disjoint_base
1630+
1631+
@disjoint_base
1632+
class B: pass
1633+
""",
1634+
runtime="""
1635+
class B: pass
1636+
""",
1637+
error="test_module.B",
1638+
)
1639+
yield Case(
1640+
stub="""
1641+
from typing_extensions import Self
1642+
1643+
class mytakewhile:
1644+
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
1645+
def __iter__(self) -> Self: ...
1646+
def __next__(self) -> object: ...
1647+
""",
1648+
runtime="""
1649+
from itertools import takewhile as mytakewhile
1650+
""",
1651+
# Should have @disjoint_base
1652+
error="test_module.mytakewhile",
1653+
)
1654+
yield Case(
1655+
stub="""
1656+
from typing_extensions import disjoint_base, Self
1657+
1658+
@disjoint_base
1659+
class mycorrecttakewhile:
1660+
def __new__(cls, predicate: object, iterable: object, /) -> Self: ...
1661+
def __iter__(self) -> Self: ...
1662+
def __next__(self) -> object: ...
1663+
""",
1664+
runtime="""
1665+
from itertools import takewhile as mycorrecttakewhile
1666+
""",
1667+
error=None,
1668+
)
1669+
16161670
@collect_cases
16171671
def test_has_runtime_final_decorator(self) -> Iterator[Case]:
16181672
yield Case(

mypy/typeops.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,3 +1253,42 @@ def named_type(fullname: str) -> Instance:
12531253
)
12541254
)
12551255
return subtype
1256+
1257+
1258+
def _is_disjoint_base(info: TypeInfo) -> bool:
1259+
# It either has the @disjoint_base decorator or defines nonempty __slots__.
1260+
return info.is_disjoint_base or bool(info.slots)
1261+
1262+
1263+
def _get_disjoint_base_of(instance: Instance) -> TypeInfo | None:
1264+
"""Returns the disjoint base of the given instance, if it exists."""
1265+
if _is_disjoint_base(instance.type):
1266+
return instance.type
1267+
for base in instance.type.mro:
1268+
if _is_disjoint_base(base):
1269+
return base
1270+
return None
1271+
1272+
1273+
def can_have_shared_disjoint_base(instances: list[Instance]) -> bool:
1274+
"""Returns whether the given instances can share a disjoint base.
1275+
1276+
This means that a child class of these classes can exist at runtime.
1277+
"""
1278+
# Ignore None disjoint bases (which are `object`).
1279+
disjoint_bases = [
1280+
base for instance in instances if (base := _get_disjoint_base_of(instance)) is not None
1281+
]
1282+
if not disjoint_bases:
1283+
# All are `object`.
1284+
return True
1285+
1286+
candidate = disjoint_bases[0]
1287+
for base in disjoint_bases[1:]:
1288+
if candidate.has_base(base.fullname):
1289+
continue
1290+
elif base.has_base(candidate.fullname):
1291+
candidate = base
1292+
else:
1293+
return False
1294+
return True

mypy/types.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,15 @@
119119
# Supported Unpack type names.
120120
UNPACK_TYPE_NAMES: Final = ("typing.Unpack", "typing_extensions.Unpack")
121121

122-
# Supported @deprecated type names
122+
# Supported @deprecated decorator names
123123
DEPRECATED_TYPE_NAMES: Final = ("warnings.deprecated", "typing_extensions.deprecated")
124124

125+
# Supported @disjoint_base decorator names
126+
DISJOINT_BASE_DECORATOR_NAMES: Final = (
127+
"typing.disjoint_base",
128+
"typing_extensions.disjoint_base",
129+
)
130+
125131
# We use this constant in various places when checking `tuple` subtyping:
126132
TUPLE_LIKE_INSTANCE_NAMES: Final = (
127133
"builtins.tuple",

test-data/unit/check-final.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,24 @@ if FOO is not None:
12721272

12731273
def func() -> int:
12741274
return FOO
1275+
1276+
[case testDisjointBase]
1277+
from typing_extensions import disjoint_base
1278+
1279+
@disjoint_base
1280+
class Disjoint1: pass
1281+
1282+
@disjoint_base
1283+
class Disjoint2: pass
1284+
1285+
@disjoint_base
1286+
class DisjointChild(Disjoint1): pass
1287+
1288+
class C1: pass
1289+
class C2(Disjoint1, C1): pass
1290+
class C3(DisjointChild, Disjoint1): pass
1291+
1292+
class C4(Disjoint1, Disjoint2): # E: Class "C4" has incompatible disjoint bases
1293+
pass
1294+
1295+
[builtins fixtures/tuple.pyi]

test-data/unit/check-isinstance.test

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,6 +2551,33 @@ def f2(x: T2) -> T2:
25512551
return C()
25522552
[builtins fixtures/isinstance.pyi]
25532553

2554+
[case testIsInstanceDisjointBase]
2555+
# flags: --warn-unreachable
2556+
from typing_extensions import disjoint_base
2557+
2558+
@disjoint_base
2559+
class Disjoint1: pass
2560+
@disjoint_base
2561+
class Disjoint2: pass
2562+
@disjoint_base
2563+
class Disjoint3(Disjoint1): pass
2564+
class Child(Disjoint1): pass
2565+
class Unrelated: pass
2566+
2567+
def f(d1: Disjoint1, u: Unrelated, c: Child) -> object:
2568+
if isinstance(d1, Disjoint2): # E: Subclass of "Disjoint1" and "Disjoint2" cannot exist: have distinct disjoint bases
2569+
return u # E: Statement is unreachable
2570+
if isinstance(u, Disjoint1): # OK
2571+
return d1
2572+
if isinstance(c, Disjoint3): # OK
2573+
return c
2574+
if isinstance(c, Disjoint2): # E: Subclass of "Child" and "Disjoint2" cannot exist: have distinct disjoint bases
2575+
return c # E: Statement is unreachable
2576+
return d1
2577+
2578+
[builtins fixtures/isinstance.pyi]
2579+
2580+
25542581
[case testIsInstanceAdHocIntersectionUsage]
25552582
# flags: --warn-unreachable
25562583
class A: pass

0 commit comments

Comments
 (0)