Skip to content

Commit 144837c

Browse files
committed
support for async iterables in coro_fns
1 parent 8ac7613 commit 144837c

File tree

2 files changed

+215
-5
lines changed

2 files changed

+215
-5
lines changed

Lib/asyncio/staggered.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ async def staggered_race(coro_fns, delay, *, loop=None):
3838
Args:
3939
coro_fns: an iterable of coroutine functions, i.e. callables that
4040
return a coroutine object when called. Use ``functools.partial`` or
41-
lambdas to pass arguments.
41+
lambdas to pass arguments. Can also be an async iterable.
4242
4343
delay: amount of time, in seconds, between starting coroutines. If
4444
``None``, the coroutines will run sequentially.
@@ -62,10 +62,19 @@ async def staggered_race(coro_fns, delay, *, loop=None):
6262
coroutine's entry is ``None``.
6363
6464
"""
65-
# TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
65+
# Support for async iterables in coro_fns
66+
try:
67+
# Try to get an async iterator
68+
aiter_coro_fns = aiter(coro_fns)
69+
is_async_iterable = True
70+
enum_coro_fns = None # Not used for async iterables
71+
except TypeError:
72+
# Not an async iterable, use regular iteration
73+
enum_coro_fns = enumerate(coro_fns)
74+
is_async_iterable = False
75+
aiter_coro_fns = None # Not used for regular iterables
6676
loop = loop or events.get_running_loop()
6777
parent_task = tasks.current_task(loop)
68-
enum_coro_fns = enumerate(coro_fns)
6978
winner_result = None
7079
winner_index = None
7180
unhandled_exceptions = []
@@ -106,8 +115,12 @@ async def run_one_coro(ok_to_start, previous_failed) -> None:
106115
await tasks.wait_for(previous_failed.wait(), delay)
107116
# Get the next coroutine to run
108117
try:
109-
this_index, coro_fn = next(enum_coro_fns)
110-
except StopIteration:
118+
if is_async_iterable:
119+
coro_fn = await anext(aiter_coro_fns)
120+
this_index = len(exceptions) # Track index manually for async iterables
121+
else:
122+
this_index, coro_fn = next(enum_coro_fns)
123+
except (StopIteration, StopAsyncIteration):
111124
return
112125
# Start task that will run the next coroutine
113126
this_failed = locks.Event()

Lib/test/test_asyncio/test_staggered.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,200 @@ async def coro_fn():
149149
raise
150150

151151
self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])
152+
153+
async def test_async_iterable_empty(self):
154+
async def empty_async_iterable():
155+
if False:
156+
yield lambda: asyncio.sleep(0)
157+
158+
winner, index, excs = await staggered_race(
159+
empty_async_iterable(),
160+
delay=None,
161+
)
162+
163+
self.assertIs(winner, None)
164+
self.assertIs(index, None)
165+
self.assertEqual(excs, [])
166+
167+
async def test_async_iterable_one_successful(self):
168+
async def async_coro_generator():
169+
async def coro(index):
170+
return f'Async Res: {index}'
171+
172+
yield lambda: coro(0)
173+
yield lambda: coro(1)
174+
175+
winner, index, excs = await staggered_race(
176+
async_coro_generator(),
177+
delay=None,
178+
)
179+
180+
self.assertEqual(winner, 'Async Res: 0')
181+
self.assertEqual(index, 0)
182+
self.assertEqual(excs, [None])
183+
184+
async def test_async_iterable_first_error_second_successful(self):
185+
async def async_coro_generator():
186+
async def coro(index):
187+
if index == 0:
188+
raise ValueError(f'Async Error: {index}')
189+
return f'Async Res: {index}'
190+
191+
yield lambda: coro(0)
192+
yield lambda: coro(1)
193+
194+
winner, index, excs = await staggered_race(
195+
async_coro_generator(),
196+
delay=None,
197+
)
198+
199+
self.assertEqual(winner, 'Async Res: 1')
200+
self.assertEqual(index, 1)
201+
self.assertEqual(len(excs), 2)
202+
self.assertIsInstance(excs[0], ValueError)
203+
self.assertEqual(str(excs[0]), 'Async Error: 0')
204+
self.assertIs(excs[1], None)
205+
206+
async def test_async_iterable_first_timeout_second_successful(self):
207+
async def async_coro_generator():
208+
async def coro(index):
209+
if index == 0:
210+
await asyncio.sleep(10)
211+
return f'Async Res: {index}'
212+
213+
yield lambda: coro(0)
214+
yield lambda: coro(1)
215+
216+
winner, index, excs = await staggered_race(
217+
async_coro_generator(),
218+
delay=0.1,
219+
)
220+
221+
self.assertEqual(winner, 'Async Res: 1')
222+
self.assertEqual(index, 1)
223+
self.assertEqual(len(excs), 2)
224+
self.assertIsInstance(excs[0], asyncio.CancelledError)
225+
self.assertIs(excs[1], None)
226+
227+
async def test_async_iterable_none_successful(self):
228+
async def async_coro_generator():
229+
async def coro(index):
230+
raise ValueError(f'Async Error: {index}')
231+
232+
yield lambda: coro(0)
233+
yield lambda: coro(1)
234+
235+
winner, index, excs = await staggered_race(
236+
async_coro_generator(),
237+
delay=None,
238+
)
239+
240+
self.assertIs(winner, None)
241+
self.assertIs(index, None)
242+
self.assertEqual(len(excs), 2)
243+
self.assertIsInstance(excs[0], ValueError)
244+
self.assertEqual(str(excs[0]), 'Async Error: 0')
245+
self.assertIsInstance(excs[1], ValueError)
246+
self.assertEqual(str(excs[1]), 'Async Error: 1')
247+
248+
async def test_async_iterable_multiple_winners(self):
249+
event = asyncio.Event()
250+
251+
async def async_coro_generator():
252+
async def coro(index):
253+
await event.wait()
254+
return f'Async Index: {index}'
255+
256+
async def do_set():
257+
event.set()
258+
await asyncio.Event().wait()
259+
260+
yield lambda: coro(0)
261+
yield lambda: coro(1)
262+
yield do_set
263+
264+
winner, index, excs = await staggered_race(
265+
async_coro_generator(),
266+
delay=0.1,
267+
)
268+
269+
self.assertEqual(winner, 'Async Index: 0')
270+
self.assertEqual(index, 0)
271+
self.assertEqual(len(excs), 3)
272+
self.assertIsNone(excs[0])
273+
self.assertIsInstance(excs[1], asyncio.CancelledError)
274+
self.assertIsInstance(excs[2], asyncio.CancelledError)
275+
276+
async def test_async_iterable_with_delay(self):
277+
results = []
278+
279+
async def async_coro_generator():
280+
async def coro(index):
281+
results.append(f'Started: {index}')
282+
await asyncio.sleep(0.05)
283+
return f'Result: {index}'
284+
285+
yield lambda: coro(0)
286+
yield lambda: coro(1)
287+
yield lambda: coro(2)
288+
289+
winner, index, excs = await staggered_race(
290+
async_coro_generator(),
291+
delay=0.02,
292+
)
293+
294+
self.assertEqual(winner, 'Result: 0')
295+
self.assertEqual(index, 0)
296+
297+
self.assertGreaterEqual(len(excs), 1)
298+
self.assertIsNone(excs[0])
299+
300+
self.assertIn('Started: 0', results)
301+
302+
async def test_async_iterable_mixed_with_regular(self):
303+
async def coro(index):
304+
return f'Mixed Res: {index}'
305+
306+
winner, index, excs = await staggered_race(
307+
[lambda: coro(0), lambda: coro(1)],
308+
delay=None,
309+
)
310+
311+
self.assertEqual(winner, 'Mixed Res: 0')
312+
self.assertEqual(index, 0)
313+
self.assertEqual(excs, [None])
314+
315+
async def test_async_iterable_cancelled(self):
316+
log = []
317+
318+
async def async_coro_generator():
319+
async def coro_fn():
320+
try:
321+
await asyncio.sleep(0.1)
322+
except asyncio.CancelledError:
323+
log.append("async cancelled")
324+
raise
325+
326+
yield coro_fn
327+
328+
with self.assertRaises(TimeoutError):
329+
async with asyncio.timeout(0.01):
330+
await staggered_race(async_coro_generator(), delay=None)
331+
332+
self.assertListEqual(log, ["async cancelled"])
333+
334+
async def test_async_iterable_stop_async_iteration(self):
335+
async def async_coro_generator():
336+
async def coro():
337+
return "success"
338+
339+
yield lambda: coro()
340+
341+
winner, index, excs = await staggered_race(
342+
async_coro_generator(),
343+
delay=None,
344+
)
345+
346+
self.assertEqual(winner, "success")
347+
self.assertEqual(index, 0)
348+
self.assertEqual(excs, [None])

0 commit comments

Comments
 (0)