1313 Iterable ,
1414 Iterator ,
1515 Tuple ,
16+ cast ,
1617 overload ,
1718 AsyncGenerator ,
1819)
1920from collections import deque
2021
21- from ._typing import ACloseable , T , AnyIterable , ADD
22+ from ._typing import ACloseable , R , T , AnyIterable , ADD
2223from ._utility import public_module
2324from ._core import (
2425 ScopedIter ,
3536)
3637
3738S = TypeVar ("S" )
39+ T_co = TypeVar ("T_co" , covariant = True )
3840
3941
4042async def cycle (iterable : AnyIterable [T ]) -> AsyncIterator [T ]:
@@ -542,12 +544,86 @@ async def identity(x: T) -> T:
542544 return x
543545
544546
545- async def groupby (
546- iterable : AnyIterable [Any ],
547- key : Optional [
548- Union [Callable [[Any ], Any ], Callable [[Any ], Awaitable [Any ]]]
549- ] = identity ,
550- ) -> AsyncIterator [Tuple [Any , AsyncIterator [Any ]]]:
547+ class _GroupByState (Generic [R , T_co ]):
548+ """Internal state for the groupby iterator, shared between the parent and groups"""
549+
550+ __slots__ = (
551+ "_iterator" ,
552+ "_key_func" ,
553+ "_current_value" ,
554+ "target_key" ,
555+ "current_key" ,
556+ "current_group" ,
557+ )
558+
559+ _sentinel = cast (T_co , object ())
560+
561+ def __init__ (
562+ self , iterator : AsyncIterator [T_co ], key_func : Callable [[T_co ], Awaitable [R ]]
563+ ):
564+ self ._iterator = iterator
565+ self ._key_func = key_func
566+ self ._current_value = self ._sentinel
567+
568+ async def step (self ) -> None :
569+ # can raise StopAsyncIteration
570+ value = await anext (self ._iterator )
571+ key = await self ._key_func (value )
572+ self ._current_value , self .current_key = value , key
573+
574+ async def maybe_step (self ) -> None :
575+ """Only step if there is no current value"""
576+ if self ._current_value is self ._sentinel :
577+ await self .step ()
578+
579+ def consume_value (self ) -> T_co :
580+ """Return the current value, after removing it from the current state"""
581+ value , self ._current_value = self ._current_value , self ._sentinel
582+ return value
583+
584+ async def aclose (self ) -> None :
585+ """Close the underlying iterator"""
586+ if (group := self .current_group ) is not None :
587+ await group .aclose ()
588+ if isinstance (self ._iterator , ACloseable ):
589+ await self ._iterator .aclose ()
590+
591+
592+ class _Grouper (AsyncIterator [T_co ], Generic [R , T_co ]):
593+ """A single group iterator, part of a series of groups yielded by groupby"""
594+
595+ __slots__ = ("_target_key" , "_state" )
596+
597+ def __init__ (self , target_key : R , state : "_GroupByState[R, T_co]" ) -> None :
598+ self ._target_key = target_key
599+ self ._state = state
600+
601+ async def __anext__ (self ) -> T_co :
602+ state = self ._state
603+ if state .current_group is not self :
604+ raise StopAsyncIteration
605+
606+ await state .maybe_step ()
607+ if self ._target_key != state .current_key :
608+ raise StopAsyncIteration
609+
610+ return state .consume_value ()
611+
612+ async def aclose (self ) -> None :
613+ """Close the group iterator
614+
615+ Note: this does _not_ close the underlying groupby managed iterator;
616+ closing a single group shouldn't affect other groups in the series.
617+
618+ """
619+ state = self ._state
620+ if state .current_group is not self :
621+ return
622+ state .current_group = None
623+
624+
625+ @public_module (__name__ , "groupby" )
626+ class GroupBy (AsyncIterator [Tuple [R , AsyncIterator [T_co ]]], Generic [R , T_co ]):
551627 """
552628 Create an async iterator over consecutive keys and groups from the (async) iterable
553629
@@ -567,49 +643,45 @@ async def groupby(
567643 required up-front for sorting, this loses the advantage of asynchronous,
568644 lazy iteration and evaluation.
569645 """
570- # whether the current group was exhausted and the next begins already
571- exhausted = False
572- # `current_*`: buffer for key/value the current group peeked beyond its end
573- current_key = current_value = nothing = object ()
574- make_key : Callable [[Any ], Awaitable [Any ]] = (
575- _awaitify (key ) if key is not None else identity # type: ignore
576- )
577- async with ScopedIter (iterable ) as async_iter :
578- # fast-forward mode: advance to the next group
579- async def seek_group () -> AsyncIterator [Any ]:
580- nonlocal current_value , current_key , exhausted
581- # Note: `value` always ends up being some T
582- # - value is something: we can never unset it
583- # - value is `nothing`: the previous group was not exhausted,
584- # and we scan at least one new value
585- value : Any = current_value
586- if not exhausted :
587- previous_key = current_key
588- while previous_key == current_key :
589- value = await anext (async_iter )
590- current_key = await make_key (value )
591- current_value = nothing
592- exhausted = False
593- return group (current_key , value = value )
594-
595- # the lazy iterable of all items with the same key
596- async def group (desired_key : Any , value : Any ) -> AsyncIterator [Any ]:
597- nonlocal current_value , current_key , exhausted
598- yield value
599- async for value in async_iter :
600- next_key : Any = await make_key (value )
601- if next_key == desired_key :
602- yield value
603- else :
604- exhausted = True
605- current_value = value
606- current_key = next_key
607- break
608646
647+ __slots__ = ("_state" ,)
648+
649+ def __init__ (
650+ self ,
651+ iterable : AnyIterable [T_co ],
652+ key : Optional [
653+ Union [Callable [[T_co ], R ], Callable [[T_co ], Awaitable [R ]]]
654+ ] = None ,
655+ ):
656+ key_func = (
657+ cast (Callable [[T_co ], Awaitable [R ]], identity )
658+ if key is None
659+ else _awaitify (key )
660+ )
661+ self ._state = _GroupByState (aiter (iterable ), key_func )
662+
663+ async def __anext__ (self ) -> Tuple [R , AsyncIterator [T_co ]]:
664+ state = self ._state
665+ # disable the last group to avoid concurrency
666+ # issues.
667+ state .current_group = None
668+ await state .maybe_step ()
609669 try :
610- while True :
611- next_group = await seek_group ()
612- async with ScopedIter (next_group ) as scoped_group :
613- yield current_key , scoped_group
614- except StopAsyncIteration :
615- return
670+ target_key = state .target_key
671+ except AttributeError :
672+ # no target key yet, skip scanning
673+ pass
674+ else :
675+ # scan to the next group
676+ while state .current_key == target_key :
677+ await state .step ()
678+
679+ state .target_key = current_key = state .current_key
680+ state .current_group = group = _Grouper (current_key , state )
681+ return (current_key , group )
682+
683+ async def aclose (self ) -> None :
684+ await self ._state .aclose ()
685+
686+
687+ groupby = GroupBy
0 commit comments