Skip to content

Commit 0a03ee9

Browse files
authored
Various simplifications, use comprehensions, helps typing. (#193)
* Various simplifications * inline * simplify * revert `if not initial` * revert inline if else * revert next() instead of a loop
1 parent e367dc9 commit 0a03ee9

File tree

4 files changed

+40
-72
lines changed

4 files changed

+40
-72
lines changed

functional/io.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,7 @@ def __iter__(self):
6565
errors=self.errors,
6666
newline=self.newline,
6767
) as file_content:
68-
for line in file_content:
69-
yield line
68+
yield from file_content
7069

7170
def read(self):
7271
# pylint: disable=no-member
@@ -148,14 +147,12 @@ def __iter__(self):
148147
errors=self.errors,
149148
newline=self.newline,
150149
) as file_content:
151-
for line in file_content:
152-
yield line
150+
yield from file_content
153151
else:
154152
with gzip.open(
155153
self.path, mode=self.mode, compresslevel=self.compresslevel
156154
) as file_content:
157-
for line in file_content:
158-
yield line
155+
yield from file_content
159156

160157
def read(self):
161158
with gzip.GzipFile(self.path, compresslevel=self.compresslevel) as gz_file:
@@ -204,8 +201,7 @@ def __iter__(self):
204201
errors=self.errors,
205202
newline=self.newline,
206203
) as file_content:
207-
for line in file_content:
208-
yield line
204+
yield from file_content
209205

210206
def read(self):
211207
with bz2.open(
@@ -265,8 +261,7 @@ def __iter__(self):
265261
errors=self.errors,
266262
newline=self.newline,
267263
) as file_content:
268-
for line in file_content:
269-
yield line
264+
yield from file_content
270265

