Skip to content

Commit 5d3e587

Browse files
committed
Make exit behavior in threaded_loop a user choice
1 parent 7eef365 commit 5d3e587

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

src/async_utils/bg_loop.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20+
import concurrent.futures as cf
2021
import threading
2122
from collections.abc import Awaitable, Generator
22-
from concurrent.futures import Future
2323
from contextlib import contextmanager
2424

25+
from . import _typings as t
26+
2527
type _FutureLike[T] = asyncio.Future[T] | Awaitable[T]
2628

2729
__all__ = ["threaded_loop"]
@@ -30,8 +32,9 @@
3032
class LoopWrapper:
3133
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
3234
self._loop = loop
35+
self._futures: set[cf.Future[t.Any]] = set()
3336

34-
def schedule[T](self, coro: _FutureLike[T], /) -> Future[T]:
37+
def schedule[T](self, coro: _FutureLike[T], /) -> cf.Future[T]:
3538
"""Schedule a coroutine to run on the wrapped event loop.
3639
3740
Parameters
@@ -44,7 +47,10 @@ def schedule[T](self, coro: _FutureLike[T], /) -> Future[T]:
4447
asyncio.Future:
4548
A Future wrapping the result.
4649
"""
47-
return asyncio.run_coroutine_threadsafe(coro, self._loop)
50+
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
51+
self._futures.add(future)
52+
future.add_done_callback(self._futures.discard)
53+
return future
4854

4955
async def run[T](self, coro: _FutureLike[T], /) -> T:
5056
"""Schedule and await a coroutine to run on the background loop.
@@ -59,8 +65,31 @@ async def run[T](self, coro: _FutureLike[T], /) -> T:
5965
The returned value of the coroutine run in the background
6066
"""
6167
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
68+
self._futures.add(future)
69+
future.add_done_callback(self._futures.discard)
6270
return await asyncio.wrap_future(future)
6371

72+
def cancel_all(self) -> None:
73+
"""Cancel all remaining futures."""
74+
for future in self._futures:
75+
future.cancel()
76+
77+
def wait_sync(self, timeout: float | None) -> bool:
78+
"""Wait for remaining futures.
79+
80+
Parameters
81+
----------
82+
timeout: float | None
83+
Optionally, how long to wait for
84+
85+
Returns
86+
-------
87+
bool
88+
True if all futures finished, otherwise False
89+
"""
90+
_done, pending = cf.wait(self._futures, timeout=timeout)
91+
return not pending
92+
6493

6594
def run_forever(
6695
loop: asyncio.AbstractEventLoop,
@@ -101,15 +130,18 @@ def run_forever(
101130

102131
@contextmanager
103132
def threaded_loop(
104-
*, use_eager_task_factory: bool = True
133+
*, use_eager_task_factory: bool = True, wait_on_exit: bool = True
105134
) -> Generator[LoopWrapper, None, None]:
106135
"""Create and use a managed event loop in a backround thread.
107136
108137
Starts an event loop on a background thread,
109138
and yields an object with scheduling methods for interacting with
110139
the loop.
111140
112-
Loop is scheduled for shutdown, and thread is joined at contextmanager exit
141+
At context manager exit, if wait_on_exit is True (default), then
142+
the context manager waits on the remaining futures. When it is done, or
143+
if that parameter is False, the loop is event loop is scheduled for shutdown
144+
and the thread is joined.
113145
114146
Yields
115147
------
@@ -118,15 +150,19 @@ def threaded_loop(
118150
"""
119151
loop = asyncio.new_event_loop()
120152
thread = None
153+
wrapper = None
121154
try:
122155
thread = threading.Thread(
123156
target=run_forever,
124157
args=(loop,),
125158
kwargs={"use_eager_task_factory": use_eager_task_factory},
126159
)
127160
thread.start()
128-
yield LoopWrapper(loop)
161+
wrapper = LoopWrapper(loop)
162+
yield wrapper
129163
finally:
164+
if wrapper and wait_on_exit:
165+
wrapper.wait_sync(None)
130166
loop.call_soon_threadsafe(loop.stop)
131167
if thread:
132168
thread.join()

0 commit comments

Comments
 (0)