11"""Async I/O protocol implementation for the LM Studio remote access API."""
22
33import asyncio
4- import asyncio .queues
54import warnings
65
76from abc import abstractmethod
2827 TypeIs ,
2928)
3029
30+ from anyio import create_task_group
31+ from anyio .abc import TaskGroup
3132from httpx import RequestError , HTTPStatusError
3233from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
3334
@@ -163,7 +164,10 @@ async def rx_stream(
163164 # Avoid emitting tracebacks that delve into supporting libraries
164165 # (we can't easily suppress the SDK's own frames for iterators)
165166 message = await self ._rx_queue .get ()
166- contents = self ._api_channel .handle_rx_message (message )
167+ if message is None :
168+ contents = None
169+ else :
170+ contents = self ._api_channel .handle_rx_message (message )
167171 if contents is None :
168172 self ._is_finished = True
169173 break
@@ -204,6 +208,8 @@ def get_rpc_message(
204208 async def receive_result (self ) -> Any :
205209 """Receive call response on the receive queue."""
206210 message = await self ._rx_queue .get ()
211+ if message is None :
212+ return None
207213 return self ._rpc .handle_rx_message (message )
208214
209215
@@ -220,8 +226,10 @@ def __init__(
220226 ) -> None :
221227 """Initialize asynchronous websocket client."""
222228 super ().__init__ (ws_url , auth_details , log_context )
223- self ._resource_manager = AsyncExitStack ()
229+ self ._resource_manager = rm = AsyncExitStack ()
230+ rm .push_async_callback (self ._notify_client_termination )
224231 self ._rx_task : asyncio .Task [None ] | None = None
232+ self ._terminate = asyncio .Event ()
225233
226234 @property
227235 def _httpx_ws (self ) -> AsyncWebSocketSession | None :
@@ -241,7 +249,9 @@ async def __aexit__(self, *args: Any) -> None:
241249 async def _send_json (self , message : DictObject ) -> None :
242250 # Callers are expected to call `_ensure_connected` before this method
243251 ws = self ._ws
244- assert ws is not None
252+ if ws is None :
253+ # Assume app is shutting down and the owning task has already been cancelled
254+ return
245255 try :
246256 await ws .send_json (message )
247257 except Exception as exc :
@@ -253,7 +263,9 @@ async def _send_json(self, message: DictObject) -> None:
253263 async def _receive_json (self ) -> Any :
254264 # Callers are expected to call `_ensure_connected` before this method
255265 ws = self ._ws
256- assert ws is not None
266+ if ws is None :
267+ # Assume app is shutting down and the owning task has already been cancelled
268+ return
257269 try :
258270 return await ws .receive_json ()
259271 except Exception as exc :
@@ -291,7 +303,7 @@ async def connect(self) -> Self:
291303 self ._rx_task = rx_task = asyncio .create_task (self ._receive_messages ())
292304
293305 async def _terminate_rx_task () -> None :
294- rx_task . cancel ()
306+ self . _terminate . set ()
295307 try :
296308 await rx_task
297309 except asyncio .CancelledError :
@@ -305,19 +317,34 @@ async def disconnect(self) -> None:
305317 """Drop the LM Studio API connection."""
306318 self ._ws = None
307319 self ._rx_task = None
308- await self ._notify_client_termination ()
320+ self ._terminate . set ()
309321 await self ._resource_manager .aclose ()
310322 self ._logger .info (f"Websocket session disconnected ({ self ._ws_url } )" )
311323
312324 aclose = disconnect
313325
326+ async def _cancel_on_termination (self , tg : TaskGroup ) -> None :
327+ await self ._terminate .wait ()
328+ tg .cancel_scope .cancel ()
329+
314330 async def _process_next_message (self ) -> bool :
315331 """Process the next message received on the websocket.
316332
317333 Returns True if a message queue was updated.
318334 """
319335 self ._ensure_connected ("receive messages" )
320- message = await self ._receive_json ()
336+ async with create_task_group () as tg :
337+ tg .start_soon (self ._cancel_on_termination , tg )
338+ try :
339+ message = await self ._receive_json ()
340+ except (LMStudioWebsocketError , HTTPXWSException ):
341+ if self ._ws is not None and not self ._terminate .is_set ():
342+ # Websocket failed unexpectedly (rather than due to client shutdown)
343+ self ._logger .error ("Websocket failed, terminating session." )
344+ self ._terminate .set ()
345+ tg .cancel_scope .cancel ()
346+ if self ._terminate .is_set ():
347+ return (await self ._notify_client_termination ()) > 0
321348 rx_queue = self ._mux .map_rx_message (message )
322349 if rx_queue is None :
323350 return False
@@ -326,18 +353,20 @@ async def _process_next_message(self) -> bool:
326353
327354 async def _receive_messages (self ) -> None :
328355 """Process received messages until connection is terminated."""
329- while True :
330- try :
331- await self ._process_next_message ()
332- except (LMStudioWebsocketError , HTTPXWSException ):
333- self ._logger .exception ("Websocket failed, terminating session." )
334- await self .disconnect ()
335- break
356+ while not self ._terminate .is_set ():
357+ await self ._process_next_message ()
336358
337- async def _notify_client_termination (self ) -> None :
359+ async def _notify_client_termination (self ) -> int :
338360 """Send None to all clients with open receive queues."""
361+ num_clients = 0
339362 for rx_queue in self ._mux .all_queues ():
340363 await rx_queue .put (None )
364+ num_clients += 1
365+ self ._logger .info (
366+ f"Notified { num_clients } clients of websocket termination" ,
367+ num_clients = num_clients ,
368+ )
369+ return num_clients
341370
342371 async def _connect_to_endpoint (self , channel : AsyncChannel [Any ]) -> None :
343372 """Connect channel to specified endpoint."""
@@ -362,6 +391,9 @@ async def open_channel(
362391 self ._logger .event_context ,
363392 )
364393 await self ._connect_to_endpoint (channel )
394+ if self ._terminate .is_set ():
395+ # Link has been terminated, ensure client gets a response
396+ await rx_queue .put (None )
365397 yield channel
366398
367399 async def _send_call (
@@ -396,6 +428,9 @@ async def remote_call(
396428 call_id , rx_queue , self ._logger .event_context , notice_prefix
397429 )
398430 await self ._send_call (rpc , endpoint , params )
431+ if self ._terminate .is_set ():
432+ # Link has been terminated, ensure client gets a response
433+ await rx_queue .put (None )
399434 return await rpc .receive_result ()
400435
401436
0 commit comments