Skip to content

Commit dca9a5b

Browse files
committed
Get better 'Source' annotations - break slots gc tests
1 parent 13b11c1 commit dca9a5b

File tree

2 files changed

+34
-7
lines changed

2 files changed

+34
-7
lines changed

Lib/dataclasses.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def add_fns_to_class(self, cls):
509509
for name, fn in zip(self.names, fns):
510510
fn.__qualname__ = f"{cls.__qualname__}.{fn.__name__}"
511511
if annotations := self.method_annotations.get(name):
512-
fn.__annotate__ = _make_annotate_function(annotations)
512+
fn.__annotate__ = _make_annotate_function(cls, annotations)
513513

514514
if self.unconditional_adds.get(name, False):
515515
setattr(cls, name, fn)
@@ -526,10 +526,13 @@ def add_fns_to_class(self, cls):
526526
raise TypeError(error_msg)
527527

528528

529-
def _make_annotate_function(annotations):
529+
def _make_annotate_function(cls, annotations):
530530
# Create an __annotate__ function for a dataclass
531531
# Try to return annotations in the same format as they would be
532532
# from a regular __init__ function
533+
534+
# annotations should be in FORWARDREF format at this stage
535+
533536
def __annotate__(format):
534537
match format:
535538
case annotationlib.Format.VALUE | annotationlib.Format.FORWARDREF:
@@ -538,19 +541,30 @@ def __annotate__(format):
538541
if isinstance(v, annotationlib.ForwardRef) else v
539542
for k, v in annotations.items()
540543
}
544+
541545
case annotationlib.Format.STRING:
546+
cls_annotations = {}
547+
for base in reversed(cls.__mro__):
548+
cls_annotations.update(
549+
annotationlib.get_annotations(base, format=format)
550+
)
551+
542552
string_annos = {}
543553
for k, v in annotations.items():
544-
if isinstance(v, str):
545-
string_annos[k] = v
546-
elif isinstance(v, annotationlib.ForwardRef):
547-
string_annos[k] = v.evaluate(format=annotationlib.Format.STRING)
548-
else:
554+
try:
555+
string_annos[k] = cls_annotations[k]
556+
except KeyError:
557+
# This should be the return value
549558
string_annos[k] = annotationlib.type_repr(v)
550559
return string_annos
560+
551561
case _:
552562
raise NotImplementedError(format)
553563

564+
# This is a flag for _add_slots to know it needs to regenerate this method
565+
# In order to remove references to the original class when it is replaced
566+
__annotate__.__generated_by_dataclasses = True
567+
554568
return __annotate__
555569

556570

Lib/test/test_dataclasses/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2512,6 +2512,8 @@ class B:
25122512
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
25132513
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
25142514

2515+
del undefined # Remove so we can use the name in later examples
2516+
25152517
# Check `init=False` attributes don't get into the annotations of the __init__ function
25162518
@dataclass
25172519
class C:
@@ -2520,6 +2522,17 @@ class C:
25202522
self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
25212523

25222524

2525+
# Check string annotations on objects containing a ForwardRef
2526+
@dataclass
2527+
class D:
2528+
d: list[undefined]
2529+
2530+
self.assertEqual(
2531+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
2532+
{"d": "list[undefined]", "return": "None"}
2533+
)
2534+
2535+
25232536
class TestRepr(unittest.TestCase):
25242537
def test_repr(self):
25252538
@dataclass

0 commit comments

Comments
 (0)