Skip to content

Commit 2bbd80a

Browse files
committed
Add new "count" function similar to itertools.count
1 parent 21cf4dc commit 2bbd80a

File tree

2 files changed

+268
-0
lines changed

2 files changed

+268
-0
lines changed

domdf_python_tools/iterative.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
# 3rd party
5353
from natsort import natsorted, ns
54+
from typing_extensions import final
5455

5556
# this package
5657
from domdf_python_tools.utils import magnitude
@@ -70,9 +71,12 @@
7071
"extend",
7172
"extend_with",
7273
"extend_with_none",
74+
"count",
75+
"AnyNum",
7376
]
7477

7578
_T = TypeVar("_T")
79+
AnyNum = TypeVar("AnyNum", float, complex)
7680

7781

7882
def chunks(l: Sequence[_T], n: int) -> Iterator[Sequence[_T]]:
@@ -414,3 +418,69 @@ def extend_with_none(sequence: Iterable[_T], minsize: int) -> Sequence[Optional[
414418
filler: Sequence[Optional[_T]] = [None] * max(0, minsize - len(output))
415419

416420
return tuple((*output, *filler))
421+
422+
423+
def count(start: AnyNum = 0, step: AnyNum = 1) -> Iterator[AnyNum]:
424+
"""
425+
Make an iterator which returns evenly spaced values starting with number ``start``.
426+
427+
Often used as an argument to :func:`map` to generate consecutive data points.
428+
Also, used with :func:`zip` to add sequence numbers.
429+
430+
.. versionadded:: 2.7.0
431+
432+
:param start:
433+
:param step: The step between values.
434+
435+
.. seealso::
436+
437+
:func:`itertools.count`.
438+
The difference is that this returns more exact floats, whereas the values from :func:`itertools.count` drift.
439+
"""
440+
441+
if not isinstance(start, (int, float, complex)):
442+
raise TypeError("a number is required")
443+
if not isinstance(step, (int, float, complex)):
444+
raise TypeError("a number is required")
445+
446+
# count(10) --> 10 11 12 13 14 ...
447+
# count(2.5, 0.5) -> 2.5 3.0 3.5 ...
448+
449+
pos: int = 0
450+
451+
def get_next():
452+
if pos:
453+
return start + (step * pos)
454+
else:
455+
return start
456+
457+
@final
458+
class count(Iterator[AnyNum]):
459+
460+
def __next__(self):
461+
nonlocal pos
462+
463+
val = get_next()
464+
pos += 1
465+
466+
return val
467+
468+
def __iter__(self):
469+
return self
470+
471+
if isinstance(step, int) and step == 1:
472+
473+
def __repr__(self):
474+
return f"{self.__class__.__name__}({get_next()})"
475+
else:
476+
477+
def __repr__(self):
478+
return f"{self.__class__.__name__}{get_next(), step}"
479+
480+
def __init_subclass__(cls, **kwargs):
481+
raise TypeError("type 'domdf_python_tools.iterative.count' is not an acceptable base type")
482+
483+
count.__qualname__ = count.__name__ = "count"
484+
485+
return count() # type: ignore
486+

tests/test_iterative.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,19 @@
66
77
"""
88

9+
# test_count, test_count_with_stride and pickletest
10+
# adapted from https://github.com/python/cpython/blob/master/Lib/test/test_itertools.py
11+
# Licensed under the Python Software Foundation License Version 2.
12+
# Copyright © 2001-2021 Python Software Foundation. All rights reserved.
13+
# Copyright © 2000 BeOpen.com. All rights reserved.
14+
# Copyright © 1995-2000 Corporation for National Research Initiatives. All rights reserved.
15+
# Copyright © 1991-1995 Stichting Mathematisch Centrum. All rights reserved.
16+
#
17+
918
# stdlib
19+
import pickle
20+
import sys
21+
from itertools import islice
1022
from random import shuffle
1123
from types import GeneratorType
1224
from typing import List, Tuple
@@ -20,6 +32,7 @@
2032
from domdf_python_tools.iterative import (
2133
Len,
2234
chunks,
35+
count,
2336
double_chain,
2437
extend,
2538
extend_with,
@@ -273,3 +286,188 @@ def test_extend_with_none():
273286
def test_extend_with_int():
274287
expects = ('a', 'b', 'c', 'd', 'e', 'f', 'g', 0, 0, 0)
275288
assert tuple(extend_with("abcdefg", 10, 0)) == expects
289+
290+
291+
def lzip(*args):
292+
return list(zip(*args))
293+
294+
295+
def take(n, seq):
296+
"""
297+
Convenience function for partially consuming a long of infinite iterable
298+
"""
299+
300+
return list(islice(seq, n))
301+
302+
303+
def test_count():
304+
assert lzip("abc", count()) == [('a', 0), ('b', 1), ('c', 2)]
305+
assert lzip("abc", count(3)) == [('a', 3), ('b', 4), ('c', 5)]
306+
assert take(2, lzip("abc", count(3))) == [('a', 3), ('b', 4)]
307+
assert take(2, zip("abc", count(-1))) == [('a', -1), ('b', 0)]
308+
assert take(2, zip("abc", count(-3))) == [('a', -3), ('b', -2)]
309+
310+
with pytest.raises(TypeError, match=r"count\(\) takes from 0 to 2 positional arguments but 3 were given"):
311+
count(2, 3, 4) # type: ignore
312+
313+
with pytest.raises(TypeError, match="a number is required"):
314+
count('a') # type: ignore
315+
316+
assert take(10, count(sys.maxsize - 5)) == list(range(sys.maxsize - 5, sys.maxsize + 5))
317+
assert take(10, count(-sys.maxsize - 5)) == list(range(-sys.maxsize - 5, -sys.maxsize + 5))
318+
assert take(3, count(3.25)) == [3.25, 4.25, 5.25]
319+
assert take(3, count(3.25 - 4j)) == [3.25 - 4j, 4.25 - 4j, 5.25 - 4j]
320+
321+
BIGINT = 1 << 1000
322+
assert take(3, count(BIGINT)) == [BIGINT, BIGINT + 1, BIGINT + 2]
323+
324+
c = count(3)
325+
assert repr(c) == "count(3)"
326+
next(c)
327+
assert repr(c) == "count(4)"
328+
c = count(-9)
329+
assert repr(c) == "count(-9)"
330+
next(c)
331+
assert next(c) == -8
332+
333+
assert repr(count(10.25)) == "count(10.25)"
334+
assert repr(count(10.0)) == "count(10.0)"
335+
assert type(next(count(10.0))) == float
336+
337+
for i in (-sys.maxsize - 5, -sys.maxsize + 5, -10, -1, 0, 10, sys.maxsize - 5, sys.maxsize + 5):
338+
# Test repr
339+
r1 = repr(count(i))
340+
r2 = "count(%r)".__mod__(i)
341+
assert r1 == r2
342+
343+
# # check copy, deepcopy, pickle
344+
# for value in -3, 3, sys.maxsize - 5, sys.maxsize + 5:
345+
# c = count(value)
346+
# assert next(copy.copy(c)) == value
347+
# assert next(copy.deepcopy(c)) == value
348+
# for proto in range(pickle.HIGHEST_PROTOCOL + 1):
349+
# pickletest(proto, count(value))
350+
351+
# check proper internal error handling for large "step' sizes
352+
count(1, sys.maxsize + 5)
353+
sys.exc_info()
354+
355+
356+
def test_count_with_stride():
357+
assert lzip("abc", count(2, 3)) == [('a', 2), ('b', 5), ('c', 8)]
358+
assert lzip("abc", count(start=2, step=3)) == [('a', 2), ('b', 5), ('c', 8)]
359+
assert lzip("abc", count(step=-1)) == [('a', 0), ('b', -1), ('c', -2)]
360+
361+
with pytest.raises(TypeError, match="a number is required"):
362+
count('a', 'b') # type: ignore
363+
364+
with pytest.raises(TypeError, match="a number is required"):
365+
count(5, 'b') # type: ignore
366+
367+
assert lzip("abc", count(2, 0)) == [('a', 2), ('b', 2), ('c', 2)]
368+
assert lzip("abc", count(2, 1)) == [('a', 2), ('b', 3), ('c', 4)]
369+
assert lzip("abc", count(2, 3)) == [('a', 2), ('b', 5), ('c', 8)]
370+
assert take(20, count(sys.maxsize - 15, 3)) == take(20, range(sys.maxsize - 15, sys.maxsize + 100, 3))
371+
assert take(20, count(-sys.maxsize - 15, 3)) == take(20, range(-sys.maxsize - 15, -sys.maxsize + 100, 3))
372+
assert take(3, count(10, sys.maxsize + 5)) == list(range(10, 10 + 3 * (sys.maxsize + 5), sys.maxsize + 5))
373+
assert take(3, count(2, 1.25)) == [2, 3.25, 4.5]
374+
assert take(3, count(2, 3.25 - 4j)) == [2, 5.25 - 4j, 8.5 - 8j]
375+
376+
BIGINT = 1 << 1000
377+
assert take(3, count(step=BIGINT)) == [0, BIGINT, 2 * BIGINT]
378+
assert repr(take(3, count(10, 2.5))) == repr([10, 12.5, 15.0])
379+
380+
c = count(3, 5)
381+
assert repr(c) == "count(3, 5)"
382+
next(c)
383+
assert repr(c) == "count(8, 5)"
384+
c = count(-9, 0)
385+
assert repr(c) == "count(-9, 0)"
386+
next(c)
387+
assert repr(c) == "count(-9, 0)"
388+
c = count(-9, -3)
389+
assert repr(c) == "count(-9, -3)"
390+
next(c)
391+
assert repr(c) == "count(-12, -3)"
392+
assert repr(c) == "count(-12, -3)"
393+
394+
assert repr(count(10.5, 1.25)) == "count(10.5, 1.25)"
395+
assert repr(count(10.5, 1)) == "count(10.5)" # suppress step=1 when it's an int
396+
assert repr(count(10.5, 1.00)) == "count(10.5, 1.0)" # do show float values like 1.0
397+
assert repr(count(10, 1.00)) == "count(10, 1.0)"
398+
399+
c = count(10, 1.0)
400+
assert type(next(c)) == int
401+
assert type(next(c)) == float
402+
403+
for i in (-sys.maxsize - 5, -sys.maxsize + 5, -10, -1, 0, 10, sys.maxsize - 5, sys.maxsize + 5):
404+
for j in (-sys.maxsize - 5, -sys.maxsize + 5, -10, -1, 0, 1, 10, sys.maxsize - 5, sys.maxsize + 5):
405+
# Test repr
406+
r1 = repr(count(i, j))
407+
408+
if j == 1:
409+
r2 = ("count(%r)" % i)
410+
else:
411+
r2 = (f'count({i!r}, {j!r})')
412+
assert r1 == r2
413+
414+
# for proto in range(pickle.HIGHEST_PROTOCOL + 1):
415+
# pickletest(proto, count(i, j))
416+
417+
418+
def pickletest(protocol, it, stop=4, take=1, compare=None):
419+
"""
420+
Test that an iterator is the same after pickling, also when part-consumed
421+
"""
422+
423+
def expand(it, i=0):
424+
# Recursively expand iterables, within sensible bounds
425+
if i > 10:
426+
raise RuntimeError("infinite recursion encountered")
427+
if isinstance(it, str):
428+
return it
429+
try:
430+
l = list(islice(it, stop))
431+
except TypeError:
432+
return it # can't expand it
433+
return [expand(e, i + 1) for e in l]
434+
435+
# Test the initial copy against the original
436+
dump = pickle.dumps(it, protocol)
437+
i2 = pickle.loads(dump)
438+
assert type(it) == type(i2)
439+
a, b = expand(it), expand(i2)
440+
assert a == b
441+
if compare:
442+
c = expand(compare)
443+
assert a == c
444+
445+
# Take from the copy, and create another copy and compare them.
446+
i3 = pickle.loads(dump)
447+
took = 0
448+
try:
449+
for i in range(take):
450+
next(i3)
451+
took += 1
452+
except StopIteration:
453+
pass # in case there is less data than 'take'
454+
455+
dump = pickle.dumps(i3, protocol)
456+
i4 = pickle.loads(dump)
457+
a, b = expand(i3), expand(i4)
458+
assert a == b
459+
if compare:
460+
c = expand(compare[took:])
461+
assert a == c
462+
463+
464+
def test_subclassing_count():
465+
CountType = type(count(1))
466+
467+
with pytest.raises(
468+
TypeError,
469+
match="type 'domdf_python_tools.iterative.count' is not an acceptable base type",
470+
):
471+
472+
class MyCount(CountType): # type: ignore
473+
pass

0 commit comments

Comments
 (0)