271266
def read(self):
272267
with lzma.open(
@@ -290,14 +285,12 @@ def read(self):
290285
def get_read_function(filename: str, disable_compression: bool):
291286
if disable_compression:
292287
return ReusableFile
293-
else:
294-
with open(filename, "rb") as f:
295-
start_bytes = f.read(N_COMPRESSION_CHECK_BYTES)
296-
for cls in COMPRESSION_CLASSES:
297-
if cls.is_compressed(start_bytes):
298-
return cls
299-
300-
return ReusableFile
288+
with open(filename, "rb") as f:
289+
start_bytes = f.read(N_COMPRESSION_CHECK_BYTES)
290+
for cls in COMPRESSION_CLASSES:
291+
if cls.is_compressed(start_bytes):
292+
return cls
293+
return ReusableFile
301294

302295

303296
def universal_write_open(

functional/pipeline.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,9 @@ def _transform(self, *transforms):
195195
:param transform: transform to apply or list of transforms to apply
196196
:return: transformed sequence
197197
"""
198-
sequence = None
198+
sequence = self
199199
for transform in transforms:
200-
if sequence:
201-
sequence = Sequence(sequence, transform=transform, no_wrap=self.no_wrap)
202-
else:
203-
sequence = Sequence(self, transform=transform, no_wrap=self.no_wrap)
200+
sequence = Sequence(sequence, transform=transform, no_wrap=self.no_wrap)
204201
return sequence
205202

206203
@property
@@ -631,11 +628,7 @@ def count(self, func):
631628
:param func: predicate to count elements on
632629
:return: count of elements that satisfy predicate
633630
"""
634-
n = 0
635-
for element in self:
636-
if func(element):
637-
n += 1
638-
return n
631+
return sum(bool(func(element)) for element in self)
639632

640633
def len(self):
641634
"""
@@ -726,10 +719,7 @@ def exists(self, func):
726719
:param func: existence check function
727720
:return: True if any element satisfies func
728721
"""
729-
for element in self:
730-
if func(element):
731-
return True
732-
return False
722+
return any(func(element) for element in self)
733723

734724
def for_all(self, func):
735725
"""
@@ -744,10 +734,7 @@ def for_all(self, func):
744734
:param func: function to check truth value of all elements with
745735
:return: True if all elements make func evaluate to True
746736
"""
747-
for element in self:
748-
if not func(element):
749-
return False
750-
return True
737+
return all(func(element) for element in self)
751738

752739
def max(self):
753740
"""
@@ -872,10 +859,7 @@ def find(self, func):
872859
:param func: function to find with
873860
:return: first element to satisfy func or None
874861
"""
875-
for element in self:
876-
if func(element):
877-
return element
878-
return None
862+
return next((element for element in self if func(element)), None)
879863

880864
def flatten(self):
881865
"""
@@ -1479,9 +1463,7 @@ def to_dict(self, default=None):
14791463
value and used for collections.defaultdict
14801464
:return: dictionary from sequence of (Key, Value) elements
14811465
"""
1482-
dictionary = {}
1483-
for e in self.sequence:
1484-
dictionary[e[0]] = e[1]
1466+
dictionary = dict(self.sequence)
14851467
if default is None:
14861468
return dictionary
14871469
else:
@@ -1882,6 +1864,7 @@ def make_set(it):
18821864
"""
18831865
if func is None:
18841866
return partial(extend, aslist=aslist, final=final, name=name, parallel=parallel)
1867+
assert func is not None # this is for mypy
18851868

18861869
@wraps(func)
18871870
def wrapper(self, *args, **kwargs):

functional/test/test_functional.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_repr(self):
6565
self.assertEqual(repr(l), repr(self.seq(l)))
6666

6767
def test_lineage_name(self):
68-
f = lambda x: x
68+
f = lambda x: x # noqa: E731
6969
self.assertEqual(f.__name__, name(f))
7070
f = "test"
7171
self.assertEqual("test", name(f))
@@ -294,9 +294,8 @@ def test_drop_right(self):
294294

295295
def test_drop_while(self):
296296
l = [1, 2, 3, 4, 5, 6, 7, 8]
297-
f = lambda x: x < 4
298297
expect = [4, 5, 6, 7, 8]
299-
result = self.seq(l).drop_while(f)
298+
result = self.seq(l).drop_while(lambda x: x < 4)
300299
self.assertIteratorEqual(expect, result)
301300
self.assert_type(result)
302301

@@ -311,9 +310,8 @@ def test_take(self):
311310

312311
def test_take_while(self):
313312
l = [1, 2, 3, 4, 5, 6, 7, 8]
314-
f = lambda x: x < 4
315313
expect = [1, 2, 3]
316-
result = self.seq(l).take_while(f)
314+
result = self.seq(l).take_while(lambda x: x < 4)
317315
self.assertIteratorEqual(result, expect)
318316
self.assert_type(result)
319317

@@ -342,18 +340,16 @@ def test_symmetric_difference(self):
342340
self.assertSetEqual(result.set(), set(expect))
343341

344342
def test_map(self):
345-
f = lambda x: x * 2
346343
l = [1, 2, 0, 5]
347344
expect = [2, 4, 0, 10]
348-
result = self.seq(l).map(f)
345+
result = self.seq(l).map(lambda x: x * 2)
349346
self.assertIteratorEqual(expect, result)
350347
self.assert_type(result)
351348

352349
def test_select(self):
353-
f = lambda x: x * 2
354350
l = [1, 2, 0, 5]
355351
expect = [2, 4, 0, 10]
356-
result = self.seq(l).select(f)
352+
result = self.seq(l).select(lambda x: x * 2)
357353
self.assertIteratorEqual(expect, result)
358354
self.assert_type(result)
359355

@@ -369,16 +365,17 @@ def test_starmap(self):
369365
self.assert_type(result)
370366

371367
def test_filter(self):
372-
f = lambda x: x > 0
373368
l = [0, -1, 5, 10]
374369
expect = [5, 10]
375370
s = self.seq(l)
376-
result = s.filter(f)
371+
result = s.filter(lambda x: x > 0)
377372
self.assertIteratorEqual(expect, result)
378373
self.assert_type(result)
379374

380375
def test_where(self):
381-
f = lambda x: x > 0
376+
def f(x):
377+
return x > 0
378+
382379
l = [0, -1, 5, 10]
383380
expect = [5, 10]
384381
s = self.seq(l)
@@ -1013,7 +1010,7 @@ def test_cache(self):
10131010
if self.seq is pseq:
10141011
raise self.skipTest("pseq doesn't support functions with side-effects")
10151012
calls = []
1016-
func = lambda x: calls.append(x)
1013+
func = calls.append
10171014
result = self.seq(1, 2, 3).map(func).cache().map(lambda x: x).to_list()
10181015
self.assertEqual(len(calls), 3)
10191016
self.assertEqual(result, [None, None, None])

functional/util.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import collections
21
import math
2+
from collections.abc import Iterable
33
from functools import reduce
44
from itertools import chain, count, islice, takewhile
55
from multiprocessing import Pool, cpu_count
@@ -49,7 +49,7 @@ def is_namedtuple(val):
4949
bases = val_type.__bases__
5050
if len(bases) != 1 or bases[0] != tuple:
5151
return False
52-
fields = getattr(val_type, "_fields", None)
52+
fields = getattr(val_type, "_fields")
5353
return all(isinstance(n, str) for n in fields)
5454

5555

@@ -69,7 +69,7 @@ def identity(arg):
6969

7070
def is_iterable(val):
7171
"""
72-
Check if val is not a list, but is a collections.Iterable type. This is used to determine
72+
Check if val is not a list, but is a Iterable type. This is used to determine
7373
when list() should be called on val
7474
7575
>>> l = [1, 2]
@@ -79,19 +79,15 @@ def is_iterable(val):
7979
True
8080
8181
:param val: value to check
82-
:return: True if it is not a list, but is a collections.Iterable
82+
:return: True if it is not a list, but is a Iterable
8383
"""
84-
if isinstance(val, list):
85-
return False
86-
return isinstance(val, collections.abc.Iterable)
84+
return not isinstance(val, list) and isinstance(val, Iterable)
8785

8886

89-
def is_tabulatable(val):
90-
if is_primitive(val):
91-
return False
92-
if is_iterable(val) or is_namedtuple(val) or isinstance(val, list):
93-
return True
94-
return False
87+
def is_tabulatable(val: object) -> bool:
88+
return not is_primitive(val) and (
89+
is_iterable(val) or is_namedtuple(val) or isinstance(val, list)
90+
)
9591

9692

9793
def split_every(parts, iterable):
@@ -117,7 +113,7 @@ def unpack(packed):
117113
"""
118114
func, args = serializer.loads(packed)
119115
result = func(*args)
120-
if isinstance(result, collections.abc.Iterable):
116+
if isinstance(result, Iterable):
121117
return list(result)
122118
return None
123119

@@ -164,8 +160,7 @@ def lazy_parallelize(func, result, processes=None, partition_size=None):
164160
with Pool(processes=processes) as pool:
165161
partitions = split_every(partition_size, iter(result))
166162
packed_partitions = (pack(func, (partition,)) for partition in partitions)
167-
for pool_result in pool.imap(unpack, packed_partitions):
168-
yield pool_result
163+
yield from pool.imap(unpack, packed_partitions)
169164

170165

171166
def compute_partition_size(result, processes):

0 commit comments

Comments
 (0)