Skip to content

Commit be3ce86

Browse files
authored
Merge pull request #355 from eriknw/fix/curry_module
Set `__module__` on curried objects. This can fix pickling global curried objects
2 parents 7fb83c8 + 0ef082c commit be3ce86

File tree

3 files changed

+181
-15
lines changed

3 files changed

+181
-15
lines changed

toolz/functoolz.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def __init__(self, *args, **kwargs):
200200

201201
self.__doc__ = getattr(func, '__doc__', None)
202202
self.__name__ = getattr(func, '__name__', '<curry>')
203+
self.__module__ = getattr(func, '__module__', None)
204+
self.__qualname__ = getattr(func, '__qualname__', None)
203205
self._sigspec = None
204206
self._has_unknown_args = None
205207

@@ -324,27 +326,43 @@ def __get__(self, instance, owner):
324326
def __reduce__(self):
325327
func = self.func
326328
modname = getattr(func, '__module__', None)
327-
funcname = getattr(func, '__name__', None)
328-
if modname and funcname:
329-
module = import_module(modname)
330-
obj = getattr(module, funcname, None)
331-
if obj is self:
332-
return funcname
333-
elif isinstance(obj, curry) and obj.func is func:
334-
func = '%s.%s' % (modname, funcname)
329+
qualname = getattr(func, '__qualname__', None)
330+
if qualname is None: # pragma: py3 no cover
331+
qualname = getattr(func, '__name__', None)
332+
is_decorated = None
333+
if modname and qualname:
334+
attrs = []
335+
obj = import_module(modname)
336+
for attr in qualname.split('.'):
337+
if isinstance(obj, curry): # pragma: py2 no cover
338+
attrs.append('func')
339+
obj = obj.func
340+
obj = getattr(obj, attr, None)
341+
if obj is None:
342+
break
343+
attrs.append(attr)
344+
if isinstance(obj, curry) and obj.func is func:
345+
is_decorated = obj is self
346+
qualname = '.'.join(attrs)
347+
func = '%s:%s' % (modname, qualname)
335348

336349
# functools.partial objects can't be pickled
337350
userdict = tuple((k, v) for k, v in self.__dict__.items()
338-
if k != '_partial')
339-
state = (type(self), func, self.args, self.keywords, userdict)
351+
if k not in ('_partial', '_sigspec'))
352+
state = (type(self), func, self.args, self.keywords, userdict,
353+
is_decorated)
340354
return (_restore_curry, state)
341355

342356

343-
def _restore_curry(cls, func, args, kwargs, userdict):
357+
def _restore_curry(cls, func, args, kwargs, userdict, is_decorated):
344358
if isinstance(func, str):
345-
modname, funcname = func.rsplit('.', 1)
346-
module = import_module(modname)
347-
func = getattr(module, funcname).func
359+
modname, qualname = func.rsplit(':', 1)
360+
obj = import_module(modname)
361+
for attr in qualname.split('.'):
362+
obj = getattr(obj, attr)
363+
if is_decorated:
364+
return obj
365+
func = obj.func
348366
obj = cls(func, *args, **(kwargs or {}))
349367
obj.__dict__.update(userdict)
350368
return obj

toolz/tests/test_functoolz.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,16 +285,26 @@ def foo(a, b, c=1):
285285
def test_curry_attributes_writable():
286286
def foo(a, b, c=1):
287287
return a + b + c
288-
288+
foo.__qualname__ = 'this.is.foo'
289289
f = curry(foo, 1, c=2)
290+
assert f.__qualname__ == 'this.is.foo'
290291
f.__name__ = 'newname'
291292
f.__doc__ = 'newdoc'
293+
f.__module__ = 'newmodule'
294+
f.__qualname__ = 'newqualname'
292295
assert f.__name__ == 'newname'
293296
assert f.__doc__ == 'newdoc'
297+
assert f.__module__ == 'newmodule'
298+
assert f.__qualname__ == 'newqualname'
294299
if hasattr(f, 'func_name'):
295300
assert f.__name__ == f.func_name
296301

297302

303+
def test_curry_module():
304+
from toolz.curried.exceptions import merge
305+
assert merge.__module__ == 'toolz.curried.exceptions'
306+
307+
298308
def test_curry_comparable():
299309
def foo(a, b, c=1):
300310
return a + b + c

toolz/tests/test_serialization.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from toolz import *
22
import toolz
3+
import toolz.curried.exceptions
34
import pickle
5+
from toolz.compatibility import PY3, PY33, PY34
6+
from toolz.utils import raises
47

58

