2020 Coroutine ,
2121 Callable ,
2222 Generator ,
23+ Self ,
2324 TypeAlias ,
2425 TypeVar ,
2526)
4546
4647
4748class AsyncTaskManager :
48- def __init__ (self , * , on_activation : Callable [[], Any ] | None ) -> None :
49+ def __init__ (self , * , on_activation : Callable [[], Any ] | None = None ) -> None :
4950 self ._activated = False
5051 self ._event_loop : asyncio .AbstractEventLoop | None = None
5152 self ._on_activation = on_activation
5253 self ._task_queue : asyncio .Queue [Callable [[], Awaitable [Any ]]] = asyncio .Queue ()
5354 self ._terminate = asyncio .Event ()
5455 self ._terminated = asyncio .Event ()
56+ # For the case where the task manager is run via its context manager
57+ self ._tm_started = asyncio .Event ()
58+ self ._tm_task : asyncio .Task [Any ] | None = None
59+
60+ ACTIVATION_TIMEOUT = 5 # Just starts an async task, should be fast
61+ TERMINATION_TIMEOUT = 20 # May have to shut down TCP links
5562
5663 @property
5764 def activated (self ) -> bool :
@@ -65,6 +72,20 @@ def active(self) -> bool:
6572 and not self ._terminated .is_set ()
6673 )
6774
75+ async def __aenter__ (self ) -> Self :
76+ # Handle reentrancy the same way files do:
77+ # allow nested use as a CM, but close on the first exit
78+ if self ._tm_task is None :
79+ self ._tm_task = asyncio .create_task (self .run_until_terminated ())
80+ with move_on_after (self .ACTIVATION_TIMEOUT ):
81+ await self ._tm_started .wait ()
82+ return self
83+
84+ async def __aexit__ (self , * args : Any ) -> None :
85+ await self .request_termination ()
86+ with move_on_after (self .TERMINATION_TIMEOUT ):
87+ await self ._terminated .wait ()
88+
6889 def check_running_in_task_loop (self , * , allow_inactive : bool = False ) -> bool :
6990 """Returns if running in this manager's event loop, raises RuntimeError otherwise."""
7091 this_loop = self ._event_loop
@@ -138,6 +159,7 @@ def _init_event_loop(self) -> None:
138159 notify = self ._on_activation
139160 if notify is not None :
140161 notify ()
162+ self ._tm_started .set ()
141163
142164 async def run_until_terminated (
143165 self , func : Callable [[], Coroutine [Any , Any , Any ]] | None = None
@@ -218,9 +240,6 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle:
218240AsyncRemoteCallInfo : TypeAlias = tuple [int , Callable [[], Awaitable [Any ]]]
219241
220242
221- # TODO: Improve code sharing between AsyncWebsocketHandler and
222- # the async-native AsyncLMStudioWebsocket implementation
223- # (likely by migrating the websocket over to using the handler)
224243class AsyncWebsocketHandler :
225244 """Async task handler for a single websocket connection."""
226245
0 commit comments