Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5820f94
replace source annotations in init with attached annotate function
DavidCEllis Aug 12, 2025
d0c5d3b
Include the new test for init annotations in all formats
DavidCEllis Aug 13, 2025
b340a2f
Add an extra check for init=False attributes
DavidCEllis Aug 13, 2025
c7a4c56
Test that defining a forwardref makes init annotations resolve
DavidCEllis Aug 13, 2025
3a3b6c0
move _make_annotate_function to top level
DavidCEllis Aug 13, 2025
13b11c1
stop unpacking a dictionary into another dictionary
DavidCEllis Aug 13, 2025
dca9a5b
Get better 'Source' annotations - break slots gc tests
DavidCEllis Aug 15, 2025
a4fe18a
Add an extra GC test for forwardrefs
DavidCEllis Aug 15, 2025
7795c9b
Remove references to original class from fields and annotate functions
DavidCEllis Aug 15, 2025
37d9b3c
Test we don't fix annotations on hand-written init functions
DavidCEllis Aug 15, 2025
864305d
*Actually* test we don't fix annotations on hand-written init functions
DavidCEllis Aug 15, 2025
9c6ed47
Use the same annotation logic for VALUE, FORWARDREF and STRING
DavidCEllis Aug 15, 2025
ab74435
Use a list of fields and return type instead of annotations dictionary.
DavidCEllis Aug 15, 2025
30512bd
remove outdated comment
DavidCEllis Aug 15, 2025
86da9b8
Include comment indicating the issue with VALUE annotations should id…
DavidCEllis Oct 21, 2025
d6c680d
Use a dunder name instead of a single underscore
DavidCEllis Oct 21, 2025
161f3e2
Use the dunder name in the tests too
DavidCEllis Oct 21, 2025
e48d6de
📜🤖 Added by blurb_it.
blurb-it[bot] Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 80 additions & 14 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,11 @@ def __init__(self, globals):
self.locals = {}
self.overwrite_errors = {}
self.unconditional_adds = {}
self.method_annotations = {}

def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
overwrite_error=False, unconditional_add=False, decorator=None):
overwrite_error=False, unconditional_add=False, decorator=None,
annotation_fields=None):
if locals is not None:
self.locals.update(locals)

Expand All @@ -464,16 +466,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,

self.names.append(name)

if return_type is not MISSING:
self.locals[f'__dataclass_{name}_return_type__'] = return_type
return_annotation = f'->__dataclass_{name}_return_type__'
else:
return_annotation = ''
if annotation_fields is not None:
self.method_annotations[name] = (annotation_fields, return_type)

args = ','.join(args)
body = '\n'.join(body)

# Compute the text of the entire function, add it to the text we're generating.
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')

def add_fns_to_class(self, cls):
# The source to all of the functions we're generating.
Expand Down Expand Up @@ -509,6 +509,15 @@ def add_fns_to_class(self, cls):
# Now that we've generated the functions, assign them into cls.
for name, fn in zip(self.names, fns):
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"

try:
annotation_fields, return_type = self.method_annotations[name]
except KeyError:
pass
else:
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
fn.__annotate__ = annotate_fn

if self.unconditional_adds.get(name, False):
setattr(cls, name, fn)
else:
Expand All @@ -524,6 +533,44 @@ def add_fns_to_class(self, cls):
raise TypeError(error_msg)


def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
# Create an __annotate__ function for a dataclass
# Try to return annotations in the same format as they would be
# from a regular __init__ function

def __annotate__(format, /):
Format = annotationlib.Format
match format:
case Format.VALUE | Format.FORWARDREF | Format.STRING:
cls_annotations = {}
for base in reversed(__class__.__mro__):
cls_annotations.update(
annotationlib.get_annotations(base, format=format)
)

new_annotations = {}
for k in annotation_fields:
new_annotations[k] = cls_annotations[k]

if return_type is not MISSING:
if format == Format.STRING:
new_annotations["return"] = annotationlib.type_repr(return_type)
else:
new_annotations["return"] = return_type

return new_annotations

case _:
raise NotImplementedError(format)

# This is a flag for _add_slots to know it needs to regenerate this method
# In order to remove references to the original class when it is replaced
__annotate__._generated_by_dataclasses = True
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"

return __annotate__


