Skip to content

Commit 12837c6

Browse files
authored
pythongh-137530: generate an __annotate__ function for dataclasses __init__ (pythonGH-137711)
1 parent 9f51524 commit 12837c6

File tree

4 files changed

+227
-15
lines changed

4 files changed

+227
-15
lines changed

Doc/whatsnew/3.15.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,14 @@ collections.abc
368368
previously emitted if it was merely imported or accessed from the
369369
:mod:`!collections.abc` module.
370370

371+
372+
dataclasses
373+
-----------
374+
375+
* Annotations for generated ``__init__`` methods no longer include internal
376+
type names.
377+
378+
371379
dbm
372380
---
373381

Lib/dataclasses.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,11 @@ def __init__(self, globals):
441441
self.locals = {}
442442
self.overwrite_errors = {}
443443
self.unconditional_adds = {}
444+
self.method_annotations = {}
444445

445446
def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
446-
overwrite_error=False, unconditional_add=False, decorator=None):
447+
overwrite_error=False, unconditional_add=False, decorator=None,
448+
annotation_fields=None):
447449
if locals is not None:
448450
self.locals.update(locals)
449451

@@ -464,16 +466,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
464466

465467
self.names.append(name)
466468

467-
if return_type is not MISSING:
468-
self.locals[f'__dataclass_{name}_return_type__'] = return_type
469-
return_annotation = f'->__dataclass_{name}_return_type__'
470-
else:
471-
return_annotation = ''
469+
if annotation_fields is not None:
470+
self.method_annotations[name] = (annotation_fields, return_type)
471+
472472
args = ','.join(args)
473473
body = '\n'.join(body)
474474

475475
# Compute the text of the entire function, add it to the text we're generating.
476-
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}){return_annotation}:\n{body}')
476+
self.src.append(f'{f' {decorator}\n' if decorator else ''} def {name}({args}):\n{body}')
477477

478478
def add_fns_to_class(self, cls):
479479
# The source to all of the functions we're generating.
@@ -509,6 +509,15 @@ def add_fns_to_class(self, cls):
509509
# Now that we've generated the functions, assign them into cls.
510510
for name, fn in zip(self.names, fns):
511511
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
512+
513+
try:
514+
annotation_fields, return_type = self.method_annotations[name]
515+
except KeyError:
516+
pass
517+
else:
518+
annotate_fn = _make_annotate_function(cls, name, annotation_fields, return_type)
519+
fn.__annotate__ = annotate_fn
520+
512521
if self.unconditional_adds.get(name, False):
513522
setattr(cls, name, fn)
514523
else:
@@ -524,6 +533,44 @@ def add_fns_to_class(self, cls):
524533
raise TypeError(error_msg)
525534

526535

536+
def _make_annotate_function(__class__, method_name, annotation_fields, return_type):
537+
# Create an __annotate__ function for a dataclass
538+
# Try to return annotations in the same format as they would be
539+
# from a regular __init__ function
540+
541+
def __annotate__(format, /):
542+
Format = annotationlib.Format
543+
match format:
544+
case Format.VALUE | Format.FORWARDREF | Format.STRING:
545+
cls_annotations = {}
546+
for base in reversed(__class__.__mro__):
547+
cls_annotations.update(
548+
annotationlib.get_annotations(base, format=format)
549+
)
550+
551+
new_annotations = {}
552+
for k in annotation_fields:
553+
new_annotations[k] = cls_annotations[k]
554+
555+
if return_type is not MISSING:
556+
if format == Format.STRING:
557+
new_annotations["return"] = annotationlib.type_repr(return_type)
558+
else:
559+
new_annotations["return"] = return_type
560+
561+
return new_annotations
562+
563+
case _:
564+
raise NotImplementedError(format)
565+
566+
# This is a flag for _add_slots to know it needs to regenerate this method
567+
# In order to remove references to the original class when it is replaced
568+
__annotate__.__generated_by_dataclasses__ = True
569+
__annotate__.__qualname__ = f"{__class__.__qualname__}.{method_name}.__annotate__"
570+
571+
return __annotate__
572+
573+
527574
def _field_assign(frozen, name, value, self_name):
528575
# If we're a frozen class, then assign to our fields in __init__
529576
# via object.__setattr__. Otherwise, just use a simple
@@ -612,7 +659,7 @@ def _init_param(f):
612659
elif f.default_factory is not MISSING:
613660
# There's a factory function. Set a marker.
614661
default = '=__dataclass_HAS_DEFAULT_FACTORY__'
615-
return f'{f.name}:__dataclass_type_{f.name}__{default}'
662+
return f'{f.name}{default}'
616663

