Skip to content

Commit 07359ad

Browse files
committed
Use __qualname__ to help serialize curried objects.
This isn't pretty, but hopefully tests provide adequate coverage. This includes a signficant change to `curry`: we defined `__getattr__` to get attributes from the wrapped func.
1 parent ffac221 commit 07359ad

File tree

2 files changed

+146
-13
lines changed

2 files changed

+146
-13
lines changed

toolz/functoolz.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ def keywords(self):
259259
def func_name(self):
260260
return self.__name__
261261

262+
def __getattr__(self, attr):
263+
return getattr(self.func, attr)
264+
262265
def __str__(self):
263266
return str(self.func)
264267

@@ -325,27 +328,42 @@ def __get__(self, instance, owner):
325328
def __reduce__(self):
326329
func = self.func
327330
modname = getattr(func, '__module__', None)
328-
funcname = getattr(func, '__name__', None)
329-
if modname and funcname:
330-
module = import_module(modname)
331-
obj = getattr(module, funcname, None)
332-
if obj is self:
333-
return funcname
334-
elif isinstance(obj, curry) and obj.func is func:
335-
func = '%s.%s' % (modname, funcname)
331+
qualname = getattr(func, '__qualname__', None)
332+
if qualname is None:
333+
qualname = getattr(func, '__name__', None)
334+
is_decorated = None
335+
if modname and qualname:
336+
attrs = []
337+
obj = import_module(modname)
338+
for attr in qualname.split('.'):
339+
if isinstance(obj, curry):
340+
attrs.append('func')
341+
obj = obj.func
342+
obj = getattr(obj, attr, None)
343+
if obj is None:
344+
break
345+
attrs.append(attr)
346+
if isinstance(obj, curry) and obj.func is func:
347+
is_decorated = obj is self
348+
qualname = '.'.join(attrs)
349+
func = '%s:%s' % (modname, qualname)
336350

337351
# functools.partial objects can't be pickled
338352
userdict = tuple((k, v) for k, v in self.__dict__.items()
339353
if k != '_partial')
340-
state = (type(self), func, self.args, self.keywords, userdict)
354+
state = (type(self), func, self.args, self.keywords, userdict, is_decorated)
341355
return (_restore_curry, state)
342356

343357

344-
def _restore_curry(cls, func, args, kwargs, userdict):
358+
def _restore_curry(cls, func, args, kwargs, userdict, is_decorated):
345359
if isinstance(func, str):
346-
modname, funcname = func.rsplit('.', 1)
347-
module = import_module(modname)
348-
func = getattr(module, funcname).func
360+
modname, qualname = func.rsplit(':', 1)
361+
obj = import_module(modname)
362+
for attr in qualname.split('.'):
363+
obj = getattr(obj, attr)
364+
if is_decorated:
365+
return obj
366+
func = obj.func
349367
obj = cls(func, *args, **(kwargs or {}))
350368
obj.__dict__.update(userdict)
351369
return obj

toolz/tests/test_serialization.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import toolz
33
import toolz.curried.exceptions
44
import pickle
5+
from toolz.compatibility import PY3
56

67