69
def test_compose():
@@ -55,3 +58,138 @@ def test_flip():
5558
g1 = flip(f)(1)
5659
g2 = pickle.loads(pickle.dumps(g1))
5760
assert g1(2) == g2(2) == f(2, 1)
61+
62+
63+
def test_curried_exceptions():
64+
# This tests a global curried object that isn't defined in toolz.functoolz
65+
merge = pickle.loads(pickle.dumps(toolz.curried.exceptions.merge))
66+
assert merge is toolz.curried.exceptions.merge
67+
68+
69+
@toolz.curry
70+
class GlobalCurried(object):
71+
def __init__(self, x, y):
72+
self.x = x
73+
self.y = y
74+
75+
@toolz.curry
76+
def f1(self, a, b):
77+
return self.x + self.y + a + b
78+
79+
def g1(self):
80+
pass
81+
82+
def __reduce__(self):
83+
"""Allow us to serialize instances of GlobalCurried"""
84+
return (GlobalCurried, (self.x, self.y))
85+
86+
@toolz.curry
87+
class NestedCurried(object):
88+
def __init__(self, x, y):
89+
self.x = x
90+
self.y = y
91+
92+
@toolz.curry
93+
def f2(self, a, b):
94+
return self.x + self.y + a + b
95+
96+
def g2(self):
97+
pass
98+
99+
def __reduce__(self):
100+
"""Allow us to serialize instances of NestedCurried"""
101+
return (GlobalCurried.NestedCurried, (self.x, self.y))
102+
103+
class Nested(object):
104+
def __init__(self, x, y):
105+
self.x = x
106+
self.y = y
107+
108+
@toolz.curry
109+
def f3(self, a, b):
110+
return self.x + self.y + a + b
111+
112+
def g3(self):
113+
pass
114+
115+
116+
def test_curried_qualname():
117+
if not PY3:
118+
return
119+
120+
def preserves_identity(obj):
121+
return pickle.loads(pickle.dumps(obj)) is obj
122+
123+
assert preserves_identity(GlobalCurried)
124+
assert preserves_identity(GlobalCurried.func.f1)
125+
assert preserves_identity(GlobalCurried.func.NestedCurried)
126+
assert preserves_identity(GlobalCurried.func.NestedCurried.func.f2)
127+
assert preserves_identity(GlobalCurried.func.Nested.f3)
128+
129+
global_curried1 = GlobalCurried(1)
130+
global_curried2 = pickle.loads(pickle.dumps(global_curried1))
131+
assert global_curried1 is not global_curried2
132+
assert global_curried1(2).f1(3, 4) == global_curried2(2).f1(3, 4) == 10
133+
134+
global_curried3 = global_curried1(2)
135+
global_curried4 = pickle.loads(pickle.dumps(global_curried3))
136+
assert global_curried3 is not global_curried4
137+
assert global_curried3.f1(3, 4) == global_curried4.f1(3, 4) == 10
138+
139+
func1 = global_curried1(2).f1(3)
140+
func2 = pickle.loads(pickle.dumps(func1))
141+
assert func1 is not func2
142+
assert func1(4) == func2(4) == 10
143+
144+
nested_curried1 = GlobalCurried.func.NestedCurried(1)
145+
nested_curried2 = pickle.loads(pickle.dumps(nested_curried1))
146+
assert nested_curried1 is not nested_curried2
147+
assert nested_curried1(2).f2(3, 4) == nested_curried2(2).f2(3, 4) == 10
148+
149+
# If we add `curry.__getattr__` forwarding, the following tests will pass
150+
151+
# if not PY33 and not PY34:
152+
# assert preserves_identity(GlobalCurried.func.g1)
153+
# assert preserves_identity(GlobalCurried.func.NestedCurried.func.g2)
154+
# assert preserves_identity(GlobalCurried.func.Nested)
155+
# assert preserves_identity(GlobalCurried.func.Nested.g3)
156+
#
157+
# # Rely on curry.__getattr__
158+
# assert preserves_identity(GlobalCurried.f1)
159+
# assert preserves_identity(GlobalCurried.NestedCurried)
160+
# assert preserves_identity(GlobalCurried.NestedCurried.f2)
161+
# assert preserves_identity(GlobalCurried.Nested.f3)
162+
# if not PY33 and not PY34:
163+
# assert preserves_identity(GlobalCurried.g1)
164+
# assert preserves_identity(GlobalCurried.NestedCurried.g2)
165+
# assert preserves_identity(GlobalCurried.Nested)
166+
# assert preserves_identity(GlobalCurried.Nested.g3)
167+
#
168+
# nested_curried3 = nested_curried1(2)
169+
# nested_curried4 = pickle.loads(pickle.dumps(nested_curried3))
170+
# assert nested_curried3 is not nested_curried4
171+
# assert nested_curried3.f2(3, 4) == nested_curried4.f2(3, 4) == 10
172+
#
173+
# func1 = nested_curried1(2).f2(3)
174+
# func2 = pickle.loads(pickle.dumps(func1))
175+
# assert func1 is not func2
176+
# assert func1(4) == func2(4) == 10
177+
#
178+
# if not PY33 and not PY34:
179+
# nested3 = GlobalCurried.func.Nested(1, 2)
180+
# nested4 = pickle.loads(pickle.dumps(nested3))
181+
# assert nested3 is not nested4
182+
# assert nested3.f3(3, 4) == nested4.f3(3, 4) == 10
183+
#
184+
# func1 = nested3.f3(3)
185+
# func2 = pickle.loads(pickle.dumps(func1))
186+
# assert func1 is not func2
187+
# assert func1(4) == func2(4) == 10
188+
189+
190+
def test_curried_bad_qualname():
191+
@toolz.curry
192+
class Bad(object):
193+
__qualname__ = 'toolz.functoolz.not.a.valid.path'
194+
195+
assert raises(pickle.PicklingError, lambda: pickle.dumps(Bad))

0 commit comments

Comments
 (0)