Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
5820f94
replace source annotations in init with attached annotate function
DavidCEllis Aug 12, 2025
d0c5d3b
Include the new test for init annotations in all formats
DavidCEllis Aug 13, 2025
b340a2f
Add an extra check for init=False attributes
DavidCEllis Aug 13, 2025
c7a4c56
Test that defining a forwardref makes init annotations resolve
DavidCEllis Aug 13, 2025
3a3b6c0
move _make_annotate_function to top level
DavidCEllis Aug 13, 2025
13b11c1
stop unpacking a dictionary into another dictionary
DavidCEllis Aug 13, 2025
dca9a5b
Get better 'Source' annotations - break slots gc tests
DavidCEllis Aug 15, 2025
a4fe18a
Add an extra GC test for forwardrefs
DavidCEllis Aug 15, 2025
7795c9b
Remove references to original class from fields and annotate functions
DavidCEllis Aug 15, 2025
37d9b3c
Test we don't fix annotations on hand-written init functions
DavidCEllis Aug 15, 2025
864305d
*Actually* test we don't fix annotations on hand-written init functions
DavidCEllis Aug 15, 2025
9c6ed47
Use the same annotation logic for VALUE, FORWARDREF and STRING
DavidCEllis Aug 15, 2025
ab74435
Use a list of fields and return type instead of annotations dictionary.
DavidCEllis Aug 15, 2025
30512bd
remove outdated comment
DavidCEllis Aug 15, 2025
86da9b8
Include comment indicating the issue with VALUE annotations should id…
DavidCEllis Oct 21, 2025
d6c680d
Use a dunder name instead of a single underscore
DavidCEllis Oct 21, 2025
161f3e2
Use the dunder name in the tests too
DavidCEllis Oct 21, 2025
e48d6de
📜🤖 Added by blurb_it.
blurb-it[bot] Oct 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions Lib/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,9 @@ def __init__(self, globals):
self.locals = {}
self.overwrite_errors = {}
self.unconditional_adds = {}
self.method_annotations = {}

def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
def add_fn(self, name, args, body, *, locals=None, annotations=None,
overwrite_error=False, unconditional_add=False, decorator=None):
if locals is not None:
self.locals.update(locals)
Expand All @@ -464,16 +465,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,

self.names.append(name)

if return_type is not MISSING:
self.locals[f'__dataclass_{name}_return_type__'] = return_type
return_annotation = f'->__dataclass_{name}_return_type__'
else:
return_annotation = ''
if annotations is not None:
self.method_annotations[name] = annotations

args = ','.join(args)
body = '\n'.join(body)

# Compute the text of the entire function, add it to the text we're generating.
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')

def add_fns_to_class(self, cls):
# The source to all of the functions we're generating.
Expand Down Expand Up @@ -509,6 +508,9 @@ def add_fns_to_class(self, cls):
# Now that we've generated the functions, assign them into cls.
for name, fn in zip(self.names, fns):
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
if annotations := self.method_annotations.get(name):
fn.__annotate__ = _make_annotate_function(annotations)

if self.unconditional_adds.get(name, False):
setattr(cls, name, fn)
else:
Expand All @@ -524,6 +526,34 @@ def add_fns_to_class(self, cls):
raise TypeError(error_msg)


def _make_annotate_function(annotations):
# Create an __annotate__ function for a dataclass
# Try to return annotations in the same format as they would be
# from a regular __init__ function
def __annotate__(format):
match format:
case annotationlib.Format.VALUE | annotationlib.Format.FORWARDREF:
return {
k: v.evaluate(format=format)
if isinstance(v, annotationlib.ForwardRef) else v
for k, v in annotations.items()
}
case annotationlib.Format.STRING:
string_annos = {}
for k, v in annotations.items():
if isinstance(v, str):
string_annos[k] = v
elif isinstance(v, annotationlib.ForwardRef):
string_annos[k] = v.evaluate(format=annotationlib.Format.STRING)
else:
string_annos[k] = annotationlib.type_repr(v)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unfortunate because we could get better STRING annotations by calling with the STRING format on the original class (and its base classes).

An approach to do this would be to make __init__.__annotate__ delegate to calling get_annotations() on the class itself, and on any base classes that are contributing fields, with the appropriate format, then processing the result to add "return": None and filter out any fields that don't correspond to __init__ parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did think about that, but wasn't certain about the extra complexity, but I've since noticed that if there's a hint that contains a ForwardRef (like list[undefined]) the result leaks all the ForwardRef details.

I've added a test for that example to demonstrate why this is necessary.

I've taken a slightly different approach, in that I'm gathering all inherited annotations and using them to update the values from the fields. The original source annotation logic this replaces wasn't __init__ specific so I've tried to keep the __annotate__ generator non-specific too.


There are additional commits as this change meant that __annotate__ now (always) has a reference to the original class which broke #135228 - this now also includes an additional test and fix for the issue I brought up there as I already needed to update the fields for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also unintentionally discovered that the __annotate__ functions CPython generates for methods have an incorrect __qualname__ while fixing the __qualname__ for the function generated here.

return string_annos
case _:
raise NotImplementedError(format)

return __annotate__


def _field_assign(frozen, name, value, self_name):
# If we're a frozen class, then assign to our fields in __init__
# via object.__setattr__. Otherwise, just use a simple
Expand Down Expand Up @@ -612,7 +642,7 @@ def _init_param(f):
elif f.default_factory is not MISSING:
# There's a factory function. Set a marker.
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
return f'{f.name}:__dataclass_type_{f.name}__{default}'
return f'{f.name}{default}'


def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
Expand All @@ -635,11 +665,11 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
raise TypeError(f'non-default argument {f.name!r} '
f'follows default argument {seen_default.name!r}')

locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object,
}
}
annotations = {f.name: f.type for f in fields if f.init}
annotations["return"] = None

locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
'__dataclass_builtins_object__': object}

body_lines = []
for f in fields:
Expand Down Expand Up @@ -670,7 +700,7 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
[self_name] + _init_params,
body_lines,
locals=locals,
return_type=None)
annotations=annotations)


def _frozen_get_del_attr(cls, fields, func_builder):
Expand Down
49 changes: 49 additions & 0 deletions Lib/test/test_dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,6 +2470,55 @@ def __init__(self, a):

self.assertEqual(D(5).a, 10)

def test_annotate_function(self):
# Test that the __init__ function has correct annotate function
# See: https://github.com/python/cpython/issues/137530
# With no forward references
@dataclass
class A:
a: int

value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)

self.assertEqual(value_annos, {'a': int, 'return': None})
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})

# With forward references
@dataclass
class B:
b: undefined

# VALUE annotations should raise while unresolvable
with self.assertRaises(NameError):
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)

forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)

self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})

# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
undefined = int

value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)

self.assertEqual(value_annos, {'b': int, 'return': None})
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})

# Check `init=False` attributes don't get into the annotations of the __init__ function
@dataclass
class C:
c: str = field(init=False)

self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})


class TestRepr(unittest.TestCase):
def test_repr(self):
Expand Down
Loading