78
def test_compose():
@@ -62,3 +63,117 @@ def test_curried_exceptions():
6263
# This tests a global curried object that isn't defined in toolz.functoolz
6364
merge = pickle.loads(pickle.dumps(toolz.curried.exceptions.merge))
6465
assert merge is toolz.curried.exceptions.merge
66+
67+
68+
@toolz.curry
69+
class GlobalCurried(object):
70+
def __init__(self, x, y):
71+
self.x = x
72+
self.y = y
73+
74+
@toolz.curry
75+
def f1(self, a, b):
76+
return self.x + self.y + a + b
77+
78+
def g1(self):
79+
pass
80+
81+
def __reduce__(self):
82+
"""Allow us to serialize instances of GlobalCurried"""
83+
return (GlobalCurried, (self.x, self.y))
84+
85+
@toolz.curry
86+
class NestedCurried(object):
87+
def __init__(self, x, y):
88+
self.x = x
89+
self.y = y
90+
91+
@toolz.curry
92+
def f2(self, a, b):
93+
return self.x + self.y + a + b
94+
95+
def g2(self):
96+
pass
97+
98+
def __reduce__(self):
99+
"""Allow us to serialize instances of NestedCurried"""
100+
return (GlobalCurried.NestedCurried, (self.x, self.y))
101+
102+
class Nested(object):
103+
def __init__(self, x, y):
104+
self.x = x
105+
self.y = y
106+
107+
@toolz.curry
108+
def f3(self, a, b):
109+
return self.x + self.y + a + b
110+
111+
def g3(self):
112+
pass
113+
114+
115+
def test_curried_qualname():
116+
if not PY3:
117+
return
118+
def preserves_identity(obj):
119+
return pickle.loads(pickle.dumps(obj)) is obj
120+
121+
assert preserves_identity(GlobalCurried)
122+
assert preserves_identity(GlobalCurried.func.f1)
123+
assert preserves_identity(GlobalCurried.func.g1)
124+
assert preserves_identity(GlobalCurried.func.NestedCurried)
125+
assert preserves_identity(GlobalCurried.func.NestedCurried.func.f2)
126+
assert preserves_identity(GlobalCurried.func.NestedCurried.func.g2)
127+
assert preserves_identity(GlobalCurried.func.Nested)
128+
assert preserves_identity(GlobalCurried.func.Nested.f3)
129+
assert preserves_identity(GlobalCurried.func.Nested.g3)
130+
131+
# Rely on curry.__getattr__
132+
assert preserves_identity(GlobalCurried.f1)
133+
assert preserves_identity(GlobalCurried.g1)
134+
assert preserves_identity(GlobalCurried.NestedCurried)
135+
assert preserves_identity(GlobalCurried.NestedCurried.f2)
136+
assert preserves_identity(GlobalCurried.NestedCurried.g2)
137+
assert preserves_identity(GlobalCurried.Nested)
138+
assert preserves_identity(GlobalCurried.Nested.f3)
139+
assert preserves_identity(GlobalCurried.Nested.g3)
140+
141+
global_curried1 = GlobalCurried(1)
142+
global_curried2 = pickle.loads(pickle.dumps(global_curried1))
143+
assert global_curried1 is not global_curried2
144+
assert global_curried1(2).f1(3, 4) == global_curried2(2).f1(3, 4) == 10
145+
146+
global_curried3 = global_curried1(2)
147+
global_curried4 = pickle.loads(pickle.dumps(global_curried3))
148+
assert global_curried3 is not global_curried4
149+
assert global_curried3.f1(3, 4) == global_curried4.f1(3, 4) == 10
150+
151+
func1 = global_curried1(2).f1(3)
152+
func2 = pickle.loads(pickle.dumps(func1))
153+
assert func1 is not func2
154+
assert func1(4) == func2(4) == 10
155+
156+
nested_curried1 = GlobalCurried.NestedCurried(1)
157+
nested_curried2 = pickle.loads(pickle.dumps(nested_curried1))
158+
assert nested_curried1 is not nested_curried2
159+
assert nested_curried1(2).f2(3, 4) == nested_curried2(2).f2(3, 4) == 10
160+
161+
nested_curried3 = nested_curried1(2)
162+
nested_curried4 = pickle.loads(pickle.dumps(nested_curried3))
163+
assert nested_curried3 is not nested_curried4
164+
assert nested_curried3.f2(3, 4) == nested_curried4.f2(3, 4) == 10
165+
166+
func1 = nested_curried1(2).f2(3)
167+
func2 = pickle.loads(pickle.dumps(func1))
168+
assert func1 is not func2
169+
assert func1(4) == func2(4) == 10
170+
171+
nested3 = GlobalCurried.Nested(1, 2)
172+
nested4 = pickle.loads(pickle.dumps(nested3))
173+
assert nested3 is not nested4
174+
assert nested3.f3(3, 4) == nested4.f3(3, 4) == 10
175+
176+
func1 = nested3.f3(3)
177+
func2 = pickle.loads(pickle.dumps(func1))
178+
assert func1 is not func2
179+
assert func1(4) == func2(4) == 10

0 commit comments

Comments
 (0)