2929)
3030
3131# Synchronous API still uses an async websocket (just in a background thread)
32- from anyio import create_task_group
32+ from anyio import create_task_group , get_cancelled_exc_class
3333from exceptiongroup import suppress
3434from httpx_ws import aconnect_ws , AsyncWebSocketSession , HTTPXWSException
3535
4747T = TypeVar ("T" )
4848
4949
50- class BackgroundThread (threading .Thread ):
51- """Background async event loop thread."""
52-
53- def __init__ (
54- self ,
55- task_target : Callable [[], Coroutine [Any , Any , Any ]] | None = None ,
56- name : str | None = None ,
57- ) -> None :
58- # Accepts the same args as `threading.Thread`, *except*:
59- # * a `task_target` coroutine replaces the `target` function
60- # * No `daemon` option (always runs as a daemon)
61- # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
62- # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
63- self ._task_target = task_target
64- self ._loop_started = threading .Event ()
65- self ._terminate = asyncio .Event ()
66- self ._event_loop : asyncio .AbstractEventLoop | None = None
67- # Annoyingly, we have to mark the background thread as a daemon thread to
68- # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
69- # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
70- super ().__init__ (name = name , daemon = True )
71- weakref .finalize (self , self .terminate )
72-
73- def run (self ) -> None :
74- """Run an async event loop in the background thread."""
75- # Only public to override threading.Thread.run
76- asyncio .run (self ._run_until_terminated ())
77-
78- def wait_for_loop (self ) -> asyncio .AbstractEventLoop | None :
79- """Wait for the event loop to start from a synchronous foreground thread."""
80- if self ._event_loop is None and not self ._loop_started .is_set ():
81- self ._loop_started .wait ()
82- return self ._event_loop
83-
84- async def wait_for_loop_async (self ) -> asyncio .AbstractEventLoop | None :
85- """Wait for the event loop to start from an asynchronous foreground thread."""
86- return await asyncio .to_thread (self .wait_for_loop )
50+ class _BackgroundTaskHandlerMixin :
51+ # Subclasses need to handle providing these instance attributes
52+ _event_loop : asyncio .AbstractEventLoop | None
53+ _task_target : Callable [[], Coroutine [Any , Any , Any ]] | None
54+ _terminate : asyncio .Event
8755
8856 def called_in_background_loop (self ) -> bool :
8957 """Returns true if currently running in this thread's event loop, false otherwise."""
@@ -123,10 +91,12 @@ async def terminate_async(self) -> bool:
12391 """Request termination of the event loop from an asynchronous foreground thread."""
12492 return await asyncio .to_thread (self .terminate )
12593
94+ def _init_event_loop (self ) -> None :
95+ self ._event_loop = asyncio .get_running_loop ()
96+
12697 async def _run_until_terminated (self ) -> None :
12798 """Run task in the background thread until termination is requested."""
128- self ._event_loop = asyncio .get_running_loop ()
129- self ._loop_started .set ()
99+ self ._init_event_loop ()
130100 # Use anyio and exceptiongroup to handle the lack of native task
131101 # and exception groups prior to Python 3.11
132102 raise_on_termination , terminated_exc = self ._raise_on_termination ()
@@ -163,6 +133,49 @@ def schedule_background_task(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T
163133 assert loop is not None
164134 return asyncio .run_coroutine_threadsafe (coro , loop )
165135
136+
137+ class BackgroundThread (_BackgroundTaskHandlerMixin , threading .Thread ):
138+ """Background async event loop thread."""
139+
140+ def __init__ (
141+ self ,
142+ task_target : Callable [[], Coroutine [Any , Any , Any ]] | None = None ,
143+ name : str | None = None ,
144+ ) -> None :
145+ # Accepts the same args as `threading.Thread`, *except*:
146+ # * a `task_target` coroutine replaces the `target` function
147+ # * No `daemon` option (always runs as a daemon)
148+ # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
149+ # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
150+ self ._task_target = task_target
151+ self ._loop_started = threading .Event ()
152+ self ._terminate = asyncio .Event ()
153+ self ._event_loop : asyncio .AbstractEventLoop | None = None
154+ # Annoyingly, we have to mark the background thread as a daemon thread to
155+ # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
156+ # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
157+ super ().__init__ (name = name , daemon = True )
158+ weakref .finalize (self , self .terminate )
159+
160+ def run (self ) -> None :
161+ """Run an async event loop in the background thread."""
162+ # Only public to override threading.Thread.run
163+ asyncio .run (self ._run_until_terminated ())
164+
165+ def _init_event_loop (self ) -> None :
166+ super ()._init_event_loop ()
167+ self ._loop_started .set ()
168+
169+ def wait_for_loop (self ) -> asyncio .AbstractEventLoop | None :
170+ """Wait for the event loop to start from a synchronous foreground thread."""
171+ if self ._event_loop is None and not self ._loop_started .is_set ():
172+ self ._loop_started .wait ()
173+ return self ._event_loop
174+
175+ async def wait_for_loop_async (self ) -> asyncio .AbstractEventLoop | None :
176+ """Wait for the event loop to start from an asynchronous foreground thread."""
177+ return await asyncio .to_thread (self .wait_for_loop )
178+
166179 def run_background_task (self , coro : Coroutine [Any , Any , T ]) -> T :
167180 """Run given coroutine in the background event loop and wait for the result."""
168181 return self .schedule_background_task (coro ).result ()
@@ -178,62 +191,83 @@ def call_in_background(self, callback: Callable[[], Any]) -> None:
178191 loop .call_soon_threadsafe (callback )
179192
180193
181- # TODO: Allow multiple websockets to share a single event loop thread
182- # (reduces thread usage in sync API, blocker for async API migration)
183194class AsyncWebsocketThread (BackgroundThread ):
195+ def __init__ (self , log_context : LogEventContext | None = None ) -> None :
196+ super ().__init__ (task_target = self ._run_main_task )
197+ self ._logger = logger = get_logger (type (self ).__name__ )
198+ logger .update_context (log_context , thread_id = self .name )
199+
200+ async def _run_main_task (self ) -> None :
201+ self ._logger .info ("Websocket handling thread started" )
202+ never_set = asyncio .Event ()
203+ try :
204+ # Run the event loop until termination is requested
205+ await never_set .wait ()
206+ except get_cancelled_exc_class ():
207+ pass
208+ except BaseException :
209+ err_msg = "Terminating websocket thread due to exception"
210+ self ._logger .debug (err_msg , exc_info = True )
211+ self ._logger .info ("Websocket thread terminated" )
212+
213+
214+ # TODO: Improve code sharing between AsyncWebsocketHandler and
215+ # the async-native AsyncLMStudioWebsocket implementation
216+ class AsyncWebsocketHandler (_BackgroundTaskHandlerMixin ):
217+ """Async task handler for a single websocket connection."""
218+
184219 def __init__ (
185220 self ,
221+ ws_thread : AsyncWebsocketThread ,
186222 ws_url : str ,
187223 auth_details : DictObject ,
188224 enqueue_message : Callable [[DictObject ], bool ],
189- log_context : LogEventContext ,
225+ log_context : LogEventContext | None = None ,
190226 ) -> None :
191227 self ._auth_details = auth_details
192228 self ._connection_attempted = asyncio .Event ()
193229 self ._connection_failure : Exception | None = None
194230 self ._auth_failure : Any | None = None
195231 self ._terminate = asyncio .Event ()
232+ self ._ws_thread = ws_thread
196233 self ._ws_url = ws_url
197234 self ._ws : AsyncWebSocketSession | None = None
198235 self ._rx_task : asyncio .Task [None ] | None = None
199236 self ._queue_message = enqueue_message
200- super (). __init__ ( task_target = self . _run_main_task )
237+ self . _logger = get_logger ( type ( self ). __name__ )
201238 self ._logger = logger = get_logger (type (self ).__name__ )
202- logger .update_context (log_context , thread_id = self . name )
239+ logger .update_context (log_context )
203240
204241 def connect (self ) -> bool :
205- if not self .is_alive ():
206- self .start ()
207- loop = self .wait_for_loop () # Block until connection has been attempted
242+ ws_thread = self ._ws_thread
243+ if not ws_thread .is_alive ():
244+ raise RuntimeError ("Websocket handling thread has failed unexpectedly" )
245+ loop = ws_thread .wait_for_loop () # Block until loop is available
208246 if loop is None :
209- return False
247+ raise RuntimeError ("Websocket handling thread has no event loop" )
248+ ws_thread .schedule_background_task (self ._run_until_terminated ())
210249 asyncio .run_coroutine_threadsafe (
211250 self ._connection_attempted .wait (), loop
212251 ).result ()
213252 return self ._ws is not None
214253
215- def disconnect (self ) -> None :
216- if self ._ws is not None :
217- self .terminate ()
218- # Ensure thread has terminated
219- self .join ()
220-
221- async def _run_main_task (self ) -> None :
222- self ._logger .info ("Websocket thread started" )
254+ async def _task_target (self ) -> None :
255+ self ._logger .info ("Websocket handling task started" )
256+ self ._init_event_loop ()
223257 try :
224- await self ._main_task ()
258+ await self ._handle_ws ()
259+ except get_cancelled_exc_class ():
260+ pass
225261 except BaseException :
226- err_msg = "Terminating websocket thread due to exception"
262+ err_msg = "Terminating websocket task due to exception"
227263 self ._logger .debug (err_msg , exc_info = True )
228264 finally :
229265 # Ensure the foreground thread is unblocked even if the
230266 # background async task errors out completely
231267 self ._connection_attempted .set ()
232- self ._logger .info ("Websocket thread terminated" )
268+ self ._logger .info ("Websocket task terminated" )
233269
234- # TODO: Improve code sharing between this background thread async websocket
235- # and the async-native AsyncLMStudioWebsocket implementation
236- async def _main_task (self ) -> None :
270+ async def _handle_ws (self ) -> None :
237271 resources = AsyncExitStack ()
238272 try :
239273 ws : AsyncWebSocketSession = await resources .enter_async_context (
@@ -274,6 +308,10 @@ async def _send_json(self, message: DictObject) -> None:
274308 self ._logger .debug (str (err ), exc_info = True )
275309 raise err from None
276310
311+ def send_json (self , message : DictObject ) -> None :
312+ future = self .schedule_background_task (self ._send_json (message ))
313+ future .result () # Block until the message is sent
314+
277315 async def _receive_json (self ) -> Any :
278316 # This is only called if the websocket has been created
279317 assert self .called_in_background_loop ()
@@ -335,8 +373,6 @@ async def _demultiplexing_task(self) -> None:
335373 finally :
336374 self ._logger .info ("Websocket closed, terminating demultiplexing task." )
337375
338- raise_on_termination , terminated_exc = self ._raise_on_termination ()
339-
340376 async def _receive_messages (self ) -> None :
341377 """Process received messages until task is cancelled."""
342378 while True :
@@ -349,6 +385,38 @@ async def _receive_messages(self) -> None:
349385 self ._terminate .set ()
350386 break
351387
388+
389+ class SyncToAsyncWebsocketBridge :
390+ def __init__ (
391+ self ,
392+ ws_thread : AsyncWebsocketThread ,
393+ ws_url : str ,
394+ auth_details : DictObject ,
395+ enqueue_message : Callable [[DictObject ], bool ],
396+ log_context : LogEventContext ,
397+ ) -> None :
398+ self ._ws_handler = AsyncWebsocketHandler (
399+ ws_thread , ws_url , auth_details , enqueue_message , log_context
400+ )
401+
402+ def connect (self ) -> bool :
403+ return self ._ws_handler .connect ()
404+
405+ def disconnect (self ) -> None :
406+ self ._ws_handler .terminate ()
407+
352408 def send_json (self , message : DictObject ) -> None :
353- # Block until message has been sent
354- self .run_background_task (self ._send_json (message ))
409+ self ._ws_handler .send_json (message )
410+
411+ # These attributes are currently accessed directly...
412+ @property
413+ def _ws (self ) -> AsyncWebSocketSession | None :
414+ return self ._ws_handler ._ws
415+
416+ @property
417+ def _connection_failure (self ) -> Exception | None :
418+ return self ._ws_handler ._connection_failure
419+
420+ @property
421+ def _auth_failure (self ) -> Any | None :
422+ return self ._ws_handler ._auth_failure
0 commit comments