Skip to content

Commit ef6792a

Browse files
committed
Improve waterfall behavior
- no more empty chunk callbacks - logging of errored and cancelled callbacks
1 parent a6c2d70 commit ef6792a

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

src/async_utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__author__ = "Michael Hall"
1010
__license__ = "Apache-2.0"
1111
__copyright__ = "Copyright 2020-Present Michael Hall"
12-
__version__ = "2025.02.18b"
12+
__version__ = "2025.03.08b"
1313

1414
import os
1515
import sys

src/async_utils/waterfall.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,15 @@
1616
from __future__ import annotations
1717

1818
import asyncio
19+
import logging
1920
import time
2021
from collections.abc import Callable, Coroutine, Sequence
22+
from functools import partial
2123

2224
from . import _typings as t
2325

26+
log = logging.getLogger(__name__)
27+
2428
__all__ = ("Waterfall",)
2529

2630
type AnyCoro = Coroutine[t.Any, t.Any, t.Any]
@@ -111,11 +115,22 @@ def put(self, item: T) -> None:
111115
raise RuntimeError(msg)
112116
self.queue.put_nowait(item)
113117

118+
def _user_done_callback(self, num: int, future: asyncio.Future[t.Any]):
119+
if future.cancelled():
120+
log.warning("Callback cancelled due to timeout")
121+
elif exc := future.exception():
122+
log.error("Exception in user callback", exc_info=exc)
123+
124+
for _ in range(num):
125+
self.queue.task_done()
126+
114127
async def _dispatch_loop(self) -> None:
115128
if (loop := self._event_loop) is None:
116129
loop = self._event_loop = asyncio.get_running_loop()
130+
131+
tasks: set[asyncio.Task[object]] = set()
117132
try:
118-
tasks: set[asyncio.Task[object]] = set()
133+
tasks = set()
119134
while self._alive:
120135
queue_items: list[T] = []
121136
iter_start = time.monotonic()
@@ -127,20 +142,22 @@ async def _dispatch_loop(self) -> None:
127142
continue
128143
else:
129144
queue_items.append(n)
130-
if len(queue_items) >= self.max_quantity:
131-
break
132145

133-
if not queue_items:
134-
continue
146+
if len(queue_items) >= self.max_quantity:
147+
break
135148

136-
num_items = len(queue_items)
149+
if not queue_items:
150+
continue
137151

152+
# get len before callback may mutate list
153+
num_items = len(queue_items)
138154
t = loop.create_task(self.callback(queue_items))
155+
del queue_items
156+
139157
tasks.add(t)
140158
t.add_done_callback(tasks.discard)
141-
142-
for _ in range(num_items):
143-
self.queue.task_done()
159+
cb = partial(self._user_done_callback, num_items)
160+
t.add_done_callback(cb)
144161

145162
finally:
146163
f = loop.create_task(self._finalize())
@@ -151,7 +168,15 @@ async def _dispatch_loop(self) -> None:
151168
# PYUPDATE: remove this block at python 3.13 minimum
152169
else:
153170
set_name("waterfall.finalizer")
154-
await asyncio.wait_for(f, timeout=self.max_wait_finalize)
171+
g = asyncio.gather(f, *tasks, return_exceptions=True)
172+
try:
173+
await asyncio.wait_for(g, timeout=self.max_wait_finalize)
174+
except TimeoutError:
175+
# GatheringFuture.cancel doesnt work here
176+
# due to return_exceptions=True
177+
for t in (f, *tasks):
178+
if not t.done():
179+
t.cancel()
155180

156181
async def _finalize(self) -> None:
157182
loop = self._event_loop
@@ -187,15 +212,15 @@ async def _finalize(self) -> None:
187212
remaining_items[p : p + self.max_quantity]
188213
for p in range(0, num_remaining, self.max_quantity)
189214
):
215+
chunk_len = len(chunk)
190216
fut = loop.create_task(self.callback(chunk))
191-
fut.add_done_callback(remaining_tasks.discard)
192217
remaining_tasks.add(fut)
218+
fut.add_done_callback(remaining_tasks.discard)
219+
cb = partial(self._user_done_callback, chunk_len)
220+
fut.add_done_callback(cb)
193221

194222
timeout = self.max_wait_finalize
195223
_done, pending = await asyncio.wait(remaining_tasks, timeout=timeout)
196224

197225
for task in pending:
198226
task.cancel()
199-
200-
for _ in range(num_remaining):
201-
self.queue.task_done()

0 commit comments

Comments
 (0)