11"""Sync I/O protocol implementation for the LM Studio remote access API."""
22
3+ import asyncio
34import itertools
45import time
5- import queue
66import weakref
77
88from abc import abstractmethod
@@ -158,14 +158,14 @@ class SyncChannel(Generic[T]):
158158 def __init__ (
159159 self ,
160160 channel_id : int ,
161- rx_queue : queue . Queue [ Any ],
161+ get_message : Callable [[], Any ],
162162 endpoint : ChannelEndpoint [T , Any , Any ],
163163 send_json : Callable [[DictObject ], None ],
164164 log_context : LogEventContext ,
165165 ) -> None :
166166 """Initialize synchronous websocket streaming channel."""
167167 self ._is_finished = False
168- self ._rx_queue = rx_queue
168+ self ._get_message = get_message
169169 self ._api_channel = ChannelHandler (channel_id , endpoint , log_context )
170170 self ._send_json = send_json
171171
@@ -193,7 +193,7 @@ def rx_stream(
193193 with sdk_public_api ():
194194 # Avoid emitting tracebacks that delve into supporting libraries
195195 # (we can't easily suppress the SDK's own frames for iterators)
196- message = self ._rx_queue . get ()
196+ message = self ._get_message ()
197197 contents = self ._api_channel .handle_rx_message (message )
198198 if contents is None :
199199 self ._is_finished = True
@@ -216,12 +216,12 @@ class SyncRemoteCall:
216216 def __init__ (
217217 self ,
218218 call_id : int ,
219- rx_queue : queue . Queue [ Any ],
219+ get_message : Callable [[], Any ],
220220 log_context : LogEventContext ,
221221 notice_prefix : str = "RPC" ,
222222 ) -> None :
223223 """Initialize synchronous remote procedure call."""
224- self ._rx_queue = rx_queue
224+ self ._get_message = get_message
225225 self ._rpc = RemoteCallHandler (call_id , log_context , notice_prefix )
226226 self ._logger = logger = new_logger (type (self ).__name__ )
227227 logger .update_context (log_context , call_id = call_id )
@@ -234,12 +234,12 @@ def get_rpc_message(
234234
235235 def receive_result (self ) -> Any :
236236 """Receive call response on the receive queue."""
237- message = self ._rx_queue . get ()
237+ message = self ._get_message ()
238238 return self ._rpc .handle_rx_message (message )
239239
240240
241241class SyncLMStudioWebsocket (
242- LMStudioWebsocket [SyncToAsyncWebsocketBridge , queue .Queue [Any ]]
242+ LMStudioWebsocket [SyncToAsyncWebsocketBridge , asyncio .Queue [Any ]]
243243):
244244 """Synchronous websocket client that handles demultiplexing of reply messages."""
245245
@@ -279,7 +279,8 @@ def connect(self) -> Self:
279279 self ._ws_thread ,
280280 self ._ws_url ,
281281 self ._auth_details ,
282- self ._enqueue_message ,
282+ self ._get_rx_queue ,
283+ self ._mux .all_queues ,
283284 self ._logger .event_context ,
284285 )
285286 if not ws .connect ():
@@ -298,42 +299,26 @@ def disconnect(self) -> None:
298299 self ._ws = None
299300 if ws is not None :
300301 self ._logger .debug (f"Disconnecting websocket session ({ self ._ws_url } )" )
301- self . _notify_client_termination ()
302+ ws . notify_client_termination_threadsafe ()
302303 ws .disconnect ()
303304 self ._logger .info (f"Websocket session disconnected ({ self ._ws_url } )" )
304305
305306 close = disconnect
306307
307- def _enqueue_message (self , message : Any ) -> bool :
308- if message is None :
309- self ._logger .info (f"Websocket session failed ({ self ._ws_url } )" )
310- self ._ws = None
311- return self ._notify_client_termination () > 0
312- rx_queue = self ._mux .map_rx_message (message )
313- if rx_queue is None :
314- return False
315- rx_queue .put (message )
316- return True
317-
318- def _notify_client_termination (self ) -> int :
319- """Send None to all clients with open receive queues."""
320- num_clients = 0
321- for rx_queue in self ._mux .all_queues ():
322- rx_queue .put (None )
323- num_clients += 1
324- self ._logger .debug (
325- f"Notified { num_clients } clients of websocket termination" ,
326- num_clients = num_clients ,
327- )
328- return num_clients
329-
330308 def _send_json (self , message : DictObject ) -> None :
331309 # Callers are expected to call `_ensure_connected` before this method
332310 ws = self ._ws
333311 assert ws is not None
334312 # Background thread handles the exception conversion
335313 ws .send_json (message )
336314
315+ def _get_rx_queue (self , message : Any ) -> asyncio .Queue [Any ] | None :
316+ if message is None :
317+ self ._logger .info (f"Websocket session failed ({ self ._ws_url } )" )
318+ self ._ws = None
319+ return None
320+ return self ._mux .map_rx_message (message )
321+
337322 def _connect_to_endpoint (self , channel : SyncChannel [Any ]) -> None :
338323 """Connect channel to specified endpoint."""
339324 self ._ensure_connected ("open channel endpoints" )
@@ -347,19 +332,18 @@ def open_channel(
347332 endpoint : ChannelEndpoint [T , Any , Any ],
348333 ) -> Generator [SyncChannel [T ], None , None ]:
349334 """Open a streaming channel over the websocket."""
350- rx_queue : queue .Queue [Any ] = queue .Queue ()
335+ ws = self ._ws
336+ assert ws is not None
337+ rx_queue , getter = ws .new_rx_queue ()
351338 with self ._mux .assign_channel_id (rx_queue ) as channel_id :
352339 channel = SyncChannel (
353340 channel_id ,
354- rx_queue ,
341+ getter ,
355342 endpoint ,
356343 self ._send_json ,
357344 self ._logger .event_context ,
358345 )
359346 self ._connect_to_endpoint (channel )
360- if self ._ws is None :
361- # Link has been terminated, ensure client gets a response
362- rx_queue .put (None )
363347 yield channel
364348
365349 def _send_call (
@@ -388,15 +372,14 @@ def remote_call(
388372 notice_prefix : str = "RPC" ,
389373 ) -> Any :
390374 """Make a remote procedure call over the websocket."""
391- rx_queue : queue .Queue [Any ] = queue .Queue ()
375+ ws = self ._ws
376+ assert ws is not None
377+ rx_queue , getter = ws .new_rx_queue ()
392378 with self ._mux .assign_call_id (rx_queue ) as call_id :
393379 rpc = SyncRemoteCall (
394- call_id , rx_queue , self ._logger .event_context , notice_prefix
380+ call_id , getter , self ._logger .event_context , notice_prefix
395381 )
396382 self ._send_call (rpc , endpoint , params )
397- if self ._ws is None :
398- # Link has been terminated, ensure client gets a response
399- rx_queue .put (None )
400383 return rpc .receive_result ()
401384
402385
0 commit comments