33from __future__ import annotations
44
55import asyncio
6+ import contextlib
67from collections .abc import Callable , Coroutine
78from typing import TYPE_CHECKING
89
@@ -23,11 +24,15 @@ def set_result(self, exec_id: str) -> None:
2324 if not self ._future .done ():
2425 self ._future .set_result (exec_id )
2526
26- def set_exception (self , exception : Exception ) -> None :
27+ def set_exception (self , exception : BaseException ) -> None :
2728 """Set an exception if the batch execution failed."""
2829 if not self ._future .done ():
2930 self ._future .set_exception (exception )
3031
32+ def is_done (self ) -> bool :
33+ """Check if the execution has completed (either with result or exception)."""
34+ return self ._future .done ()
35+
3136 def __await__ (self ):
3237 """Make this awaitable."""
3338 return self ._future .__await__ ()
@@ -84,64 +89,92 @@ async def add(
8489 :param label: Label for the action group
8590 :return: QueuedExecution that resolves to exec_id when batch executes
8691 """
92+ batch_to_execute = None
93+
8794 async with self ._lock :
8895 # If mode or label changes, flush existing queue first
8996 if self ._pending_actions and (
9097 mode != self ._pending_mode or label != self ._pending_label
9198 ):
92- await self ._flush_now ()
99+ batch_to_execute = self ._prepare_flush ()
93100
94101 # Add actions to pending queue
95102 self ._pending_actions .extend (actions )
96103 self ._pending_mode = mode
97104 self ._pending_label = label
98105
99- # Create waiter for this caller
106+ # Create waiter for this caller. This waiter is added to the current
107+ # batch being built, even if we flushed a previous batch above due to
108+ # a mode/label change. This ensures the waiter belongs to the batch
109+ # containing the actions we just added.
100110 waiter = QueuedExecution ()
101111 self ._pending_waiters .append (waiter )
102112
103113 # If we hit max actions, flush immediately
104114 if len (self ._pending_actions ) >= self ._max_actions :
105- await self ._flush_now ()
106- else :
115+ # Prepare the current batch for flushing (which includes the actions
116+ # we just added). If we already flushed due to mode change, this is
117+ # a second batch.
118+ new_batch = self ._prepare_flush ()
119+ # Execute the first batch if it exists, then the second
120+ if batch_to_execute :
121+ await self ._execute_batch (* batch_to_execute )
122+ batch_to_execute = new_batch
123+ elif self ._flush_task is None or self ._flush_task .done ():
107124 # Schedule delayed flush if not already scheduled
108- if self ._flush_task is None or self ._flush_task .done ():
109- self ._flush_task = asyncio .create_task (self ._delayed_flush ())
125+ self ._flush_task = asyncio .create_task (self ._delayed_flush ())
126+
127+ # Execute batch outside the lock if we flushed
128+ if batch_to_execute :
129+ await self ._execute_batch (* batch_to_execute )
110130
111- return waiter
131+ return waiter
112132
113133 async def _delayed_flush (self ) -> None :
114134 """Wait for the delay period, then flush the queue."""
115- await asyncio .sleep (self ._delay )
116- async with self ._lock :
117- if not self ._pending_actions :
118- return
119-
120- # Take snapshot and clear state while holding lock
121- actions = self ._pending_actions
122- mode = self ._pending_mode
123- label = self ._pending_label
124- waiters = self ._pending_waiters
125-
126- self ._pending_actions = []
127- self ._pending_mode = None
128- self ._pending_label = None
129- self ._pending_waiters = []
130- self ._flush_task = None
131-
132- # Execute outside the lock
135+ waiters : list [QueuedExecution ] = []
133136 try :
134- exec_id = await self ._executor (actions , mode , label )
135- for waiter in waiters :
136- waiter .set_result (exec_id )
137- except Exception as exc :
137+ await asyncio .sleep (self ._delay )
138+ async with self ._lock :
139+ if not self ._pending_actions :
140+ return
141+
142+ # Take snapshot and clear state while holding lock
143+ actions = self ._pending_actions
144+ mode = self ._pending_mode
145+ label = self ._pending_label
146+ waiters = self ._pending_waiters
147+
148+ self ._pending_actions = []
149+ self ._pending_mode = None
150+ self ._pending_label = None
151+ self ._pending_waiters = []
152+ self ._flush_task = None
153+
154+ # Execute outside the lock
155+ try :
156+ exec_id = await self ._executor (actions , mode , label )
157+ for waiter in waiters :
158+ waiter .set_result (exec_id )
159+ except Exception as exc :
160+ for waiter in waiters :
161+ waiter .set_exception (exc )
162+ except asyncio .CancelledError as exc :
163+ # Ensure all waiters are notified if this task is cancelled
138164 for waiter in waiters :
139165 waiter .set_exception (exc )
166+ raise
167+
168+ def _prepare_flush (
169+ self ,
170+ ) -> tuple [list [Action ], CommandMode | None , str | None , list [QueuedExecution ]]:
171+ """Prepare a flush by taking snapshot and clearing state (must be called with lock held).
140172
141- async def _flush_now (self ) -> None :
142- """Execute pending actions immediately (must be called with lock held)."""
173+ Returns a tuple of (actions, mode, label, waiters) that should be executed
174+ outside the lock using _execute_batch().
175+ """
143176 if not self ._pending_actions :
144- return
177+ return ([], None , None , [])
145178
146179 # Cancel any pending flush task
147180 if self ._flush_task and not self ._flush_task .done ():
@@ -160,8 +193,19 @@ async def _flush_now(self) -> None:
160193 self ._pending_label = None
161194 self ._pending_waiters = []
162195
163- # Execute the batch (must release lock before calling executor to avoid deadlock)
164- # Note: This is called within a lock context, we'll execute outside
196+ return (actions , mode , label , waiters )
197+
198+ async def _execute_batch (
199+ self ,
200+ actions : list [Action ],
201+ mode : CommandMode | None ,
202+ label : str | None ,
203+ waiters : list [QueuedExecution ],
204+ ) -> None :
205+ """Execute a batch of actions and notify waiters (must be called without lock)."""
206+ if not actions :
207+ return
208+
165209 try :
166210 exec_id = await self ._executor (actions , mode , label )
167211 # Notify all waiters
@@ -173,42 +217,49 @@ async def _flush_now(self) -> None:
173217 waiter .set_exception (exc )
174218 raise
175219
176- async def flush (self ) -> list [ str ] :
220+ async def flush (self ) -> None :
177221 """Force flush all pending actions immediately.
178222
179- :return: List of exec_ids from flushed batches
223+ This method forces the queue to execute any pending batched actions
224+ without waiting for the delay timer. The execution results are delivered
225+ to the corresponding QueuedExecution objects returned by add().
226+
227+ This method is useful for forcing immediate execution without having to
228+ wait for the delay timer to expire.
180229 """
230+ batch_to_execute = None
181231 async with self ._lock :
182- if not self ._pending_actions :
183- return []
184-
185- # Since we can only have one batch pending at a time,
186- # this will return a single exec_id in a list
187- exec_ids : list [str ] = []
232+ if self ._pending_actions :
233+ batch_to_execute = self ._prepare_flush ()
188234
189- try :
190- await self ._flush_now ()
191- # If flush succeeded, we can't actually return the exec_id here
192- # since it's delivered via the waiters. This method is mainly
193- # for forcing a flush, not retrieving results.
194- # Return empty list to indicate flush completed
195- except Exception :
196- # If flush fails, the exception will be propagated to waiters
197- # and also raised here
198- raise
199-
200- return exec_ids
235+ # Execute outside the lock
236+ if batch_to_execute :
237+ await self ._execute_batch (* batch_to_execute )
201238
202239 def get_pending_count (self ) -> int :
203- """Get the number of actions currently waiting in the queue."""
240+ """Get the (approximate) number of actions currently waiting in the queue.
241+
242+ This method does not acquire the internal lock and therefore returns a
243+ best-effort snapshot that may be slightly out of date if the queue is
244+ being modified concurrently by other coroutines.
245+ """
204246 return len (self ._pending_actions )
205247
206248 async def shutdown (self ) -> None :
207249 """Shutdown the queue, flushing any pending actions."""
250+ batch_to_execute = None
208251 async with self ._lock :
209252 if self ._flush_task and not self ._flush_task .done ():
210- self ._flush_task .cancel ()
253+ task = self ._flush_task
254+ task .cancel ()
211255 self ._flush_task = None
256+ # Wait for cancellation to complete
257+ with contextlib .suppress (asyncio .CancelledError ):
258+ await task
212259
213260 if self ._pending_actions :
214- await self ._flush_now ()
261+ batch_to_execute = self ._prepare_flush ()
262+
263+ # Execute outside the lock
264+ if batch_to_execute :
265+ await self ._execute_batch (* batch_to_execute )
0 commit comments