Skip to content

Commit cb745e8

Browse files
Merge pull request #155 from mjpieters/refacored_groupby
Refactor groupby to use classes
2 parents ce2d0d4 + 5d23d88 commit cb745e8

File tree

2 files changed

+130
-56
lines changed

2 files changed

+130
-56
lines changed

asyncstdlib/itertools.py

Lines changed: 123 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
Iterable,
1414
Iterator,
1515
Tuple,
16+
cast,
1617
overload,
1718
AsyncGenerator,
1819
)
1920
from collections import deque
2021

21-
from ._typing import ACloseable, T, AnyIterable, ADD
22+
from ._typing import ACloseable, R, T, AnyIterable, ADD
2223
from ._utility import public_module
2324
from ._core import (
2425
ScopedIter,
@@ -35,6 +36,7 @@
3536
)
3637

3738
S = TypeVar("S")
39+
T_co = TypeVar("T_co", covariant=True)
3840

3941

4042
async def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]:
@@ -542,12 +544,86 @@ async def identity(x: T) -> T:
542544
return x
543545

544546

545-
async def groupby(
546-
iterable: AnyIterable[Any],
547-
key: Optional[
548-
Union[Callable[[Any], Any], Callable[[Any], Awaitable[Any]]]
549-
] = identity,
550-
) -> AsyncIterator[Tuple[Any, AsyncIterator[Any]]]:
547+
class _GroupByState(Generic[R, T_co]):
548+
"""Internal state for the groupby iterator, shared between the parent and groups"""
549+
550+
__slots__ = (
551+
"_iterator",
552+
"_key_func",
553+
"_current_value",
554+
"target_key",
555+
"current_key",
556+
"current_group",
557+
)
558+
559+
_sentinel = cast(T_co, object())
560+
561+
def __init__(
562+
self, iterator: AsyncIterator[T_co], key_func: Callable[[T_co], Awaitable[R]]
563+
):
564+
self._iterator = iterator
565+
self._key_func = key_func
566+
self._current_value = self._sentinel
567+
568+
async def step(self) -> None:
569+
# can raise StopAsyncIteration
570+
value = await anext(self._iterator)
571+
key = await self._key_func(value)
572+
self._current_value, self.current_key = value, key
573+
574+
async def maybe_step(self) -> None:
575+
"""Only step if there is no current value"""
576+
if self._current_value is self._sentinel:
577+
await self.step()
578+
579+
def consume_value(self) -> T_co:
580+
"""Return the current value, after removing it from the current state"""
581+
value, self._current_value = self._current_value, self._sentinel
582+
return value
583+
584+
async def aclose(self) -> None:
585+
"""Close the underlying iterator"""
586+
if (group := self.current_group) is not None:
587+
await group.aclose()
588+
if isinstance(self._iterator, ACloseable):
589+
await self._iterator.aclose()
590+
591+
592+
class _Grouper(AsyncIterator[T_co], Generic[R, T_co]):
593+
"""A single group iterator, part of a series of groups yielded by groupby"""
594+
595+
__slots__ = ("_target_key", "_state")
596+
597+
def __init__(self, target_key: R, state: "_GroupByState[R, T_co]") -> None:
598+
self._target_key = target_key
599+
self._state = state
600+
601+
async def __anext__(self) -> T_co:
602+
state = self._state
603+
if state.current_group is not self:
604+
raise StopAsyncIteration
605+
606+
await state.maybe_step()
607+
if self._target_key != state.current_key:
608+
raise StopAsyncIteration
609+
610+
return state.consume_value()
611+
612+
async def aclose(self) -> None:
613+
"""Close the group iterator
614+
615+
Note: this does _not_ close the underlying groupby managed iterator;
616+
closing a single group shouldn't affect other groups in the series.
617+
618+
"""
619+
state = self._state
620+
if state.current_group is not self:
621+
return
622+
state.current_group = None
623+
624+
625+
@public_module(__name__, "groupby")
626+
class GroupBy(AsyncIterator[Tuple[R, AsyncIterator[T_co]]], Generic[R, T_co]):
551627
"""
552628
Create an async iterator over consecutive keys and groups from the (async) iterable
553629
@@ -567,49 +643,45 @@ async def groupby(
567643
required up-front for sorting, this loses the advantage of asynchronous,
568644
lazy iteration and evaluation.
569645
"""
570-
# whether the current group was exhausted and the next begins already
571-
exhausted = False
572-
# `current_*`: buffer for key/value the current group peeked beyond its end
573-
current_key = current_value = nothing = object()
574-
make_key: Callable[[Any], Awaitable[Any]] = (
575-
_awaitify(key) if key is not None else identity # type: ignore
576-
)
577-
async with ScopedIter(iterable) as async_iter:
578-
# fast-forward mode: advance to the next group
579-
async def seek_group() -> AsyncIterator[Any]:
580-
nonlocal current_value, current_key, exhausted
581-
# Note: `value` always ends up being some T
582-
# - value is something: we can never unset it
583-
# - value is `nothing`: the previous group was not exhausted,
584-
# and we scan at least one new value
585-
value: Any = current_value
586-
if not exhausted:
587-
previous_key = current_key
588-
while previous_key == current_key:
589-
value = await anext(async_iter)
590-
current_key = await make_key(value)
591-
current_value = nothing
592-
exhausted = False
593-
return group(current_key, value=value)
594-
595-
# the lazy iterable of all items with the same key
596-
async def group(desired_key: Any, value: Any) -> AsyncIterator[Any]:
597-
nonlocal current_value, current_key, exhausted
598-
yield value
599-
async for value in async_iter:
600-
next_key: Any = await make_key(value)
601-
if next_key == desired_key:
602-
yield value
603-
else:
604-
exhausted = True
605-
current_value = value
606-
current_key = next_key
607-
break
608646

