1616from __future__ import annotations
1717
1818import asyncio
19+ import logging
1920import time
2021from collections .abc import Callable , Coroutine , Sequence
22+ from functools import partial
2123
2224from . import _typings as t
2325
26+ log = logging .getLogger (__name__ )
27+
2428__all__ = ("Waterfall" ,)
2529
2630type AnyCoro = Coroutine [t .Any , t .Any , t .Any ]
@@ -111,11 +115,22 @@ def put(self, item: T) -> None:
111115 raise RuntimeError (msg )
112116 self .queue .put_nowait (item )
113117
118+ def _user_done_callback (self , num : int , future : asyncio .Future [t .Any ]):
119+ if future .cancelled ():
120+ log .warning ("Callback cancelled due to timeout" )
121+ elif exc := future .exception ():
122+ log .error ("Exception in user callback" , exc_info = exc )
123+
124+ for _ in range (num ):
125+ self .queue .task_done ()
126+
114127 async def _dispatch_loop (self ) -> None :
115128 if (loop := self ._event_loop ) is None :
116129 loop = self ._event_loop = asyncio .get_running_loop ()
130+
131+ tasks : set [asyncio .Task [object ]] = set ()
117132 try :
118- tasks : set [ asyncio . Task [ object ]] = set ()
133+ tasks = set ()
119134 while self ._alive :
120135 queue_items : list [T ] = []
121136 iter_start = time .monotonic ()
@@ -127,20 +142,22 @@ async def _dispatch_loop(self) -> None:
127142 continue
128143 else :
129144 queue_items .append (n )
130- if len (queue_items ) >= self .max_quantity :
131- break
132145
133- if not queue_items :
134- continue
146+ if len ( queue_items ) >= self . max_quantity :
147+ break
135148
136- num_items = len (queue_items )
149+ if not queue_items :
150+ continue
137151
152+ # get len before callback may mutate list
153+ num_items = len (queue_items )
138154 t = loop .create_task (self .callback (queue_items ))
155+ del queue_items
156+
139157 tasks .add (t )
140158 t .add_done_callback (tasks .discard )
141-
142- for _ in range (num_items ):
143- self .queue .task_done ()
159+ cb = partial (self ._user_done_callback , num_items )
160+ t .add_done_callback (cb )
144161
145162 finally :
146163 f = loop .create_task (self ._finalize ())
@@ -151,7 +168,15 @@ async def _dispatch_loop(self) -> None:
151168 # PYUPDATE: remove this block at python 3.13 minimum
152169 else :
153170 set_name ("waterfall.finalizer" )
154- await asyncio .wait_for (f , timeout = self .max_wait_finalize )
171+ g = asyncio .gather (f , * tasks , return_exceptions = True )
172+ try :
173+ await asyncio .wait_for (g , timeout = self .max_wait_finalize )
174+ except TimeoutError :
175+ # GatheringFuture.cancel doesnt work here
176+ # due to return_exceptions=True
177+ for t in (f , * tasks ):
178+ if not t .done ():
179+ t .cancel ()
155180
156181 async def _finalize (self ) -> None :
157182 loop = self ._event_loop
@@ -187,15 +212,15 @@ async def _finalize(self) -> None:
187212 remaining_items [p : p + self .max_quantity ]
188213 for p in range (0 , num_remaining , self .max_quantity )
189214 ):
215+ chunk_len = len (chunk )
190216 fut = loop .create_task (self .callback (chunk ))
191- fut .add_done_callback (remaining_tasks .discard )
192217 remaining_tasks .add (fut )
218+ fut .add_done_callback (remaining_tasks .discard )
219+ cb = partial (self ._user_done_callback , chunk_len )
220+ fut .add_done_callback (cb )
193221
194222 timeout = self .max_wait_finalize
195223 _done , pending = await asyncio .wait (remaining_tasks , timeout = timeout )
196224
197225 for task in pending :
198226 task .cancel ()
199-
200- for _ in range (num_remaining ):
201- self .queue .task_done ()
0 commit comments