Skip to content

Commit 20355d5

Browse files
authored
[stubgen] Preserve dataclass_transform decorator (#18418)
Ref: #18081
1 parent 306c1af commit 20355d5

File tree

2 files changed

+148
-11
lines changed

2 files changed

+148
-11
lines changed

mypy/stubgen.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
Var,
114114
)
115115
from mypy.options import Options as MypyOptions
116+
from mypy.semanal_shared import find_dataclass_transform_spec
116117
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
117118
from mypy.stubdoc import ArgSig, FunctionSig
118119
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
@@ -139,6 +140,7 @@
139140
has_yield_from_expression,
140141
)
141142
from mypy.types import (
143+
DATACLASS_TRANSFORM_NAMES,
142144
OVERLOAD_NAMES,
143145
TPDICT_NAMES,
144146
TYPED_NAMEDTUPLE_NAMES,
@@ -701,10 +703,13 @@ def process_decorator(self, o: Decorator) -> None:
701703
"""
702704
o.func.is_overload = False
703705
for decorator in o.original_decorators:
704-
if not isinstance(decorator, (NameExpr, MemberExpr)):
706+
d = decorator
707+
if isinstance(d, CallExpr):
708+
d = d.callee
709+
if not isinstance(d, (NameExpr, MemberExpr)):
705710
continue
706-
qualname = get_qualified_name(decorator)
707-
fullname = self.get_fullname(decorator)
711+
qualname = get_qualified_name(d)
712+
fullname = self.get_fullname(d)
708713
if fullname in (
709714
"builtins.property",
710715
"builtins.staticmethod",
@@ -739,6 +744,9 @@ def process_decorator(self, o: Decorator) -> None:
739744
o.func.is_overload = True
740745
elif qualname.endswith((".setter", ".deleter")):
741746
self.add_decorator(qualname, require_name=False)
747+
elif fullname in DATACLASS_TRANSFORM_NAMES:
748+
p = AliasPrinter(self)
749+
self._decorators.append(f"@{decorator.accept(p)}")
742750

743751
def get_fullname(self, expr: Expression) -> str:
744752
"""Return the expression's full name."""
@@ -785,6 +793,8 @@ def visit_class_def(self, o: ClassDef) -> None:
785793
self.add(f"{self._indent}{docstring}\n")
786794
n = len(self._output)
787795
self._vars.append([])
796+
if self.analyzed and find_dataclass_transform_spec(o):
797+
self.processing_dataclass = True
788798
super().visit_class_def(o)
789799
self.dedent()
790800
self._vars.pop()
@@ -854,13 +864,26 @@ def get_class_decorators(self, cdef: ClassDef) -> list[str]:
854864
decorators.append(d.accept(p))
855865
self.import_tracker.require_name(get_qualified_name(d))
856866
self.processing_dataclass = True
867+
if self.is_dataclass_transform(d):
868+
decorators.append(d.accept(p))
869+
self.import_tracker.require_name(get_qualified_name(d))
857870
return decorators
858871

859872
def is_dataclass(self, expr: Expression) -> bool:
860873
if isinstance(expr, CallExpr):
861874
expr = expr.callee
862875
return self.get_fullname(expr) == "dataclasses.dataclass"
863876

877+
def is_dataclass_transform(self, expr: Expression) -> bool:
878+
if isinstance(expr, CallExpr):
879+
expr = expr.callee
880+
if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES:
881+
return True
882+
if find_dataclass_transform_spec(expr) is not None:
883+
self.processing_dataclass = True
884+
return True
885+
return False
886+
864887
def visit_block(self, o: Block) -> None:
865888
# Unreachable statements may be partially uninitialized and that may
866889
# cause trouble.

test-data/unit/stubgen.test

Lines changed: 122 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3104,15 +3104,12 @@ class C:
31043104
x = attrs.field()
31053105

31063106
[out]
3107-
from _typeshed import Incomplete
3107+
import attrs
31083108

3109+
@attrs.define
31093110
class C:
3110-
x: Incomplete
3111+
x = ...
31113112
def __init__(self, x) -> None: ...
3112-
def __lt__(self, other): ...
3113-
def __le__(self, other): ...
3114-
def __gt__(self, other): ...
3115-
def __ge__(self, other): ...
31163113

31173114
[case testNamedTupleInClass]
31183115
from collections import namedtuple
@@ -4249,6 +4246,122 @@ class Y(missing.Base):
42494246
generated_kwargs_: float
42504247
def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ...
42514248

4249+
[case testDataclassTransform]
4250+
# dataclass_transform detection only works with sementic analysis.
4251+
# Test stubgen doesn't break too badly without it.
4252+
from typing_extensions import dataclass_transform
4253+
4254+
@typing_extensions.dataclass_transform(kw_only_default=True)
4255+
def create_model(cls):
4256+
return cls
4257+
4258+
@create_model
4259+
class X:
4260+
a: int
4261+
b: str = "hello"
4262+
4263+
@typing_extensions.dataclass_transform(kw_only_default=True)
4264+
class ModelBase: ...
4265+
4266+
class Y(ModelBase):
4267+
a: int
4268+
b: str = "hello"
4269+
4270+
@typing_extensions.dataclass_transform(kw_only_default=True)
4271+
class DCMeta(type): ...
4272+
4273+
class Z(metaclass=DCMeta):
4274+
a: int
4275+
b: str = "hello"
4276+
4277+
[out]
4278+
@typing_extensions.dataclass_transform(kw_only_default=True)
4279+
def create_model(cls): ...
4280+
4281+
class X:
4282+
a: int
4283+
b: str
4284+
4285+
@typing_extensions.dataclass_transform(kw_only_default=True)
4286+
class ModelBase: ...
4287+
4288+
class Y(ModelBase):
4289+
a: int
4290+
b: str
4291+
4292+
@typing_extensions.dataclass_transform(kw_only_default=True)
4293+
class DCMeta(type): ...
4294+
4295+
class Z(metaclass=DCMeta):
4296+
a: int
4297+
b: str
4298+
4299+
[case testDataclassTransformDecorator_semanal]
4300+
import typing_extensions
4301+
4302+
@typing_extensions.dataclass_transform(kw_only_default=True)
4303+
def create_model(cls):
4304+
return cls
4305+
4306+
@create_model
4307+
class X:
4308+
a: int
4309+
b: str = "hello"
4310+
4311+
[out]
4312+
import typing_extensions
4313+
4314+
@typing_extensions.dataclass_transform(kw_only_default=True)
4315+
def create_model(cls): ...
4316+
4317+
@create_model
4318+
class X:
4319+
a: int
4320+
b: str = ...
4321+
def __init__(self, *, a, b=...) -> None: ...
4322+
4323+
[case testDataclassTransformClass_semanal]
4324+
from typing_extensions import dataclass_transform
4325+
4326+
@dataclass_transform(kw_only_default=True)
4327+
class ModelBase: ...
4328+
4329+
class X(ModelBase):
4330+
a: int
4331+
b: str = "hello"
4332+
4333+
[out]
4334+
from typing_extensions import dataclass_transform
4335+
4336+
@dataclass_transform(kw_only_default=True)
4337+
class ModelBase: ...
4338+
4339+
class X(ModelBase):
4340+
a: int
4341+
b: str = ...
4342+
def __init__(self, *, a, b=...) -> None: ...
4343+
4344+
[case testDataclassTransformMetaclass_semanal]
4345+
from typing_extensions import dataclass_transform
4346+
4347+
@dataclass_transform(kw_only_default=True)
4348+
class DCMeta(type): ...
4349+
4350+
class X(metaclass=DCMeta):
4351+
a: int
4352+
b: str = "hello"
4353+
4354+
[out]
4355+
from typing_extensions import dataclass_transform
4356+
4357+
@dataclass_transform(kw_only_default=True)
4358+
class DCMeta(type): ...
4359+
4360+
class X(metaclass=DCMeta):
4361+
a: int
4362+
b: str = ...
4363+
def __init__(self, *, a, b=...) -> None: ...
4364+
42524365
[case testAlwaysUsePEP604Union]
42534366
import typing
42544367
import typing as t
@@ -4536,16 +4649,17 @@ def f5[T5 = int]() -> None: ...
45364649
# flags: --include-private --python-version=3.13
45374650
from typing_extensions import dataclass_transform
45384651

4539-
# TODO: preserve dataclass_transform decorator
45404652
@dataclass_transform()
45414653
class DCMeta(type): ...
45424654
class DC(metaclass=DCMeta):
45434655
x: str
45444656

45454657
[out]
4658+
from typing_extensions import dataclass_transform
4659+
4660+
@dataclass_transform()
45464661
class DCMeta(type): ...
45474662

45484663
class DC(metaclass=DCMeta):
45494664
x: str
45504665
def __init__(self, x) -> None: ...
4551-
def __replace__(self, *, x) -> None: ...

0 commit comments

Comments
 (0)