617664

618665
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
@@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
635682
raise TypeError(f'non-default argument {f.name!r} '
636683
f'follows default argument {seen_default.name!r}')
637684

638-
locals = {**{f'__dataclass_type_{f.name}__': f.type for f in fields},
639-
**{'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
640-
'__dataclass_builtins_object__': object,
641-
}
642-
}
685+
annotation_fields = [f.name for f in fields if f.init]
686+
687+
locals = {'__dataclass_HAS_DEFAULT_FACTORY__': _HAS_DEFAULT_FACTORY,
688+
'__dataclass_builtins_object__': object}
643689

644690
body_lines = []
645691
for f in fields:
@@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
670716
[self_name] + _init_params,
671717
body_lines,
672718
locals=locals,
673-
return_type=None)
719+
return_type=None,
720+
annotation_fields=annotation_fields)
674721

675722

676723
def _frozen_get_del_attr(cls, fields, func_builder):
@@ -1337,6 +1384,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13371384
or _update_func_cell_for__class__(member.fdel, cls, newcls)):
13381385
break
13391386

1387+
# Get new annotations to remove references to the original class
1388+
# in forward references
1389+
newcls_ann = annotationlib.get_annotations(
1390+
newcls, format=annotationlib.Format.FORWARDREF)
1391+
1392+
# Fix references in dataclass Fields
1393+
for f in getattr(newcls, _FIELDS).values():
1394+
try:
1395+
ann = newcls_ann[f.name]
1396+
except KeyError:
1397+
pass
1398+
else:
1399+
f.type = ann
1400+
1401+
# Fix the class reference in the __annotate__ method
1402+
init_annotate = newcls.__init__.__annotate__
1403+
if getattr(init_annotate, "__generated_by_dataclasses__", False):
1404+
_update_func_cell_for__class__(init_annotate, cls, newcls)
1405+
13401406
return newcls
13411407

13421408

Lib/test/test_dataclasses/__init__.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2471,6 +2471,135 @@ def __init__(self, a):
24712471
self.assertEqual(D(5).a, 10)
24722472

24732473

