2323 TypeAlias ,
2424 TypeVar ,
2525)
26+ from typing_extensions import (
27+ # Native in 3.11+
28+ Self ,
29+ )
2630
2731from anyio import create_task_group , move_on_after
2832from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
4549
4650
4751class AsyncTaskManager :
48- def __init__ (self , * , on_activation : Callable [[], Any ] | None ) -> None :
52+ def __init__ (self , * , on_activation : Callable [[], Any ] | None = None ) -> None :
4953 self ._activated = False
5054 self ._event_loop : asyncio .AbstractEventLoop | None = None
5155 self ._on_activation = on_activation
5256 self ._task_queue : asyncio .Queue [Callable [[], Awaitable [Any ]]] = asyncio .Queue ()
5357 self ._terminate = asyncio .Event ()
5458 self ._terminated = asyncio .Event ()
59+ # For the case where the task manager is run via its context manager
60+ self ._tm_started = asyncio .Event ()
61+ self ._tm_task : asyncio .Task [Any ] | None = None
62+
63+ ACTIVATION_TIMEOUT = 5 # Just starts an async task, should be fast
64+ TERMINATION_TIMEOUT = 20 # May have to shut down TCP links
5565
5666 @property
5767 def activated (self ) -> bool :
@@ -65,6 +75,20 @@ def active(self) -> bool:
6575 and not self ._terminated .is_set ()
6676 )
6777
78+ async def __aenter__ (self ) -> Self :
79+ # Handle reentrancy the same way files do:
80+ # allow nested use as a CM, but close on the first exit
81+ if self ._tm_task is None :
82+ self ._tm_task = asyncio .create_task (self .run_until_terminated ())
83+ with move_on_after (self .ACTIVATION_TIMEOUT ):
84+ await self ._tm_started .wait ()
85+ return self
86+
87+ async def __aexit__ (self , * args : Any ) -> None :
88+ await self .request_termination ()
89+ with move_on_after (self .TERMINATION_TIMEOUT ):
90+ await self ._terminated .wait ()
91+
6892 def check_running_in_task_loop (self , * , allow_inactive : bool = False ) -> bool :
6993 """Returns if running in this manager's event loop, raises RuntimeError otherwise."""
7094 this_loop = self ._event_loop
@@ -138,6 +162,7 @@ def _init_event_loop(self) -> None:
138162 notify = self ._on_activation
139163 if notify is not None :
140164 notify ()
165+ self ._tm_started .set ()
141166
142167 async def run_until_terminated (
143168 self , func : Callable [[], Coroutine [Any , Any , Any ]] | None = None
@@ -218,9 +243,6 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle:
218243AsyncRemoteCallInfo : TypeAlias = tuple [int , Callable [[], Awaitable [Any ]]]
219244
220245
221- # TODO: Improve code sharing between AsyncWebsocketHandler and
222- # the async-native AsyncLMStudioWebsocket implementation
223- # (likely by migrating the websocket over to using the handler)
224246class AsyncWebsocketHandler :
225247 """Async task handler for a single websocket connection."""
226248
@@ -243,7 +265,7 @@ def __init__(
243265 self ._ws_disconnected = asyncio .Event ()
244266 self ._rx_task : asyncio .Task [None ] | None = None
245267 self ._logger = logger = new_logger (type (self ).__name__ )
246- logger .update_context (log_context )
268+ logger .update_context (log_context , ws_url = ws_url )
247269 self ._mux = MultiplexingManager (logger )
248270
249271 async def connect (self ) -> bool :
@@ -275,7 +297,7 @@ def disconnect_threadsafe(self) -> None:
275297 task_manager .run_coroutine_threadsafe (self .disconnect ()).result ()
276298
277299 async def _logged_ws_handler (self ) -> None :
278- self ._logger .info ("Websocket handling task started" )
300+ self ._logger .debug ("Websocket handling task started" )
279301 try :
280302 await self ._handle_ws ()
281303 except (asyncio .CancelledError , GeneratorExit ):
@@ -287,7 +309,7 @@ async def _logged_ws_handler(self) -> None:
287309 # Ensure connections attempt are unblocked even if the
288310 # background async task errors out completely
289311 self ._connection_attempted .set ()
290- self ._logger .info ("Websocket task terminated" )
312+ self ._logger .debug ("Websocket task terminated" )
291313
292314 async def _handle_ws (self ) -> None :
293315 assert self ._task_manager .check_running_in_task_loop ()
@@ -311,12 +333,12 @@ def _clear_task_state() -> None:
311333 if not await self ._authenticate ():
312334 return
313335 self ._connection_attempted .set ()
314- self ._logger .info (f "Websocket session established ( { self . _ws_url } ) " )
336+ self ._logger .info ("Websocket session established" )
315337 # Task will run until message reception fails or is cancelled
316338 try :
317339 await self ._receive_messages ()
318340 finally :
319- self ._logger .info ("Websocket demultiplexing task terminated." )
341+ self ._logger .debug ("Websocket demultiplexing task terminated." )
320342 # Notify foreground thread of background thread termination
321343 # (this covers termination due to link failure)
322344 await self .notify_client_termination ()
0 commit comments