def _field_assign(frozen, name, value, self_name):
# If we're a frozen class, then assign to our fields in __init__
# via object.__setattr__. Otherwise, just use a simple
Expand Down Expand Up @@ -612,7 +659,7 @@ def _init_param(f):
elif f.default_factory is not MISSING:
# There's a factory function. Set a marker.
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
return f'{f.name}:__dataclass_type_{f.name}__{default}'
return f'{f.name}{default}'


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
Expand All @@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
raise TypeError(f'non-default argument {f.name!r} '
f'follows default argument {seen_default.name!r}')

locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object,
}
}
annotation_fields = [f.name for f in fields if f.init]

locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object}

body_lines = []
for f in fields:
Expand Down Expand Up @@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
[self_name] + _init_params,
body_lines,
locals=locals,
return_type=None)
return_type=None,
annotation_fields=annotation_fields)


def _frozen_get_del_attr(cls, fields, func_builder):
Expand Down Expand Up @@ -1336,6 +1383,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
break

# Get new annotations to remove references to the original class
# in forward references
newcls_ann = annotationlib.get_annotations(
newcls, format=annotationlib.Format.FORWARDREF)

# Fix references in dataclass Fields
for f in getattr(newcls, _FIELDS).values():
try:
ann = newcls_ann[f.name]
except KeyError:
pass
else:
f.type = ann

# Fix the class reference in the __annotate__ method
init_annotate = newcls.__init__.__annotate__
if getattr(init_annotate, "_generated_by_dataclasses", False):
_update_func_cell_for__class__(init_annotate, cls, newcls)

return newcls


Expand Down
136 changes: 135 additions & 1 deletion Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,132 @@ def __init__(self, a):
self.assertEqual(D(5).a, 10)


class TestInitAnnotate(unittest.TestCase):
# Tests for the generated __annotate__ function for __init__
# See: https://github.com/python/cpython/issues/137530

def test_annotate_function(self):
# No forward references
@dataclass
class A:
a: int

value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)

self.assertEqual(value_annos, {'a': int, 'return': None})
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})

self.assertTrue(getattr(A.__init__.__annotate__, "_generated_by_dataclasses"))

def test_annotate_function_forwardref(self):
# With forward references
@dataclass
class B:
b: undefined

# VALUE annotations should raise while unresolvable
with self.assertRaises(NameError):
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)

forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)

self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})

# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
undefined = int

value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)

self.assertEqual(value_annos, {'b': int, 'return': None})
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})

def test_annotate_function_init_false(self):
# Check `init=False` attributes don't get into the annotations of the __init__ function
@dataclass
class C:
c: str = field(init=False)

self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})

def test_annotate_function_contains_forwardref(self):
# Check string annotations on objects containing a ForwardRef
@dataclass
class D:
d: list[undefined]

with self.assertRaises(NameError):
annotationlib.get_annotations(D.__init__)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
{"d": "list[undefined]", "return": "None"}
)

# Now test when it is defined
undefined = str

# VALUE should now resolve
self.assertEqual(
annotationlib.get_annotations(D.__init__),
{"d": list[str], "return": None}
)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
{"d": list[str], "return": None}
)

self.assertEqual(
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
{"d": "list[undefined]", "return": "None"}
)

def test_annotate_function_not_replaced(self):
# Check that __annotate__ is not replaced on non-generated __init__ functions
@dataclass(slots=True)
class E:
x: str
def __init__(self, x: int) -> None:
self.x = x

self.assertEqual(
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
)

self.assertFalse(hasattr(E.__init__.__annotate__, "_generated_by_dataclasses"))

def test_init_false_forwardref(self):
# Currently this raises a NameError even though the ForwardRef
# is not in the __init__ method

@dataclass
class F:
not_in_init: list[undefined] = field(init=False, default=None)
in_init: int

annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
self.assertEqual(
annos,
{"in_init": int, "return": None},
)

with self.assertRaises(NameError):
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init


class TestRepr(unittest.TestCase):
def test_repr(self):
@dataclass
Expand Down Expand Up @@ -3831,7 +3957,15 @@ def method(self) -> int:

return SlotsTest

for make in (make_simple, make_with_annotations, make_with_annotations_and_method):
def make_with_forwardref():
@dataclass(slots=True)
class SlotsTest:
x: undefined
y: list[undefined]

return SlotsTest

for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
with self.subTest(make=make):
C = make()
support.gc_collect()
Expand Down
Loading