2474+
class TestInitAnnotate(unittest.TestCase):
2475+
# Tests for the generated __annotate__ function for __init__
2476+
# See: https://github.com/python/cpython/issues/137530
2477+
2478+
def test_annotate_function(self):
2479+
# No forward references
2480+
@dataclass
2481+
class A:
2482+
a: int
2483+
2484+
value_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.VALUE)
2485+
forwardref_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.FORWARDREF)
2486+
string_annos = annotationlib.get_annotations(A.__init__, format=annotationlib.Format.STRING)
2487+
2488+
self.assertEqual(value_annos, {'a': int, 'return': None})
2489+
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
2490+
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})
2491+
2492+
self.assertTrue(getattr(A.__init__.__annotate__, "__generated_by_dataclasses__"))
2493+
2494+
def test_annotate_function_forwardref(self):
2495+
# With forward references
2496+
@dataclass
2497+
class B:
2498+
b: undefined
2499+
2500+
# VALUE annotations should raise while unresolvable
2501+
with self.assertRaises(NameError):
2502+
_ = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
2503+
2504+
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
2505+
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
2506+
2507+
self.assertEqual(forwardref_annos, {'b': support.EqualToForwardRef('undefined', owner=B, is_class=True), 'return': None})
2508+
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
2509+
2510+
# Now VALUE and FORWARDREF should resolve, STRING should be unchanged
2511+
undefined = int
2512+
2513+
value_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.VALUE)
2514+
forwardref_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.FORWARDREF)
2515+
string_annos = annotationlib.get_annotations(B.__init__, format=annotationlib.Format.STRING)
2516+
2517+
self.assertEqual(value_annos, {'b': int, 'return': None})
2518+
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
2519+
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
2520+
2521+
def test_annotate_function_init_false(self):
2522+
# Check `init=False` attributes don't get into the annotations of the __init__ function
2523+
@dataclass
2524+
class C:
2525+
c: str = field(init=False)
2526+
2527+
self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
2528+
2529+
def test_annotate_function_contains_forwardref(self):
2530+
# Check string annotations on objects containing a ForwardRef
2531+
@dataclass
2532+
class D:
2533+
d: list[undefined]
2534+
2535+
with self.assertRaises(NameError):
2536+
annotationlib.get_annotations(D.__init__)
2537+
2538+
self.assertEqual(
2539+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
2540+
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
2541+
)
2542+
2543+
self.assertEqual(
2544+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
2545+
{"d": "list[undefined]", "return": "None"}
2546+
)
2547+
2548+
# Now test when it is defined
2549+
undefined = str
2550+
2551+
# VALUE should now resolve
2552+
self.assertEqual(
2553+
annotationlib.get_annotations(D.__init__),
2554+
{"d": list[str], "return": None}
2555+
)
2556+
2557+
self.assertEqual(
2558+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
2559+
{"d": list[str], "return": None}
2560+
)
2561+
2562+
self.assertEqual(
2563+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
2564+
{"d": "list[undefined]", "return": "None"}
2565+
)
2566+
2567+
def test_annotate_function_not_replaced(self):
2568+
# Check that __annotate__ is not replaced on non-generated __init__ functions
2569+
@dataclass(slots=True)
2570+
class E:
2571+
x: str
2572+
def __init__(self, x: int) -> None:
2573+
self.x = x
2574+
2575+
self.assertEqual(
2576+
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
2577+
)
2578+
2579+
self.assertFalse(hasattr(E.__init__.__annotate__, "__generated_by_dataclasses__"))
2580+
2581+
def test_init_false_forwardref(self):
2582+
# Test forward references in fields not required for __init__ annotations.
2583+
2584+
# At the moment this raises a NameError for VALUE annotations even though the
2585+
# undefined annotation is not required for the __init__ annotations.
2586+
# Ideally this will be fixed but currently there is no good way to resolve this
2587+
2588+
@dataclass
2589+
class F:
2590+
not_in_init: list[undefined] = field(init=False, default=None)
2591+
in_init: int
2592+
2593+
annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
2594+
self.assertEqual(
2595+
annos,
2596+
{"in_init": int, "return": None},
2597+
)
2598+
2599+
with self.assertRaises(NameError):
2600+
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init
2601+
2602+
24742603
class TestRepr(unittest.TestCase):
24752604
def test_repr(self):
24762605
@dataclass
@@ -3831,7 +3960,15 @@ def method(self) -> int:
38313960

38323961
return SlotsTest
38333962

3834-
for make in (make_simple, make_with_annotations, make_with_annotations_and_method):
3963+
def make_with_forwardref():
3964+
@dataclass(slots=True)
3965+
class SlotsTest:
3966+
x: undefined
3967+
y: list[undefined]
3968+
3969+
return SlotsTest
3970+
3971+
for make in (make_simple, make_with_annotations, make_with_annotations_and_method, make_with_forwardref):
38353972
with self.subTest(make=make):
38363973
C = make()
38373974
support.gc_collect()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:mod:`dataclasses` Fix annotations for generated ``__init__`` methods by replacing the annotations that were in-line in the generated source code with ``__annotate__`` functions attached to the methods.

0 commit comments

Comments
 (0)