Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,28 +506,57 @@ def _is_disjoint_base(typ: type[object]) -> bool:
def _verify_disjoint_base(
stub: nodes.TypeInfo, runtime: type[object], object_path: list[str]
) -> Iterator[Error]:
# If it's final, doesn't matter whether it's a disjoint base or not
if stub.is_final:
return
is_disjoint_runtime = _is_disjoint_base(runtime)
# Don't complain about missing @disjoint_base if there are __slots__, because
# in that case we can infer that it's a disjoint base.
if is_disjoint_runtime and not stub.is_disjoint_base and not runtime.__dict__.get("__slots__"):
if (
is_disjoint_runtime
and not stub.is_disjoint_base
and not runtime.__dict__.get("__slots__")
and not stub.is_final
and not (stub.is_enum and stub.enum_members)
):
yield Error(
object_path,
"is a disjoint base at runtime, but isn't marked with @disjoint_base in the stub",
stub,
runtime,
stub_desc=repr(stub),
)
elif not is_disjoint_runtime and stub.is_disjoint_base:
yield Error(
object_path,
"is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime",
stub,
runtime,
stub_desc=repr(stub),
)
elif stub.is_disjoint_base:
if not is_disjoint_runtime:
yield Error(
object_path,
"is marked with @disjoint_base in the stub, but isn't a disjoint base at runtime",
stub,
runtime,
stub_desc=repr(stub),
)
if runtime.__dict__.get("__slots__"):
yield Error(
object_path,
"is marked as @disjoint_base, but also has slots; add __slots__ instead",
stub,
runtime,
stub_desc=repr(stub),
)
elif stub.is_final:
yield Error(
object_path,
"is marked as @disjoint_base, but also marked as @final; remove @disjoint_base",
stub,
runtime,
stub_desc=repr(stub),
)
elif stub.is_enum and stub.enum_members:
yield Error(
object_path,
"is marked as @disjoint_base, but is an enum with members, which is implicitly final; "
"remove @disjoint_base",
stub,
runtime,
stub_desc=repr(stub),
)
Comment on lines +543 to +559
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these feel like linter issues more than correctness issues (so sort-of outside of stubtest's normal purview?), but it's obviously very easy to check for them here, so I think it makes sense

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we could also add these checks to flake8-pyi/ruff/whatever.



def _verify_metaclass(
Expand Down
58 changes: 50 additions & 8 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,14 +1407,9 @@ def spam(x=Flags4(0)): pass
stub="""
import sys
from typing import Final, Literal
from typing_extensions import disjoint_base
if sys.version_info >= (3, 12):
class BytesEnum(bytes, enum.Enum):
a = b'foo'
else:
@disjoint_base
class BytesEnum(bytes, enum.Enum):
a = b'foo'
class BytesEnum(bytes, enum.Enum):
a = b'foo'
Comment on lines -1410 to +1411
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one way to keep the test as it was before would be to just have it be an "abstract enum" that doesn't have any members

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I complicated this test in a previous PR for disjoint_base, so I don't feel bad about going back to the previous less complicated version.


FOO: Literal[BytesEnum.a]
BAR: Final = BytesEnum.a
BAZ: BytesEnum
Expand Down Expand Up @@ -1698,6 +1693,53 @@ def __next__(self) -> object: ...
""",
error=None,
)
yield Case(
runtime="""
class IsDisjointBaseBecauseItHasSlots:
__slots__ = ("a",)
a: int
""",
stub="""
from typing_extensions import disjoint_base

@disjoint_base
class IsDisjointBaseBecauseItHasSlots:
a: int
""",
error="test_module.IsDisjointBaseBecauseItHasSlots",
)
yield Case(
runtime="""
class IsFinalSoDisjointBaseIsRedundant: ...
""",
stub="""
from typing_extensions import disjoint_base, final

@final
@disjoint_base
class IsFinalSoDisjointBaseIsRedundant: ...
""",
error="test_module.IsFinalSoDisjointBaseIsRedundant",
)
yield Case(
runtime="""
import enum

class IsEnumSoDisjointBaseIsRedundant(enum.Enum):
A = 1
B = 2
""",
stub="""
from typing_extensions import disjoint_base
import enum

@disjoint_base
class IsEnumSoDisjointBaseIsRedundant(enum.Enum):
A = 1
B = 2
""",
error="test_module.IsEnumSoDisjointBaseIsRedundant",
)

@collect_cases
def test_has_runtime_final_decorator(self) -> Iterator[Case]:
Expand Down
Loading