Skip to content

Commit 8bd9d50

Browse files
committed
[stubgen] Preserve dataclass_transform decorator
1 parent bac9984 commit 8bd9d50

File tree

2 files changed

+98
-11
lines changed

2 files changed

+98
-11
lines changed

mypy/stubgen.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@
113113
Var,
114114
)
115115
from mypy.options import Options as MypyOptions
116+
from mypy.semanal_shared import find_dataclass_transform_spec
116117
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
117118
from mypy.stubdoc import ArgSig, FunctionSig
118119
from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module
@@ -139,6 +140,7 @@
139140
has_yield_from_expression,
140141
)
141142
from mypy.types import (
143+
DATACLASS_TRANSFORM_NAMES,
142144
OVERLOAD_NAMES,
143145
TPDICT_NAMES,
144146
TYPED_NAMEDTUPLE_NAMES,
@@ -701,10 +703,13 @@ def process_decorator(self, o: Decorator) -> None:
701703
"""
702704
o.func.is_overload = False
703705
for decorator in o.original_decorators:
704-
if not isinstance(decorator, (NameExpr, MemberExpr)):
706+
d = decorator
707+
if isinstance(d, CallExpr):
708+
d = d.callee
709+
if not isinstance(d, (NameExpr, MemberExpr)):
705710
continue
706-
qualname = get_qualified_name(decorator)
707-
fullname = self.get_fullname(decorator)
711+
qualname = get_qualified_name(d)
712+
fullname = self.get_fullname(d)
708713
if fullname in (
709714
"builtins.property",
710715
"builtins.staticmethod",
@@ -739,6 +744,9 @@ def process_decorator(self, o: Decorator) -> None:
739744
o.func.is_overload = True
740745
elif qualname.endswith((".setter", ".deleter")):
741746
self.add_decorator(qualname, require_name=False)
747+
elif fullname in DATACLASS_TRANSFORM_NAMES:
748+
p = AliasPrinter(self)
749+
self._decorators.append(f"@{decorator.accept(p)}")
742750

743751
def get_fullname(self, expr: Expression) -> str:
744752
"""Return the expression's full name."""
@@ -785,6 +793,8 @@ def visit_class_def(self, o: ClassDef) -> None:
785793
self.add(f"{self._indent}{docstring}\n")
786794
n = len(self._output)
787795
self._vars.append([])
796+
if self.analyzed and find_dataclass_transform_spec(o):
797+
self.processing_dataclass = True
788798
super().visit_class_def(o)
789799
self.dedent()
790800
self._vars.pop()
@@ -854,13 +864,26 @@ def get_class_decorators(self, cdef: ClassDef) -> list[str]:
854864
decorators.append(d.accept(p))
855865
self.import_tracker.require_name(get_qualified_name(d))
856866
self.processing_dataclass = True
867+
if self.is_dataclass_transform(d):
868+
decorators.append(d.accept(p))
869+
self.import_tracker.require_name(get_qualified_name(d))
857870
return decorators
858871

859872
def is_dataclass(self, expr: Expression) -> bool:
860873
if isinstance(expr, CallExpr):
861874
expr = expr.callee
862875
return self.get_fullname(expr) == "dataclasses.dataclass"
863876

877+
def is_dataclass_transform(self, expr: Expression) -> bool:
878+
if isinstance(expr, CallExpr):
879+
expr = expr.callee
880+
if self.get_fullname(expr) in DATACLASS_TRANSFORM_NAMES:
881+
return True
882+
if find_dataclass_transform_spec(expr) is not None:
883+
self.processing_dataclass = True
884+
return True
885+
return False
886+
864887
def visit_block(self, o: Block) -> None:
865888
# Unreachable statements may be partially uninitialized and that may
866889
# cause trouble.

test-data/unit/stubgen.test

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,15 +3081,12 @@ class C:
30813081
x = attrs.field()
30823082

30833083
[out]
3084-
from _typeshed import Incomplete
3084+
import attrs
30853085

3086+
@attrs.define
30863087
class C:
3087-
x: Incomplete
3088+
x = ...
30883089
def __init__(self, x) -> None: ...
3089-
def __lt__(self, other): ...
3090-
def __le__(self, other): ...
3091-
def __gt__(self, other): ...
3092-
def __ge__(self, other): ...
30933090

30943091
[case testNamedTupleInClass]
30953092
from collections import namedtuple
@@ -4226,6 +4223,72 @@ class Y(missing.Base):
42264223
generated_kwargs_: float
42274224
def __init__(self, *generated_args__, generated_args, generated_args_, generated_kwargs, generated_kwargs_, **generated_kwargs__) -> None: ...
42284225

4226+
[case testDataclassTransformDecorator_semanal]
4227+
import typing_extensions
4228+
4229+
@typing_extensions.dataclass_transform(kw_only_default=True)
4230+
def create_model(cls):
4231+
return cls
4232+
4233+
@create_model
4234+
class X:
4235+
a: int
4236+
b: str = "hello"
4237+
4238+
[out]
4239+
import typing_extensions
4240+
4241+
@typing_extensions.dataclass_transform(kw_only_default=True)
4242+
def create_model(cls): ...
4243+
4244+
@create_model
4245+
class X:
4246+
a: int
4247+
b: str = ...
4248+
def __init__(self, *, a, b=...) -> None: ...
4249+
4250+
[case testDataclassTransformClass_semanal]
4251+
from typing_extensions import dataclass_transform
4252+
4253+
@dataclass_transform(kw_only_default=True)
4254+
class ModelBase: ...
4255+
4256+
class X(ModelBase):
4257+
a: int
4258+
b: str = "hello"
4259+
4260+
[out]
4261+
from typing_extensions import dataclass_transform
4262+
4263+
@dataclass_transform(kw_only_default=True)
4264+
class ModelBase: ...
4265+
4266+
class X(ModelBase):
4267+
a: int
4268+
b: str = ...
4269+
def __init__(self, *, a, b=...) -> None: ...
4270+
4271+
[case testDataclassTransformMetaclass_semanal]
4272+
from typing_extensions import dataclass_transform
4273+
4274+
@dataclass_transform(kw_only_default=True)
4275+
class DCMeta(type): ...
4276+
4277+
class X(metaclass=DCMeta):
4278+
a: int
4279+
b: str = "hello"
4280+
4281+
[out]
4282+
from typing_extensions import dataclass_transform
4283+
4284+
@dataclass_transform(kw_only_default=True)
4285+
class DCMeta(type): ...
4286+
4287+
class X(metaclass=DCMeta):
4288+
a: int
4289+
b: str = ...
4290+
def __init__(self, *, a, b=...) -> None: ...
4291+
42294292
[case testAlwaysUsePEP604Union]
42304293
import typing
42314294
import typing as t
@@ -4513,16 +4576,17 @@ def f5[T5 = int]() -> None: ...
45134576
# flags: --include-private --python-version=3.13
45144577
from typing_extensions import dataclass_transform
45154578

4516-
# TODO: preserve dataclass_transform decorator
45174579
@dataclass_transform()
45184580
class DCMeta(type): ...
45194581
class DC(metaclass=DCMeta):
45204582
x: str
45214583

45224584
[out]
4585+
from typing_extensions import dataclass_transform
4586+
4587+
@dataclass_transform()
45234588
class DCMeta(type): ...
45244589

45254590
class DC(metaclass=DCMeta):
45264591
x: str
45274592
def __init__(self, x) -> None: ...
4528-
def __replace__(self, *, x) -> None: ...

0 commit comments

Comments
 (0)