647+
__slots__ = ("_state",)
648+
649+
def __init__(
650+
self,
651+
iterable: AnyIterable[T_co],
652+
key: Optional[
653+
Union[Callable[[T_co], R], Callable[[T_co], Awaitable[R]]]
654+
] = None,
655+
):
656+
key_func = (
657+
cast(Callable[[T_co], Awaitable[R]], identity)
658+
if key is None
659+
else _awaitify(key)
660+
)
661+
self._state = _GroupByState(aiter(iterable), key_func)
662+
663+
async def __anext__(self) -> Tuple[R, AsyncIterator[T_co]]:
664+
state = self._state
665+
# disable the last group to avoid concurrency
666+
# issues.
667+
state.current_group = None
668+
await state.maybe_step()
609669
try:
610-
while True:
611-
next_group = await seek_group()
612-
async with ScopedIter(next_group) as scoped_group:
613-
yield current_key, scoped_group
614-
except StopAsyncIteration:
615-
return
670+
target_key = state.target_key
671+
except AttributeError:
672+
# no target key yet, skip scanning
673+
pass
674+
else:
675+
# scan to the next group
676+
while state.current_key == target_key:
677+
await state.step()
678+
679+
state.target_key = current_key = state.current_key
680+
state.current_group = group = _Grouper(current_key, state)
681+
return (current_key, group)
682+
683+
async def aclose(self) -> None:
684+
await self._state.aclose()
685+
686+
687+
groupby = GroupBy

asyncstdlib/itertools.pyi

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,15 @@ def zip_longest(
223223
fillvalue: F,
224224
) -> AsyncIterator[tuple[T | F, ...]]: ...
225225

226-
K = TypeVar("K")
226+
K_co = TypeVar("K_co", covariant=True)
227+
T_co = TypeVar("T_co", covariant=True)
227228

228229
@overload
229230
def groupby(
230-
iterable: AnyIterable[T], key: None = ...
231-
) -> AsyncIterator[tuple[T, AsyncIterator[T]]]: ...
231+
iterable: AnyIterable[T_co], key: None = ...
232+
) -> AsyncIterator[tuple[T_co, AsyncIterator[T_co]]]: ...
232233
@overload
233234
def groupby(
234-
iterable: AnyIterable[T], key: Callable[[T], Awaitable[K]] | Callable[[T], K]
235-
) -> AsyncIterator[tuple[K, AsyncIterator[T]]]: ...
235+
iterable: AnyIterable[T_co],
236+
key: Callable[[T_co], Awaitable[K_co]] | Callable[[T], K_co],
237+
) -> AsyncIterator[tuple[K_co, AsyncIterator[T_co]]]: ...

0 commit comments

Comments
 (0)