@@ -1218,9 +1218,31 @@ def _get_slots(cls):
12181218 raise TypeError (f"Slots of '{ cls .__name__ } ' cannot be determined" )
12191219
12201220
1221+ def _update_func_cell_for__class__ (f , oldcls , newcls ):
1222+ # Returns True if we update a cell, else False.
1223+ if f is None :
1224+ # f will be None in the case of a property where not all of
1225+ # fget, fset, and fdel are used. Nothing to do in that case.
1226+ return False
1227+ try :
1228+ idx = f .__code__ .co_freevars .index ("__class__" )
1229+ except ValueError :
1230+ # This function doesn't reference __class__, so nothing to do.
1231+ return False
1232+ # Fix the cell to point to the new class, if it's already pointing
1233+ # at the old class. I'm not convinced that the "is oldcls" test
1234+ # is needed, but other than performance can't hurt.
1235+ closure = f .__closure__ [idx ]
1236+ if closure .cell_contents is oldcls :
1237+ closure .cell_contents = newcls
1238+ return True
1239+ return False
1240+
1241+
12211242def _add_slots (cls , is_frozen , weakref_slot ):
1222- # Need to create a new class, since we can't set __slots__
1223- # after a class has been created.
1243+ # Need to create a new class, since we can't set __slots__ after a
1244+ # class has been created, and the @dataclass decorator is called
1245+ # after the class is created.
12241246
12251247 # Make sure __slots__ isn't already set.
12261248 if '__slots__' in cls .__dict__ :
@@ -1259,18 +1281,37 @@ def _add_slots(cls, is_frozen, weakref_slot):
12591281
12601282 # And finally create the class.
12611283 qualname = getattr (cls , '__qualname__' , None )
1262- cls = type (cls )(cls .__name__ , cls .__bases__ , cls_dict )
1284+ newcls = type (cls )(cls .__name__ , cls .__bases__ , cls_dict )
12631285 if qualname is not None :
1264- cls .__qualname__ = qualname
1286+ newcls .__qualname__ = qualname
12651287
12661288 if is_frozen :
12671289 # Need this for pickling frozen classes with slots.
12681290 if '__getstate__' not in cls_dict :
1269- cls .__getstate__ = _dataclass_getstate
1291+ newcls .__getstate__ = _dataclass_getstate
12701292 if '__setstate__' not in cls_dict :
1271- cls .__setstate__ = _dataclass_setstate
1272-
1273- return cls
1293+ newcls .__setstate__ = _dataclass_setstate
1294+
1295+ # Fix up any closures which reference __class__. This is used to
1296+ # fix zero argument super so that it points to the correct class
1297+ # (the newly created one, which we're returning) and not the
1298+ # original class. We can break out of this loop as soon as we
1299+ # make an update, since all closures for a class will share a
1300+ # given cell.
1301+ for member in newcls .__dict__ .values ():
1302+ # If this is a wrapped function, unwrap it.
1303+ member = inspect .unwrap (member )
1304+
1305+ if isinstance (member , types .FunctionType ):
1306+ if _update_func_cell_for__class__ (member , cls , newcls ):
1307+ break
1308+ elif isinstance (member , property ):
1309+ if (_update_func_cell_for__class__ (member .fget , cls , newcls )
1310+ or _update_func_cell_for__class__ (member .fset , cls , newcls )
1311+ or _update_func_cell_for__class__ (member .fdel , cls , newcls )):
1312+ break
1313+
1314+ return newcls
12741315
12751316
12761317def dataclass (cls = None , / , * , init = True , repr = True , eq = True , order = False ,
0 commit comments