Skip to content
Open
1 change: 1 addition & 0 deletions toolz/curried/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
drop = toolz.curry(toolz.drop)
excepts = toolz.curry(toolz.excepts)
filter = toolz.curry(toolz.filter)
flat = toolz.curry(toolz.flat)
get = toolz.curry(toolz.get)
get_in = toolz.curry(toolz.get_in)
groupby = toolz.curry(toolz.groupby)
Expand Down
14 changes: 13 additions & 1 deletion toolz/itertoolz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
'first', 'second', 'nth', 'last', 'get', 'concat', 'concatv',
'mapcat', 'cons', 'interpose', 'frequencies', 'reduceby', 'iterate',
'sliding_window', 'partition', 'partition_all', 'count', 'pluck',
'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample')
'join', 'tail', 'diff', 'topk', 'peek', 'peekn', 'random_sample',
'flat')


def remove(predicate, seq):
Expand Down Expand Up @@ -1055,3 +1056,14 @@ def random_sample(prob, seq, random_state=None):

random_state = Random(random_state)
return filter(lambda _: random_state.random() < prob, seq)


def flat(level, seq):
""" Flatten a possible nested sequence by n levels """
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fill out this docstring soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool. Don't forget to also point to concat(seq), which flattens a sequence one level.

if level < 0:
raise ValueError("level must be >= 0")
for item in seq:
if level == 0 or not hasattr(item, '__iter__'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably better to have outside the for loop:

if level == 0:
    yield from seq
    return

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That works really well.

yield item
else:
yield from flat(level - 1, item)
17 changes: 16 additions & 1 deletion toolz/tests/test_itertoolz.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import itertools
from itertools import starmap
from toolz.utils import raises
Expand All @@ -13,7 +14,7 @@
reduceby, iterate, accumulate,
sliding_window, count, partition,
partition_all, take_nth, pluck, join,
diff, topk, peek, peekn, random_sample)
diff, topk, peek, peekn, random_sample, flat)
from operator import add, mul


Expand Down Expand Up @@ -547,3 +548,17 @@ def test_random_sample():
assert mk_rsample(b"a") == mk_rsample(u"a")

assert raises(TypeError, lambda: mk_rsample([]))


def test_flat():
seq = [1, 2, 3, 4]
assert list(flat(0, seq)) == seq
assert list(flat(1, seq)) == seq

seq = [1, [2, [3]]]
assert list(flat(0, seq)) == seq
assert list(flat(1, seq)) == [1, 2, [3]]
assert list(flat(2, seq)) == [1, 2, 3]

with pytest.raises(ValueError):
list(flat(-1, seq))