diff --git a/asyncstdlib/itertools.py b/asyncstdlib/itertools.py index 006da37..febd591 100644 --- a/asyncstdlib/itertools.py +++ b/asyncstdlib/itertools.py @@ -24,7 +24,6 @@ from ._core import ( ScopedIter, awaitify as _awaitify, - Sentinel, borrow as _borrow, ) from .builtins import ( @@ -64,9 +63,6 @@ async def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]: yield item -__ACCUMULATE_SENTINEL = Sentinel("") - - async def add(x: ADD, y: ADD) -> ADD: """The default reduction of :py:func:`~.accumulate`""" return x + y @@ -78,7 +74,7 @@ async def accumulate( Callable[[Any, Any], Any], Callable[[Any, Any], Awaitable[Any]] ] = add, *, - initial: Any = __ACCUMULATE_SENTINEL, + initial: Any = None, ) -> AsyncIterator[Any]: """ An :term:`asynchronous iterator` on the running reduction of ``iterable`` @@ -105,11 +101,7 @@ async def accumulate(iterable, function, *, initial): """ async with ScopedIter(iterable) as item_iter: try: - value = ( - initial - if initial is not __ACCUMULATE_SENTINEL - else await anext(item_iter) - ) + value = initial if initial is not None else await anext(item_iter) except StopAsyncIteration: raise TypeError( "accumulate() of empty sequence with no initial value" diff --git a/asyncstdlib/itertools.pyi b/asyncstdlib/itertools.pyi index f65ff6d..e561d57 100644 --- a/asyncstdlib/itertools.pyi +++ b/asyncstdlib/itertools.pyi @@ -16,13 +16,17 @@ from ._typing import AnyIterable, ADD, T, T1, T2, T3, T4, T5 def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]: ... @overload -def accumulate(iterable: AnyIterable[ADD]) -> AsyncIterator[ADD]: ... +def accumulate( + iterable: AnyIterable[ADD], *, initial: None = ... +) -> AsyncIterator[ADD]: ... @overload def accumulate(iterable: AnyIterable[ADD], *, initial: ADD) -> AsyncIterator[ADD]: ... @overload def accumulate( iterable: AnyIterable[T], function: Callable[[T, T], T] | Callable[[T, T], Awaitable[T]], + *, + initial: None = ..., ) -> AsyncIterator[T]: ... @overload def accumulate( diff --git a/docs/source/api/itertools.rst b/docs/source/api/itertools.rst index 733332f..c61b750 100644 --- a/docs/source/api/itertools.rst +++ b/docs/source/api/itertools.rst @@ -68,6 +68,10 @@ Iterator transforming .. autofunction:: accumulate(iterable: (async) iter T, function: (T, T) → (await) T = add [, initial: T]) :async-for: :T + .. versionchanged:: 3.13.2 + + ``initial=None`` means no initial value is assumed. + .. autofunction:: starmap(function: (*A) → (await) T, iterable: (async) iter (A, ...)) :async-for: :T diff --git a/unittests/test_itertools.py b/unittests/test_itertools.py index 5f88e96..82e4e7a 100644 --- a/unittests/test_itertools.py +++ b/unittests/test_itertools.py @@ -34,6 +34,7 @@ async def reduction(x, y): @sync async def test_accumulate_default(): + """Test the default function of accumulate""" for itertype in (asyncify, list): assert await a.list(a.accumulate(itertype([0, 1]))) == list( itertools.accumulate([0, 1]) @@ -53,10 +54,21 @@ async def test_accumulate_default(): @sync async def test_accumulate_misuse(): + """Test wrong arguments to accumulate""" with pytest.raises(TypeError): assert await a.list(a.accumulate([])) +@sync +async def test_accumulate_initial(): + """Test the `initial` argument to accumulate""" + assert ( + await a.list(a.accumulate(asyncify([1, 2, 3]), initial=None)) + == await a.list(a.accumulate(asyncify([1, 2, 3]))) + == list(itertools.accumulate([1, 2, 3], initial=None)) + ) + + batched_cases = [ (range(10), 2, [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]), (range(10), 3, [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)]),