@@ -441,9 +441,11 @@ def __init__(self, globals):
441441 self .locals = {}
442442 self .overwrite_errors = {}
443443 self .unconditional_adds = {}
444+ self .method_annotations = {}
444445
445446 def add_fn (self , name , args , body , * , locals = None , return_type = MISSING ,
446- overwrite_error = False , unconditional_add = False , decorator = None ):
447+ overwrite_error = False , unconditional_add = False , decorator = None ,
448+ annotation_fields = None ):
447449 if locals is not None :
448450 self .locals .update (locals )
449451
@@ -464,16 +466,14 @@ def add_fn(self, name, args, body, *, locals=None, return_type=MISSING,
464466
465467 self .names .append (name )
466468
467- if return_type is not MISSING :
468- self .locals [f'__dataclass_{ name } _return_type__' ] = return_type
469- return_annotation = f'->__dataclass_{ name } _return_type__'
470- else :
471- return_annotation = ''
469+ if annotation_fields is not None :
470+ self .method_annotations [name ] = (annotation_fields , return_type )
471+
472472 args = ',' .join (args )
473473 body = '\n ' .join (body )
474474
475475 # Compute the text of the entire function, add it to the text we're generating.
476- self .src .append (f'{ f' { decorator } \n ' if decorator else '' } def { name } ({ args } ){ return_annotation } :\n { body } ' )
476+ self .src .append (f'{ f' { decorator } \n ' if decorator else '' } def { name } ({ args } ):\n { body } ' )
477477
478478 def add_fns_to_class (self , cls ):
479479 # The source to all of the functions we're generating.
@@ -509,6 +509,15 @@ def add_fns_to_class(self, cls):
509509 # Now that we've generated the functions, assign them into cls.
510510 for name , fn in zip (self .names , fns ):
511511 fn .__qualname__ = f"{ cls .__qualname__ } .{ fn .__name__ } "
512+
513+ try :
514+ annotation_fields , return_type = self .method_annotations [name ]
515+ except KeyError :
516+ pass
517+ else :
518+ annotate_fn = _make_annotate_function (cls , name , annotation_fields , return_type )
519+ fn .__annotate__ = annotate_fn
520+
512521 if self .unconditional_adds .get (name , False ):
513522 setattr (cls , name , fn )
514523 else :
@@ -524,6 +533,44 @@ def add_fns_to_class(self, cls):
524533 raise TypeError (error_msg )
525534
526535
536+ def _make_annotate_function (__class__ , method_name , annotation_fields , return_type ):
537+ # Create an __annotate__ function for a dataclass
538+ # Try to return annotations in the same format as they would be
539+ # from a regular __init__ function
540+
541+ def __annotate__ (format , / ):
542+ Format = annotationlib .Format
543+ match format :
544+ case Format .VALUE | Format .FORWARDREF | Format .STRING :
545+ cls_annotations = {}
546+ for base in reversed (__class__ .__mro__ ):
547+ cls_annotations .update (
548+ annotationlib .get_annotations (base , format = format )
549+ )
550+
551+ new_annotations = {}
552+ for k in annotation_fields :
553+ new_annotations [k ] = cls_annotations [k ]
554+
555+ if return_type is not MISSING :
556+ if format == Format .STRING :
557+ new_annotations ["return" ] = annotationlib .type_repr (return_type )
558+ else :
559+ new_annotations ["return" ] = return_type
560+
561+ return new_annotations
562+
563+ case _:
564+ raise NotImplementedError (format )
565+
566+ # This is a flag for _add_slots to know it needs to regenerate this method
567+ # In order to remove references to the original class when it is replaced
568+ __annotate__ .__generated_by_dataclasses__ = True
569+ __annotate__ .__qualname__ = f"{ __class__ .__qualname__ } .{ method_name } .__annotate__"
570+
571+ return __annotate__
572+
573+
527574def _field_assign (frozen , name , value , self_name ):
528575 # If we're a frozen class, then assign to our fields in __init__
529576 # via object.__setattr__. Otherwise, just use a simple
@@ -612,7 +659,7 @@ def _init_param(f):
612659 elif f .default_factory is not MISSING :
613660 # There's a factory function. Set a marker.
614661 default = '=__dataclass_HAS_DEFAULT_FACTORY__'
615- return f'{ f .name } :__dataclass_type_ { f . name } __ { default } '
662+ return f'{ f .name } { default } '
616663
617664
618665def _init_fn (fields , std_fields , kw_only_fields , frozen , has_post_init ,
@@ -635,11 +682,10 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
635682 raise TypeError (f'non-default argument { f .name !r} '
636683 f'follows default argument { seen_default .name !r} ' )
637684
638- locals = {** {f'__dataclass_type_{ f .name } __' : f .type for f in fields },
639- ** {'__dataclass_HAS_DEFAULT_FACTORY__' : _HAS_DEFAULT_FACTORY ,
640- '__dataclass_builtins_object__' : object ,
641- }
642- }
685+ annotation_fields = [f .name for f in fields if f .init ]
686+
687+ locals = {'__dataclass_HAS_DEFAULT_FACTORY__' : _HAS_DEFAULT_FACTORY ,
688+ '__dataclass_builtins_object__' : object }
643689
644690 body_lines = []
645691 for f in fields :
@@ -670,7 +716,8 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
670716 [self_name ] + _init_params ,
671717 body_lines ,
672718 locals = locals ,
673- return_type = None )
719+ return_type = None ,
720+ annotation_fields = annotation_fields )
674721
675722
676723def _frozen_get_del_attr (cls , fields , func_builder ):
@@ -1337,6 +1384,25 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13371384 or _update_func_cell_for__class__ (member .fdel , cls , newcls )):
13381385 break
13391386
1387+ # Get new annotations to remove references to the original class
1388+ # in forward references
1389+ newcls_ann = annotationlib .get_annotations (
1390+ newcls , format = annotationlib .Format .FORWARDREF )
1391+
1392+ # Fix references in dataclass Fields
1393+ for f in getattr (newcls , _FIELDS ).values ():
1394+ try :
1395+ ann = newcls_ann [f .name ]
1396+ except KeyError :
1397+ pass
1398+ else :
1399+ f .type = ann
1400+
1401+ # Fix the class reference in the __annotate__ method
1402+ init_annotate = newcls .__init__ .__annotate__
1403+ if getattr (init_annotate , "__generated_by_dataclasses__" , False ):
1404+ _update_func_cell_for__class__ (init_annotate , cls , newcls )
1405+
13401406 return newcls
13411407
13421408
0 commit comments