11"""Async I/O protocol implementation for the LM Studio remote access API."""
22
33import asyncio
4- import asyncio .queues
54import warnings
65
76from abc import abstractmethod
2827 TypeIs ,
2928)
3029
30+ from anyio import create_task_group
31+ from anyio .abc import TaskGroup
3132from httpx import RequestError , HTTPStatusError
3233from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
3334
@@ -168,7 +169,10 @@ async def rx_stream(
168169 # Avoid emitting tracebacks that delve into supporting libraries
169170 # (we can't easily suppress the SDK's own frames for iterators)
170171 message = await self ._rx_queue .get ()
171- contents = self ._api_channel .handle_rx_message (message )
172+ if message is None :
173+ contents = None
174+ else :
175+ contents = self ._api_channel .handle_rx_message (message )
172176 if contents is None :
173177 self ._is_finished = True
174178 break
@@ -209,6 +213,8 @@ def get_rpc_message(
209213 async def receive_result (self ) -> Any :
210214 """Receive call response on the receive queue."""
211215 message = await self ._rx_queue .get ()
216+ if message is None :
217+ return None
212218 return self ._rpc .handle_rx_message (message )
213219
214220
@@ -225,8 +231,10 @@ def __init__(
225231 ) -> None :
226232 """Initialize asynchronous websocket client."""
227233 super ().__init__ (ws_url , auth_details , log_context )
228- self ._resource_manager = AsyncExitStack ()
234+ self ._resource_manager = rm = AsyncExitStack ()
235+ rm .push_async_callback (self ._notify_client_termination )
229236 self ._rx_task : asyncio .Task [None ] | None = None
237+ self ._terminate = asyncio .Event ()
230238
231239 @property
232240 def _httpx_ws (self ) -> AsyncWebSocketSession | None :
@@ -246,7 +254,9 @@ async def __aexit__(self, *args: Any) -> None:
246254 async def _send_json (self , message : DictObject ) -> None :
247255 # Callers are expected to call `_ensure_connected` before this method
248256 ws = self ._ws
249- assert ws is not None
257+ if ws is None :
258+ # Assume app is shutting down and the owning task has already been cancelled
259+ return
250260 try :
251261 await ws .send_json (message )
252262 except Exception as exc :
@@ -258,7 +268,9 @@ async def _send_json(self, message: DictObject) -> None:
258268 async def _receive_json (self ) -> Any :
259269 # Callers are expected to call `_ensure_connected` before this method
260270 ws = self ._ws
261- assert ws is not None
271+ if ws is None :
272+ # Assume app is shutting down and the owning task has already been cancelled
273+ return
262274 try :
263275 return await ws .receive_json ()
264276 except Exception as exc :
@@ -296,7 +308,7 @@ async def connect(self) -> Self:
296308 self ._rx_task = rx_task = asyncio .create_task (self ._receive_messages ())
297309
298310 async def _terminate_rx_task () -> None :
299- rx_task . cancel ()
311+ self . _terminate . set ()
300312 try :
301313 await rx_task
302314 except asyncio .CancelledError :
@@ -310,19 +322,34 @@ async def disconnect(self) -> None:
310322 """Drop the LM Studio API connection."""
311323 self ._ws = None
312324 self ._rx_task = None
313- await self ._notify_client_termination ()
325+ self ._terminate . set ()
314326 await self ._resource_manager .aclose ()
315327 self ._logger .info (f"Websocket session disconnected ({ self ._ws_url } )" )
316328
317329 aclose = disconnect
318330
331+ async def _cancel_on_termination (self , tg : TaskGroup ) -> None :
332+ await self ._terminate .wait ()
333+ tg .cancel_scope .cancel ()
334+
319335 async def _process_next_message (self ) -> bool :
320336 """Process the next message received on the websocket.
321337
322338 Returns True if a message queue was updated.
323339 """
324340 self ._ensure_connected ("receive messages" )
325- message = await self ._receive_json ()
341+ async with create_task_group () as tg :
342+ tg .start_soon (self ._cancel_on_termination , tg )
343+ try :
344+ message = await self ._receive_json ()
345+ except (LMStudioWebsocketError , HTTPXWSException ):
346+ if self ._ws is not None and not self ._terminate .is_set ():
347+ # Websocket failed unexpectedly (rather than due to client shutdown)
348+ self ._logger .error ("Websocket failed, terminating session." )
349+ self ._terminate .set ()
350+ tg .cancel_scope .cancel ()
351+ if self ._terminate .is_set ():
352+ return (await self ._notify_client_termination ()) > 0
326353 rx_queue = self ._mux .map_rx_message (message )
327354 if rx_queue is None :
328355 return False
@@ -331,18 +358,20 @@ async def _process_next_message(self) -> bool:
331358
332359 async def _receive_messages (self ) -> None :
333360 """Process received messages until connection is terminated."""
334- while True :
335- try :
336- await self ._process_next_message ()
337- except (LMStudioWebsocketError , HTTPXWSException ):
338- self ._logger .exception ("Websocket failed, terminating session." )
339- await self .disconnect ()
340- break
361+ while not self ._terminate .is_set ():
362+ await self ._process_next_message ()
341363
342- async def _notify_client_termination (self ) -> None :
364+ async def _notify_client_termination (self ) -> int :
343365 """Send None to all clients with open receive queues."""
366+ num_clients = 0
344367 for rx_queue in self ._mux .all_queues ():
345368 await rx_queue .put (None )
369+ num_clients += 1
370+ self ._logger .info (
371+ f"Notified { num_clients } clients of websocket termination" ,
372+ num_clients = num_clients ,
373+ )
374+ return num_clients
346375
347376 async def _connect_to_endpoint (self , channel : AsyncChannel [Any ]) -> None :
348377 """Connect channel to specified endpoint."""
@@ -367,6 +396,9 @@ async def open_channel(
367396 self ._logger .event_context ,
368397 )
369398 await self ._connect_to_endpoint (channel )
399+ if self ._terminate .is_set ():
400+ # Link has been terminated, ensure client gets a response
401+ await rx_queue .put (None )
370402 yield channel
371403
372404 async def _send_call (
@@ -401,6 +433,9 @@ async def remote_call(
401433 call_id , rx_queue , self ._logger .event_context , notice_prefix
402434 )
403435 await self ._send_call (rpc , endpoint , params )
436+ if self ._terminate .is_set ():
437+ # Link has been terminated, ensure client gets a response
438+ await rx_queue .put (None )
404439 return await rpc .receive_result ()
405440
406441
0 commit comments