Skip to content

Commit 20d8aef

Browse files
committed
Explicitly import and curry objects in toolz.curried for nicer IDE integration. #323
We could either have a script to create the curried/__init__.py file or test that the file is correct. I opted for the latter that uses our original logic to construct curried namespace, and the test includes very informative error messages. The purpose of this is to remove the red squiggles in editors that try to inspect Python files and warn on mispelled names. This is a serious quality-of-life issue for some people.
1 parent fe56571 commit 20d8aef

File tree

3 files changed

+120
-25
lines changed

3 files changed

+120
-25
lines changed

toolz/curried/__init__.py

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,78 @@
2323
See Also:
2424
toolz.functoolz.curry
2525
"""
26-
from . import exceptions
27-
from . import operator
2826
import toolz
27+
from . import operator
28+
from toolz import (
29+
comp,
30+
complement,
31+
compose,
32+
concat,
33+
concatv,
34+
count,
35+
curry,
36+
diff,
37+
dissoc,
38+
first,
39+
flip,
40+
frequencies,
41+
identity,
42+
interleave,
43+
isdistinct,
44+
isiterable,
45+
juxt,
46+
last,
47+
memoize,
48+
merge_sorted,
49+
peek,
50+
pipe,
51+
second,
52+
thread_first,
53+
thread_last,
54+
)
55+
from .exceptions import merge, merge_with
2956

57+
accumulate = toolz.curry(toolz.accumulate)
58+
assoc = toolz.curry(toolz.assoc)
59+
assoc_in = toolz.curry(toolz.assoc_in)
60+
cons = toolz.curry(toolz.cons)
61+
countby = toolz.curry(toolz.countby)
62+
do = toolz.curry(toolz.do)
63+
drop = toolz.curry(toolz.drop)
64+
excepts = toolz.curry(toolz.excepts)
65+
filter = toolz.curry(toolz.filter)
66+
get = toolz.curry(toolz.get)
67+
get_in = toolz.curry(toolz.get_in)
68+
groupby = toolz.curry(toolz.groupby)
69+
interpose = toolz.curry(toolz.interpose)
70+
itemfilter = toolz.curry(toolz.itemfilter)
71+
itemmap = toolz.curry(toolz.itemmap)
72+
iterate = toolz.curry(toolz.iterate)
73+
join = toolz.curry(toolz.join)
74+
keyfilter = toolz.curry(toolz.keyfilter)
75+
keymap = toolz.curry(toolz.keymap)
76+
map = toolz.curry(toolz.map)
77+
mapcat = toolz.curry(toolz.mapcat)
78+
nth = toolz.curry(toolz.nth)
79+
partial = toolz.curry(toolz.partial)
80+
partition = toolz.curry(toolz.partition)
81+
partition_all = toolz.curry(toolz.partition_all)
82+
partitionby = toolz.curry(toolz.partitionby)
83+
pluck = toolz.curry(toolz.pluck)
84+
random_sample = toolz.curry(toolz.random_sample)
85+
reduce = toolz.curry(toolz.reduce)
86+
reduceby = toolz.curry(toolz.reduceby)
87+
remove = toolz.curry(toolz.remove)
88+
sliding_window = toolz.curry(toolz.sliding_window)
89+
sorted = toolz.curry(toolz.sorted)
90+
tail = toolz.curry(toolz.tail)
91+
take = toolz.curry(toolz.take)
92+
take_nth = toolz.curry(toolz.take_nth)
93+
topk = toolz.curry(toolz.topk)
94+
unique = toolz.curry(toolz.unique)
95+
update_in = toolz.curry(toolz.update_in)
96+
valfilter = toolz.curry(toolz.valfilter)
97+
valmap = toolz.curry(toolz.valmap)
3098

31-
def _should_curry(func):
32-
if not callable(func) or isinstance(func, toolz.curry):
33-
return False
34-
nargs = toolz.functoolz.num_required_args(func)
35-
if nargs is None or nargs > 1:
36-
return True
37-
return nargs == 1 and toolz.functoolz.has_keywords(func)
38-
39-
40-
def _curry_namespace(ns):
41-
return dict(
42-
(name, toolz.curry(f) if _should_curry(f) else f)
43-
for name, f in ns.items() if '__' not in name
44-
)
45-
46-
47-
locals().update(toolz.merge(
48-
_curry_namespace(vars(toolz)),
49-
_curry_namespace(vars(exceptions)),
50-
))
51-
52-
# Clean up the namespace.
53-
del _should_curry
5499
del exceptions
55100
del toolz

toolz/curried/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@ def merge_with(func, d, *dicts, **kwargs):
1313
def merge(d, *dicts, **kwargs):
1414
return toolz.merge(d, *dicts, **kwargs)
1515

16+
1617
merge_with.__doc__ = toolz.merge_with.__doc__
1718
merge.__doc__ = toolz.merge.__doc__

toolz/tests/test_curried.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import toolz.curried
33
from toolz.curried import (take, first, second, sorted, merge_with, reduce,
44
merge, operator as cop)
5+
from toolz.compatibility import import_module
56
from collections import defaultdict
67
from operator import add
78

@@ -62,3 +63,51 @@ def test_curried_operator():
6263

6364
# Make sure this isn't totally empty.
6465
assert len(set(vars(cop)) & set(['add', 'sub', 'mul'])) == 3
66+
67+
68+
def test_curried_namespace():
69+
exceptions = import_module('toolz.curried.exceptions')
70+
namespace = {}
71+
72+
def should_curry(func):
73+
if not callable(func) or isinstance(func, toolz.curry):
74+
return False
75+
nargs = toolz.functoolz.num_required_args(func)
76+
if nargs is None or nargs > 1:
77+
return True
78+
return nargs == 1 and toolz.functoolz.has_keywords(func)
79+
80+
81+
def curry_namespace(ns):
82+
return dict(
83+
(name, toolz.curry(f) if should_curry(f) else f)
84+
for name, f in ns.items() if '__' not in name
85+
)
86+
87+
from_toolz = curry_namespace(vars(toolz))
88+
from_exceptions = curry_namespace(vars(exceptions))
89+
namespace.update(toolz.merge(from_toolz, from_exceptions))
90+
91+
namespace = toolz.valfilter(callable, namespace)
92+
curried_namespace = toolz.valfilter(callable, toolz.curried.__dict__)
93+
94+
if namespace != curried_namespace:
95+
missing = set(namespace) - set(curried_namespace)
96+
if missing:
97+
raise AssertionError('There are missing functions in toolz.curried:\n %s'
98+
% ' \n'.join(sorted(missing)))
99+
extra = set(curried_namespace) - set(namespace)
100+
if extra:
101+
raise AssertionError('There are extra functions in toolz.curried:\n %s'
102+
% ' \n'.join(sorted(extra)))
103+
unequal = toolz.merge_with(list, namespace, curried_namespace)
104+
unequal = toolz.valfilter(lambda x: x[0] != x[1], unequal)
105+
messages = []
106+
for name, (orig_func, auto_func) in sorted(unequal.items()):
107+
if name in from_exceptions:
108+
messages.append('%s should come from toolz.curried.exceptions' % name)
109+
elif should_curry(getattr(toolz, name)):
110+
messages.append('%s should be curried from toolz' % name)
111+
else:
112+
messages.append('%s should come from toolz and NOT be curried' % name)
113+
raise AssertionError('\n'.join(messages))

0 commit comments

Comments
 (0)