diff --git a/misc/atexit-example.py b/misc/atexit-example.py new file mode 100644 index 0000000..7183193 --- /dev/null +++ b/misc/atexit-example.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +"""Check accessing the default client from an atexit hook.""" +import atexit + +# Intentionally named with a hyphen to ensure this can't be imported + +import lmstudio as lms + +# TODO: Turn this into a CI test case (perhaps via subprocess invocation?) + +# Prior to lmstudio-python 1.5.0, the atexit hook below would hang on shutdown. +# By the time atexit hooks run, asyncio.to_thread no longer works due to all +# concurrent.futures managed thread pools (including those used by asyncio) +# being shut down before the interpreter waits for non-daemon threads to terminate. +# Since the synchronous client relied on asyncio.to_thread to deliver +# messages from the async background comms thread to the blocking +# foreground thread, the sync API didn't work in this scenario. +# In 1.5.0, the sync message reception was reworked to queue messages entirely +# in the async background thread with blocking async queue reads, eliminating +# the blocking queue write operations, and allowing the client to continue running +# in atexit threads +@atexit.register +def access_default_client(): + """Ensure default client can be accessed from an atexit hook.""" + print("During shutdown:", end=" ", flush=True) + print(lms.list_loaded_models()) + +print("Prior to shutdown:", lms.list_loaded_models()) diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index a239251..1d68cf8 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -14,14 +14,14 @@ import weakref from concurrent.futures import Future as SyncFuture -from contextlib import ( - AsyncExitStack, -) +from contextlib import AsyncExitStack +from functools import partial from typing import ( Any, Awaitable, Coroutine, Callable, + Iterable, TypeVar, ) @@ -221,6 +221,8 @@ def __init__( # Annoyingly, we have to mark the background thread as a daemon thread to # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 + # TODO: skip thread daemonization when running in a subinterpreter + # (and also disable the convenience API in subinterpreters to avoid hanging on shutdown) super().__init__(name=name, daemon=True) weakref.finalize(self, self.terminate) @@ -278,6 +280,25 @@ def call_in_background(self, func: Callable[[], Any]) -> None: self._task_manager.call_soon_threadsafe(func) +# By default, the weakref finalization atexit hook is registered lazily. +# This can lead to shutdown sequencing issues if SDK users attempt to access +# client instances (such as the default sync client) from atexit hooks +# registered at import time (so they may end up running after the weakref +# finalization hook has already terminated background threads) +# Creating this finalizer here ensures the weakref finalization hook is +# registered at import time, and hence runs *after* any such hooks +# (assuming the lmstudio SDK is imported before the hooks are registered) +def _register_weakref_atexit_hook() -> None: + class C: + pass + + weakref.finalize(C(), int) + + +_register_weakref_atexit_hook() +del _register_weakref_atexit_hook + + class AsyncWebsocketThread(BackgroundThread): def __init__(self, log_context: LogEventContext | None = None) -> None: super().__init__(task_target=self._log_thread_execution) @@ -439,6 +460,20 @@ def send_json_threadsafe(self, message: DictObject) -> None: future = self._task_manager.run_coroutine_threadsafe(self.send_json(message)) future.result() # Block until the message is sent + def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: + """Run given coroutine in the event loop and wait for the result.""" + return self._task_manager.run_coroutine_threadsafe(coro).result() + + def rx_queue_get_threadsafe(self, rx_queue: asyncio.Queue[Any]) -> Any: + future = self._task_manager.run_coroutine_threadsafe(rx_queue.get()) + return future.result() + + def rx_queue_put_threadsafe( + self, rx_queue: asyncio.Queue[Any], message: Any + ) -> None: + future = self._task_manager.run_coroutine_threadsafe(rx_queue.put(message)) + return future.result() + async def _receive_json(self) -> Any: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() @@ -502,15 +537,21 @@ def __init__( ws_thread: AsyncWebsocketThread, ws_url: str, auth_details: DictObject, - enqueue_message: Callable[[DictObject | None], bool], + get_queue: Callable[[DictObject | None], asyncio.Queue[Any] | None], + iter_queues: Callable[[], Iterable[asyncio.Queue[Any]]], log_context: LogEventContext, ) -> None: - async def enqueue_async(message: DictObject | None) -> bool: - return await asyncio.to_thread(enqueue_message, message) - + self._get_queue = get_queue + self._iter_queues = iter_queues self._ws_handler = AsyncWebsocketHandler( - ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context + ws_thread.task_manager, + ws_url, + auth_details, + self._enqueue_message, + log_context, ) + self._logger = logger = new_logger(type(self).__name__) + logger.update_context(log_context) def connect(self) -> bool: return self._ws_handler.connect_threadsafe() @@ -521,6 +562,37 @@ def disconnect(self) -> None: def send_json(self, message: DictObject) -> None: self._ws_handler.send_json_threadsafe(message) + def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[], Any]]: + rx_queue: asyncio.Queue[Any] = asyncio.Queue() + return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue) + + async def _enqueue_message(self, message: Any) -> bool: + rx_queue = self._get_queue(message) + if message is None: + return await self.notify_client_termination() > 0 + if rx_queue is None: + return False + await rx_queue.put(message) + return True + + async def notify_client_termination(self) -> int: + """Send None to all clients with open receive queues (from background thread).""" + num_clients = 0 + for rx_queue in self._iter_queues(): + await rx_queue.put(None) + num_clients += 1 + self._logger.debug( + f"Notified {num_clients} clients of websocket termination", + num_clients=num_clients, + ) + return num_clients + + def notify_client_termination_threadsafe(self) -> int: + """Send None to all clients with open receive queues (from foreground thread).""" + return self._ws_handler.run_background_coroutine( + self.notify_client_termination() + ) + # These attributes are currently accessed directly... @property def _ws(self) -> AsyncWebSocketSession | None: diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 487a83a..00b1365 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -526,6 +526,8 @@ def _redact_json(data: DictObject | None) -> DictObject | None: return redacted +# TODO: Now that even the sync API uses asyncio.Queue, +# change the multiplexing manager to no longer be generic TQueue = TypeVar("TQueue") diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index 8a368ef..b7d8f28 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1,8 +1,8 @@ """Sync I/O protocol implementation for the LM Studio remote access API.""" +import asyncio import itertools import time -import queue import weakref from abc import abstractmethod @@ -158,14 +158,14 @@ class SyncChannel(Generic[T]): def __init__( self, channel_id: int, - rx_queue: queue.Queue[Any], + get_message: Callable[[], Any], endpoint: ChannelEndpoint[T, Any, Any], send_json: Callable[[DictObject], None], log_context: LogEventContext, ) -> None: """Initialize synchronous websocket streaming channel.""" self._is_finished = False - self._rx_queue = rx_queue + self._get_message = get_message self._api_channel = ChannelHandler(channel_id, endpoint, log_context) self._send_json = send_json @@ -193,7 +193,7 @@ def rx_stream( with sdk_public_api(): # Avoid emitting tracebacks that delve into supporting libraries # (we can't easily suppress the SDK's own frames for iterators) - message = self._rx_queue.get() + message = self._get_message() contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True @@ -216,12 +216,12 @@ class SyncRemoteCall: def __init__( self, call_id: int, - rx_queue: queue.Queue[Any], + get_message: Callable[[], Any], log_context: LogEventContext, notice_prefix: str = "RPC", ) -> None: """Initialize synchronous remote procedure call.""" - self._rx_queue = rx_queue + self._get_message = get_message self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix) self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context, call_id=call_id) @@ -234,12 +234,12 @@ def get_rpc_message( def receive_result(self) -> Any: """Receive call response on the receive queue.""" - message = self._rx_queue.get() + message = self._get_message() return self._rpc.handle_rx_message(message) class SyncLMStudioWebsocket( - LMStudioWebsocket[SyncToAsyncWebsocketBridge, queue.Queue[Any]] + LMStudioWebsocket[SyncToAsyncWebsocketBridge, asyncio.Queue[Any]] ): """Synchronous websocket client that handles demultiplexing of reply messages.""" @@ -279,7 +279,8 @@ def connect(self) -> Self: self._ws_thread, self._ws_url, self._auth_details, - self._enqueue_message, + self._get_rx_queue, + self._mux.all_queues, self._logger.event_context, ) if not ws.connect(): @@ -298,35 +299,12 @@ def disconnect(self) -> None: self._ws = None if ws is not None: self._logger.debug(f"Disconnecting websocket session ({self._ws_url})") - self._notify_client_termination() + ws.notify_client_termination_threadsafe() ws.disconnect() self._logger.info(f"Websocket session disconnected ({self._ws_url})") close = disconnect - def _enqueue_message(self, message: Any) -> bool: - if message is None: - self._logger.info(f"Websocket session failed ({self._ws_url})") - self._ws = None - return self._notify_client_termination() > 0 - rx_queue = self._mux.map_rx_message(message) - if rx_queue is None: - return False - rx_queue.put(message) - return True - - def _notify_client_termination(self) -> int: - """Send None to all clients with open receive queues.""" - num_clients = 0 - for rx_queue in self._mux.all_queues(): - rx_queue.put(None) - num_clients += 1 - self._logger.debug( - f"Notified {num_clients} clients of websocket termination", - num_clients=num_clients, - ) - return num_clients - def _send_json(self, message: DictObject) -> None: # Callers are expected to call `_ensure_connected` before this method ws = self._ws @@ -334,6 +312,13 @@ def _send_json(self, message: DictObject) -> None: # Background thread handles the exception conversion ws.send_json(message) + def _get_rx_queue(self, message: Any) -> asyncio.Queue[Any] | None: + if message is None: + self._logger.info(f"Websocket session failed ({self._ws_url})") + self._ws = None + return None + return self._mux.map_rx_message(message) + def _connect_to_endpoint(self, channel: SyncChannel[Any]) -> None: """Connect channel to specified endpoint.""" self._ensure_connected("open channel endpoints") @@ -347,19 +332,18 @@ def open_channel( endpoint: ChannelEndpoint[T, Any, Any], ) -> Generator[SyncChannel[T], None, None]: """Open a streaming channel over the websocket.""" - rx_queue: queue.Queue[Any] = queue.Queue() + ws = self._ws + assert ws is not None + rx_queue, getter = ws.new_rx_queue() with self._mux.assign_channel_id(rx_queue) as channel_id: channel = SyncChannel( channel_id, - rx_queue, + getter, endpoint, self._send_json, self._logger.event_context, ) self._connect_to_endpoint(channel) - if self._ws is None: - # Link has been terminated, ensure client gets a response - rx_queue.put(None) yield channel def _send_call( @@ -388,15 +372,14 @@ def remote_call( notice_prefix: str = "RPC", ) -> Any: """Make a remote procedure call over the websocket.""" - rx_queue: queue.Queue[Any] = queue.Queue() + ws = self._ws + assert ws is not None + rx_queue, getter = ws.new_rx_queue() with self._mux.assign_call_id(rx_queue) as call_id: rpc = SyncRemoteCall( - call_id, rx_queue, self._logger.event_context, notice_prefix + call_id, getter, self._logger.event_context, notice_prefix ) self._send_call(rpc, endpoint, params) - if self._ws is None: - # Link has been terminated, ensure client gets a response - rx_queue.put(None) return rpc.receive_result()