Skip to content

Commit aec80d2

Browse files
Scott Sandersonrgbkrk
authored andcommitted
BUG: Fix crash when pickling dynamic class cycles.
Fixes a bug where we would fail to pickle a class created inside a function if that class participated in a cycle with its own __dict__. Such cycles occur, for example, when a class defines a method that makes a Python 2-style super call, because we have a cycle from class -> __dict__ -> function -> __closure__ -> class. The fix for this is to use the same technique we use to dynamically-created functions: we first pickle an empty "skeleton class", which we memoize before pickling the rest of the class' __dict__. We then invoke a reduce function that re-attaches the class' attributes from the __dict__.
1 parent c89dc9d commit aec80d2

File tree

2 files changed

+127
-12
lines changed

2 files changed

+127
-12
lines changed

cloudpickle/cloudpickle.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,71 @@ def _save_subimports(self, code, top_level_dependencies):
393393
# then discards the reference to it
394394
self.write(pickle.POP)
395395

396+
def save_dynamic_class(self, obj):
397+
"""
398+
Save a class that can't be stored as module global.
399+
400+
This method is used to serialize classes that are defined inside
401+
functions, or that otherwise can't be serialized as attribute lookups
402+
from global modules.
403+
"""
404+
clsdict = dict(obj.__dict__) # copy dict proxy to a dict
405+
if not isinstance(clsdict.get('__dict__', None), property):
406+
# don't extract dict that are properties
407+
clsdict.pop('__dict__', None)
408+
clsdict.pop('__weakref__', None)
409+
410+
# hack as __new__ is stored differently in the __dict__
411+
new_override = clsdict.get('__new__', None)
412+
if new_override:
413+
clsdict['__new__'] = obj.__new__
414+
415+
save = self.save
416+
write = self.write
417+
418+
# We write pickle instructions explicitly here to handle the
419+
# possibility that the type object participates in a cycle with its own
420+
# __dict__. We first write an empty "skeleton" version of the class and
421+
# memoize it before writing the class' __dict__ itself. We then write
422+
# instructions to "rehydrate" the skeleton class by restoring the
423+
# attributes from the __dict__.
424+
#
425+
# A type can appear in a cycle with its __dict__ if an instance of the
426+
# type appears in the type's __dict__ (which happens for the stdlib
427+
# Enum class), or if the type defines methods that close over the name
428+
# of the type, (which is common for Python 2-style super() calls).
429+
430+
# Push the rehydration function.
431+
save(_rehydrate_skeleton_class)
432+
433+
# Mark the start of the args for the rehydration function.
434+
write(pickle.MARK)
435+
436+
# On PyPy, __doc__ is a readonly attribute, so we need to include it in
437+
# the initial skeleton class. This is safe because we know that the
438+
# doc can't participate in a cycle with the original class.
439+
doc_dict = {'__doc__': clsdict.pop('__doc__', None)}
440+
441+
# Create and memoize an empty class with obj's name and bases.
442+
save(type(obj))
443+
save((
444+
obj.__name__,
445+
obj.__bases__,
446+
doc_dict,
447+
))
448+
write(pickle.REDUCE)
449+
self.memoize(obj)
450+
451+
# Now save the rest of obj's __dict__. Any references to obj
452+
# encountered while saving will point to the skeleton class.
453+
save(clsdict)
454+
455+
# Write a tuple of (skeleton_class, clsdict).
456+
write(pickle.TUPLE)
457+
458+
# Call _rehydrate_skeleton_class(skeleton_class, clsdict)
459+
write(pickle.REDUCE)
460+
396461
def save_function_tuple(self, func):
397462
""" Pickles an actual func object.
398463
@@ -513,6 +578,12 @@ def save_builtin_function(self, obj):
513578
dispatch[types.BuiltinFunctionType] = save_builtin_function
514579

515580
def save_global(self, obj, name=None, pack=struct.pack):
581+
"""
582+
Save a "global".
583+
584+
The name of this method is somewhat misleading: all types get
585+
dispatched here.
586+
"""
516587
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
517588
if obj in _BUILTIN_TYPE_NAMES:
518589
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
@@ -536,18 +607,7 @@ def save_global(self, obj, name=None, pack=struct.pack):
536607

537608
typ = type(obj)
538609
if typ is not obj and isinstance(obj, (type, types.ClassType)):
539-
d = dict(obj.__dict__) # copy dict proxy to a dict
540-
if not isinstance(d.get('__dict__', None), property):
541-
# don't extract dict that are properties
542-
d.pop('__dict__', None)
543-
d.pop('__weakref__', None)
544-
545-
# hack as __new__ is stored differently in the __dict__
546-
new_override = d.get('__new__', None)
547-
if new_override:
548-
d['__new__'] = obj.__new__
549-
550-
self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
610+
self.save_dynamic_class(obj)
551611
else:
552612
raise pickle.PicklingError("Can't pickle %r" % obj)
553613

@@ -986,6 +1046,16 @@ def _make_skel_func(code, cell_count, base_globals=None):
9861046
return types.FunctionType(code, base_globals, None, None, closure)
9871047

9881048

1049+
def _rehydrate_skeleton_class(skeleton_class, class_dict):
1050+
"""Put attributes from `class_dict` back on `skeleton_class`.
1051+
1052+
See CloudPickler.save_dynamic_class for more info.
1053+
"""
1054+
for attrname, attr in class_dict.items():
1055+
setattr(skeleton_class, attrname, attr)
1056+
return skeleton_class
1057+
1058+
9891059
def _find_module(mod_name):
9901060
"""
9911061
Iterate over each part instead of calling imp.find_module directly.

tests/cloudpickle_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,51 @@ def g():
197197
g = pickle_depickle(f())
198198
self.assertEqual(g(), 2)
199199

200+
def test_dynamically_generated_class_that_uses_super(self):
201+
202+
class Base(object):
203+
def method(self):
204+
return 1
205+
206+
class Derived(Base):
207+
"Derived Docstring"
208+
def method(self):
209+
return super(Derived, self).method() + 1
210+
211+
self.assertEqual(Derived().method(), 2)
212+
213+
# Pickle and unpickle the class.
214+
UnpickledDerived = pickle_depickle(Derived)
215+
self.assertEqual(UnpickledDerived().method(), 2)
216+
217+
# We have special logic for handling __doc__ because it's a readonly
218+
# attribute on PyPy.
219+
self.assertEqual(UnpickledDerived.__doc__, "Derived Docstring")
220+
221+
# Pickle and unpickle an instance.
222+
orig_d = Derived()
223+
d = pickle_depickle(orig_d)
224+
self.assertEqual(d.method(), 2)
225+
226+
def test_cycle_in_classdict_globals(self):
227+
228+
class C(object):
229+
230+
def it_works(self):
231+
return "woohoo!"
232+
233+
C.C_again = C
234+
C.instance_of_C = C()
235+
236+
depickled_C = pickle_depickle(C)
237+
depickled_instance = pickle_depickle(C())
238+
239+
# Test instance of depickled class.
240+
self.assertEqual(depickled_C().it_works(), "woohoo!")
241+
self.assertEqual(depickled_C.C_again().it_works(), "woohoo!")
242+
self.assertEqual(depickled_C.instance_of_C.it_works(), "woohoo!")
243+
self.assertEqual(depickled_instance.it_works(), "woohoo!")
244+
200245
@pytest.mark.skipif(sys.version_info >= (3, 4)
201246
and sys.version_info < (3, 4, 3),
202247
reason="subprocess has a bug in 3.4.0 to 3.4.2")

0 commit comments

Comments
 (0)