44
55import contextlib
66
7+ from . import events
8+ from . import exceptions as exceptions_mod
79from . import locks
810from . import tasks
9- from . import taskgroups
1011
11- class _Done (Exception ):
12- pass
1312
1413async def staggered_race (coro_fns , delay , * , loop = None ):
1514 """Run coroutines with staggered start times and take the first to finish.
@@ -43,6 +42,8 @@ async def staggered_race(coro_fns, delay, *, loop=None):
4342 delay: amount of time, in seconds, between starting coroutines. If
4443 ``None``, the coroutines will run sequentially.
4544
45+ loop: the event loop to use.
46+
4647 Returns:
4748 tuple *(winner_result, winner_index, exceptions)* where
4849
@@ -61,11 +62,36 @@ async def staggered_race(coro_fns, delay, *, loop=None):
6162
6263 """
6364 # TODO: when we have aiter() and anext(), allow async iterables in coro_fns.
65+ loop = loop or events .get_running_loop ()
66+ enum_coro_fns = enumerate (coro_fns )
6467 winner_result = None
6568 winner_index = None
6669 exceptions = []
70+ running_tasks = []
71+
72+ async def run_one_coro (previous_failed ) -> None :
73+ # Wait for the previous task to finish, or for delay seconds
74+ if previous_failed is not None :
75+ with contextlib .suppress (exceptions_mod .TimeoutError ):
76+ # Use asyncio.wait_for() instead of asyncio.wait() here, so
77+ # that if we get cancelled at this point, Event.wait() is also
78+ # cancelled, otherwise there will be a "Task destroyed but it is
79+ # pending" later.
80+ await tasks .wait_for (previous_failed .wait (), delay )
81+ # Get the next coroutine to run
82+ try :
83+ this_index , coro_fn = next (enum_coro_fns )
84+ except StopIteration :
85+ return
86+ # Start task that will run the next coroutine
87+ this_failed = locks .Event ()
88+ next_task = loop .create_task (run_one_coro (this_failed ))
89+ running_tasks .append (next_task )
90+ assert len (running_tasks ) == this_index + 2
91+ # Prepare place to put this coroutine's exceptions if not won
92+ exceptions .append (None )
93+ assert len (exceptions ) == this_index + 1
6794
68- async def run_one_coro (this_index , coro_fn , this_failed ):
6995 try :
7096 result = await coro_fn ()
7197 except (SystemExit , KeyboardInterrupt ):
@@ -79,23 +105,34 @@ async def run_one_coro(this_index, coro_fn, this_failed):
79105 assert winner_index is None
80106 winner_index = this_index
81107 winner_result = result
82- raise _Done
83-
108+ # Cancel all other tasks. We take care to not cancel the current
109+ # task as well. If we do so, then since there is no `await` after
110+ # here and CancelledError are usually thrown at one, we will
111+ # encounter a curious corner case where the current task will end
112+ # up as done() == True, cancelled() == False, exception() ==
113+ # asyncio.CancelledError. This behavior is specified in
114+ # https://bugs.python.org/issue30048
115+ for i , t in enumerate (running_tasks ):
116+ if i != this_index :
117+ t .cancel ()
118+
119+ first_task = loop .create_task (run_one_coro (None ))
120+ running_tasks .append (first_task )
84121 try :
85- tg = taskgroups . TaskGroup ()
86- # Intentionally override the loop in the TaskGroup to avoid
87- # using the running loop, preserving backwards compatibility
88- # TaskGroup only starts using `_loop` after `__aenter__`
89- # so overriding it here is safe.
90- tg . _loop = loop
91- async with tg :
92- for this_index , coro_fn in enumerate ( coro_fns ):
93- this_failed = locks . Event ()
94- exceptions . append ( None )
95- tg . create_task ( run_one_coro ( this_index , coro_fn , this_failed ))
96- with contextlib . suppress ( TimeoutError ):
97- await tasks . wait_for ( this_failed . wait (), delay )
98- except* _Done :
99- pass
100-
101- return winner_result , winner_index , exceptions
122+ # Wait for a growing list of tasks to all finish: poor man's version of
123+ # curio's TaskGroup or trio's nursery
124+ done_count = 0
125+ while done_count != len ( running_tasks ):
126+ done , _ = await tasks . wait ( running_tasks )
127+ done_count = len ( done )
128+ # If run_one_coro raises an unhandled exception, it's probably a
129+ # programming error, and I want to see it.
130+ if __debug__ :
131+ for d in done :
132+ if d . done () and not d . cancelled () and d . exception ():
133+ raise d . exception ()
134+ return winner_result , winner_index , exceptions
135+ finally :
136+ # Make sure no tasks are left running if we leave this function
137+ for t in running_tasks :
138+ t . cancel ()
0 commit comments