Skip to content

Commit ffac221

Browse files
committed
Set __module__ on curried objects. This can fix pickling global curried objects.
Should fix dask/distributed#725 This was an oversight that should have been handled in #326. We only tested objects defined in `toolz.functoolz`, so pickling accidentally worked, because they were in the same module as `curry`.
1 parent 7fb83c8 commit ffac221

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

toolz/functoolz.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __init__(self, *args, **kwargs):
200200

201201
self.__doc__ = getattr(func, '__doc__', None)
202202
self.__name__ = getattr(func, '__name__', '<curry>')
203+
self.__module__ = getattr(func, '__module__', None)
203204
self._sigspec = None
204205
self._has_unknown_args = None
205206

toolz/tests/test_functoolz.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,19 @@ def foo(a, b, c=1):
289289
f = curry(foo, 1, c=2)
290290
f.__name__ = 'newname'
291291
f.__doc__ = 'newdoc'
292+
f.__module__ = 'newmodule'
292293
assert f.__name__ == 'newname'
293294
assert f.__doc__ == 'newdoc'
295+
assert f.__module__ == 'newmodule'
294296
if hasattr(f, 'func_name'):
295297
assert f.__name__ == f.func_name
296298

297299

300+
def test_curry_module():
301+
from toolz.curried.exceptions import merge
302+
assert merge.__module__ == 'toolz.curried.exceptions'
303+
304+
298305
def test_curry_comparable():
299306
def foo(a, b, c=1):
300307
return a + b + c

toolz/tests/test_serialization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from toolz import *
22
import toolz
3+
import toolz.curried.exceptions
34
import pickle
45

56

@@ -55,3 +56,9 @@ def test_flip():
5556
g1 = flip(f)(1)
5657
g2 = pickle.loads(pickle.dumps(g1))
5758
assert g1(2) == g2(2) == f(2, 1)
59+
60+
61+
def test_curried_exceptions():
62+
# This tests a global curried object that isn't defined in toolz.functoolz
63+
merge = pickle.loads(pickle.dumps(toolz.curried.exceptions.merge))
64+
assert merge is toolz.curried.exceptions.merge

0 commit comments

Comments
 (0)