diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index a9d917c..34217f0 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -1,7 +1,7 @@ """Shared core async websocket implementation for the LM Studio remote access API.""" -# Sync API: runs in background thread with sync queues -# Async convenience API: runs in background thread with async queues +# Sync API: runs in dedicated background thread +# Async convenience API (once implemented): runs in dedicated background thread # Async structured API: runs in foreground event loop # Callback handling rules: @@ -10,36 +10,37 @@ # * All callbacks must be invoked from the *foreground* thread/event loop import asyncio -import threading -import weakref from concurrent.futures import Future as SyncFuture -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, contextmanager from functools import partial from typing import ( Any, Awaitable, Coroutine, Callable, - Iterable, + Generator, + TypeAlias, TypeVar, ) -# Synchronous API still uses an async websocket (just in a background thread) from anyio import create_task_group, move_on_after from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException from .schemas import DictObject -from .json_api import LMStudioWebsocket, LMStudioWebsocketError - -from ._logging import new_logger, LogEventContext +from .json_api import ( + LMStudioWebsocket, + LMStudioWebsocketError, + MultiplexingManager, + RxQueue, +) +from ._logging import LogEventContext, new_logger # Allow the core client websocket management to be shared across all SDK interaction APIs # See https://discuss.python.org/t/daemon-threads-and-background-task-termination/77604 # (Note: this implementation has the elements needed to run on *current* Python versions # and omits the generalised features that the SDK doesn't need) -# Already used by the sync API, async client is still to be migrated T = TypeVar("T") @@ -194,6 +195,17 @@ def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T raise RuntimeError(f"{self!r} is currently inactive.") return asyncio.run_coroutine_threadsafe(coro, loop) + def call_threadsafe(self, func: Callable[[], T]) -> SyncFuture[T]: + """Call non-blocking function in the background event loop and make the result available. + + Important: function must NOT access any scoped resources from the calling scope. + """ + + async def coro() -> T: + return func() + + return self.run_coroutine_threadsafe(coro()) + def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle: """Call given non-blocking function in the background event loop.""" loop = self._event_loop @@ -202,126 +214,13 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle: return loop.call_soon_threadsafe(func) -class BackgroundThread(threading.Thread): - """Background async event loop thread.""" - - def __init__( - self, - task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None, - name: str | None = None, - ) -> None: - # Accepts the same args as `threading.Thread`, *except*: - # * a `task_target` coroutine replaces the `target` function - # * No `daemon` option (always runs as a daemon) - # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run` - # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run` - self._task_target = task_target - self._loop_started = loop_started = threading.Event() - self._task_manager = AsyncTaskManager(on_activation=loop_started.set) - # Annoyingly, we have to mark the background thread as a daemon thread to - # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate - # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 - # TODO: skip thread daemonization when running in a subinterpreter - # (and also disable the convenience API in subinterpreters to avoid hanging on shutdown) - super().__init__(name=name, daemon=True) - weakref.finalize(self, self.terminate) - - @property - def task_manager(self) -> AsyncTaskManager: - return self._task_manager - - def start(self, wait_for_loop: bool = True) -> None: - """Start background thread and (optionally) wait for the event loop to be ready.""" - super().start() - if wait_for_loop: - self.wait_for_loop() - - def run(self) -> None: - """Run an async event loop in the background thread.""" - # Only public to override threading.Thread.run - asyncio.run(self._task_manager.run_until_terminated(self._task_target)) - - def wait_for_loop(self) -> asyncio.AbstractEventLoop | None: - """Wait for the event loop to start from a synchronous foreground thread.""" - if self._task_manager._event_loop is None and not self._task_manager.activated: - self._loop_started.wait() - return self._task_manager._event_loop - - async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None: - """Wait for the event loop to start from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.wait_for_loop) - - def terminate(self) -> bool: - """Request termination of the event loop from a synchronous foreground thread.""" - return self._task_manager.request_termination_threadsafe().result() - - async def terminate_async(self) -> bool: - """Request termination of the event loop from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.terminate) - - def schedule_background_task(self, func: Callable[[], Any]) -> None: - """Schedule given task in the event loop from a synchronous foreground thread.""" - self._task_manager.schedule_task_threadsafe(func) - - async def schedule_background_task_async(self, func: Callable[[], Any]) -> None: - """Schedule given task in the event loop from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.schedule_background_task, func) - - def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: - """Run given coroutine in the event loop and wait for the result.""" - return self._task_manager.run_coroutine_threadsafe(coro).result() - - async def run_background_coroutine_async(self, coro: Coroutine[Any, Any, T]) -> T: - """Run given coroutine in the event loop and await the result.""" - return await asyncio.to_thread(self.run_background_coroutine, coro) - - def call_in_background(self, func: Callable[[], Any]) -> None: - """Call given non-blocking function in the background event loop.""" - self._task_manager.call_soon_threadsafe(func) - - -# By default, the weakref finalization atexit hook is registered lazily. -# This can lead to shutdown sequencing issues if SDK users attempt to access -# client instances (such as the default sync client) from atexit hooks -# registered at import time (so they may end up running after the weakref -# finalization hook has already terminated background threads) -# Creating this finalizer here ensures the weakref finalization hook is -# registered at import time, and hence runs *after* any such hooks -# (assuming the lmstudio SDK is imported before the hooks are registered) -def _register_weakref_atexit_hook() -> None: - class C: - pass - - weakref.finalize(C(), int) - - -_register_weakref_atexit_hook() -del _register_weakref_atexit_hook - - -class AsyncWebsocketThread(BackgroundThread): - def __init__(self, log_context: LogEventContext | None = None) -> None: - super().__init__(task_target=self._log_thread_execution) - self._logger = logger = new_logger(type(self).__name__) - logger.update_context(log_context, thread_id=self.name) - - async def _log_thread_execution(self) -> None: - self._logger.info("Websocket handling thread started") - never_set = asyncio.Event() - try: - # Run the event loop until termination is requested - await never_set.wait() - except (asyncio.CancelledError, GeneratorExit): - raise - except BaseException: - err_msg = "Terminating websocket thread due to exception" - self._logger.debug(err_msg, exc_info=True) - finally: - self._logger.info("Websocket thread terminated") +AsyncChannelInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]] +AsyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]] # TODO: Improve code sharing between AsyncWebsocketHandler and # the async-native AsyncLMStudioWebsocket implementation +# (likely by migrating the websocket over to using the handler) class AsyncWebsocketHandler: """Async task handler for a single websocket connection.""" @@ -332,7 +231,6 @@ def __init__( task_manager: AsyncTaskManager, ws_url: str, auth_details: DictObject, - enqueue_message: Callable[[DictObject | None], Awaitable[bool]], log_context: LogEventContext | None = None, ) -> None: self._auth_details = auth_details @@ -344,9 +242,9 @@ def __init__( self._ws: AsyncWebSocketSession | None = None self._ws_disconnected = asyncio.Event() self._rx_task: asyncio.Task[None] | None = None - self._enqueue_message = enqueue_message self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context) + self._mux = MultiplexingManager(logger) async def connect(self) -> bool: """Connect websocket from the task manager's event loop.""" @@ -386,7 +284,7 @@ async def _logged_ws_handler(self) -> None: err_msg = "Terminating websocket task due to exception" self._logger.debug(err_msg, exc_info=True) finally: - # Ensure the foreground thread is unblocked even if the + # Ensure connections attempt are unblocked even if the # background async task errors out completely self._connection_attempted.set() self._logger.info("Websocket task terminated") @@ -403,9 +301,7 @@ async def _handle_ws(self) -> None: raise def _clear_task_state() -> None: - # Break the reference cycle with the foreground thread - del self._enqueue_message - # Websocket is about to be disconnected + # Websocket is about to be disconnected (if it isn't already) self._ws = None resources.callback(_clear_task_state) @@ -423,7 +319,7 @@ def _clear_task_state() -> None: self._logger.info("Websocket demultiplexing task terminated.") # Notify foreground thread of background thread termination # (this covers termination due to link failure) - await self._enqueue_message(None) + await self.notify_client_termination() dc_timeout = self.WS_DISCONNECT_TIMEOUT with move_on_after(dc_timeout, shield=True) as cancel_scope: # Workaround an anyio/httpx-ws issue with task cancellation: @@ -447,6 +343,9 @@ async def send_json(self, message: DictObject) -> None: ws = self._ws if ws is None: # Assume app is shutting down and the owning task has already been cancelled + rx_queue = self._mux.map_tx_message(message) + if rx_queue is not None: + await rx_queue.put(None) return try: await ws.send_json(message) @@ -464,9 +363,47 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: """Run given coroutine in the event loop and wait for the result.""" return self._task_manager.run_coroutine_threadsafe(coro).result() - def rx_queue_get_threadsafe( - self, rx_queue: asyncio.Queue[Any], timeout: float | None - ) -> Any: + @contextmanager + def open_channel(self) -> Generator[AsyncChannelInfo, None, None]: + assert self._task_manager.check_running_in_task_loop() + rx_queue: RxQueue = asyncio.Queue() + with self._mux.assign_channel_id(rx_queue) as call_id: + yield call_id, rx_queue.get + + @contextmanager + def start_call(self) -> Generator[AsyncRemoteCallInfo, None, None]: + assert self._task_manager.check_running_in_task_loop() + rx_queue: RxQueue = asyncio.Queue() + with self._mux.assign_call_id(rx_queue) as call_id: + yield call_id, rx_queue.get + + def new_threadsafe_rx_queue(self) -> tuple[RxQueue, Callable[[float | None], Any]]: + rx_queue: RxQueue = asyncio.Queue() + return rx_queue, partial(self._rx_queue_get_threadsafe, rx_queue) + + def acquire_channel_id_threadsafe(self, rx_queue: RxQueue) -> int: + future = self._task_manager.call_threadsafe( + partial(self._mux.acquire_channel_id, rx_queue) + ) + return future.result() # Wait for background thread to assign the ID + + def release_channel_id_threadsafe(self, channel_id: int, rx_queue: RxQueue) -> None: + self._task_manager.call_soon_threadsafe( + partial(self._mux.release_channel_id, channel_id, rx_queue) + ) + + def acquire_call_id_threadsafe(self, rx_queue: RxQueue) -> int: + future = self._task_manager.call_threadsafe( + partial(self._mux.acquire_call_id, rx_queue) + ) + return future.result() # Wait for background thread to assign the ID + + def release_call_id_threadsafe(self, call_id: int, rx_queue: RxQueue) -> None: + self._task_manager.call_soon_threadsafe( + partial(self._mux.release_call_id, call_id, rx_queue) + ) + + def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) -> Any: future = self._task_manager.run_coroutine_threadsafe(rx_queue.get()) try: return future.result(timeout) @@ -474,12 +411,6 @@ def rx_queue_get_threadsafe( future.cancel() raise - def rx_queue_put_threadsafe( - self, rx_queue: asyncio.Queue[Any], message: Any - ) -> None: - future = self._task_manager.run_coroutine_threadsafe(rx_queue.put(message)) - return future.result() - async def _receive_json(self) -> Any: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() @@ -536,46 +467,12 @@ async def _receive_messages(self) -> None: self._logger.error("Websocket failed, terminating session.") break - -class SyncToAsyncWebsocketBridge: - def __init__( - self, - ws_thread: AsyncWebsocketThread, - ws_url: str, - auth_details: DictObject, - get_queue: Callable[[DictObject | None], asyncio.Queue[Any] | None], - iter_queues: Callable[[], Iterable[asyncio.Queue[Any]]], - log_context: LogEventContext, - ) -> None: - self._get_queue = get_queue - self._iter_queues = iter_queues - self._ws_handler = AsyncWebsocketHandler( - ws_thread.task_manager, - ws_url, - auth_details, - self._enqueue_message, - log_context, - ) - self._logger = logger = new_logger(type(self).__name__) - logger.update_context(log_context) - - def connect(self) -> bool: - return self._ws_handler.connect_threadsafe() - - def disconnect(self) -> None: - self._ws_handler.disconnect_threadsafe() - - def send_json(self, message: DictObject) -> None: - self._ws_handler.send_json_threadsafe(message) - - def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[float | None], Any]]: - rx_queue: asyncio.Queue[Any] = asyncio.Queue() - return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue) - async def _enqueue_message(self, message: Any) -> bool: - rx_queue = self._get_queue(message) if message is None: + self._logger.info(f"Websocket session failed ({self._ws_url})") + self._ws = None return await self.notify_client_termination() > 0 + rx_queue = self._mux.map_rx_message(message) if rx_queue is None: return False await rx_queue.put(message) @@ -584,7 +481,7 @@ async def _enqueue_message(self, message: Any) -> bool: async def notify_client_termination(self) -> int: """Send None to all clients with open receive queues (from background thread).""" num_clients = 0 - for rx_queue in self._iter_queues(): + for rx_queue in self._mux.all_queues(): await rx_queue.put(None) num_clients += 1 self._logger.debug( @@ -595,19 +492,4 @@ async def notify_client_termination(self) -> int: def notify_client_termination_threadsafe(self) -> int: """Send None to all clients with open receive queues (from foreground thread).""" - return self._ws_handler.run_background_coroutine( - self.notify_client_termination() - ) - - # These attributes are currently accessed directly... - @property - def _ws(self) -> AsyncWebSocketSession | None: - return self._ws_handler._ws - - @property - def _connection_failure(self) -> Exception | None: - return self._ws_handler._connection_failure - - @property - def _auth_failure(self) -> Any | None: - return self._ws_handler._auth_failure + return self.run_background_coroutine(self.notify_client_termination()) diff --git a/src/lmstudio/_ws_thread.py b/src/lmstudio/_ws_thread.py new file mode 100644 index 0000000..acfed06 --- /dev/null +++ b/src/lmstudio/_ws_thread.py @@ -0,0 +1,213 @@ +"""Background thread async websocket implementation for the LM Studio remote access API.""" + +# Sync API +# Async convenience API (once implemented) + +import asyncio +import threading +import weakref + +from contextlib import contextmanager +from typing import ( + Any, + Coroutine, + Callable, + Generator, + TypeAlias, + TypeVar, +) + +from httpx_ws import AsyncWebSocketSession + +from .schemas import DictObject + +from ._logging import new_logger, LogEventContext +from ._ws_impl import AsyncTaskManager, AsyncWebsocketHandler + +T = TypeVar("T") + + +class BackgroundThread(threading.Thread): + """Background async event loop thread.""" + + def __init__( + self, + task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None, + name: str | None = None, + ) -> None: + # Accepts the same args as `threading.Thread`, *except*: + # * a `task_target` coroutine replaces the `target` function + # * No `daemon` option (always runs as a daemon) + # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run` + # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run` + self._task_target = task_target + self._loop_started = loop_started = threading.Event() + self._task_manager = AsyncTaskManager(on_activation=loop_started.set) + # Annoyingly, we have to mark the background thread as a daemon thread to + # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate + # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 + # TODO: skip thread daemonization when running in a subinterpreter + # (and also disable the convenience API in subinterpreters to avoid hanging on shutdown) + super().__init__(name=name, daemon=True) + weakref.finalize(self, self.terminate) + + @property + def task_manager(self) -> AsyncTaskManager: + return self._task_manager + + def start(self, wait_for_loop: bool = True) -> None: + """Start background thread and (optionally) wait for the event loop to be ready.""" + super().start() + if wait_for_loop: + self.wait_for_loop() + + def run(self) -> None: + """Run an async event loop in the background thread.""" + # Only public to override threading.Thread.run + asyncio.run(self._task_manager.run_until_terminated(self._task_target)) + + def wait_for_loop(self) -> asyncio.AbstractEventLoop | None: + """Wait for the event loop to start from a synchronous foreground thread.""" + if self._task_manager._event_loop is None and not self._task_manager.activated: + self._loop_started.wait() + return self._task_manager._event_loop + + async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None: + """Wait for the event loop to start from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.wait_for_loop) + + def terminate(self) -> bool: + """Request termination of the event loop from a synchronous foreground thread.""" + return self._task_manager.request_termination_threadsafe().result() + + async def terminate_async(self) -> bool: + """Request termination of the event loop from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.terminate) + + def schedule_background_task(self, func: Callable[[], Any]) -> None: + """Schedule given task in the event loop from a synchronous foreground thread.""" + self._task_manager.schedule_task_threadsafe(func) + + async def schedule_background_task_async(self, func: Callable[[], Any]) -> None: + """Schedule given task in the event loop from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.schedule_background_task, func) + + def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: + """Run given coroutine in the event loop and wait for the result.""" + return self._task_manager.run_coroutine_threadsafe(coro).result() + + async def run_background_coroutine_async(self, coro: Coroutine[Any, Any, T]) -> T: + """Run given coroutine in the event loop and await the result.""" + return await asyncio.to_thread(self.run_background_coroutine, coro) + + def call_in_background(self, func: Callable[[], Any]) -> None: + """Call given non-blocking function in the background event loop.""" + self._task_manager.call_soon_threadsafe(func) + + +# By default, the weakref finalization atexit hook is registered lazily. +# This can lead to shutdown sequencing issues if SDK users attempt to access +# client instances (such as the default sync client) from atexit hooks +# registered at import time (so they may end up running after the weakref +# finalization hook has already terminated background threads) +# Creating this finalizer here ensures the weakref finalization hook is +# registered at import time, and hence runs *after* any such hooks +# (assuming the lmstudio SDK is imported before the hooks are registered) +def _register_weakref_atexit_hook() -> None: + class C: + pass + + weakref.finalize(C(), int) + + +_register_weakref_atexit_hook() +del _register_weakref_atexit_hook + + +class AsyncWebsocketThread(BackgroundThread): + def __init__(self, log_context: LogEventContext | None = None) -> None: + super().__init__(task_target=self._log_thread_execution) + self._logger = logger = new_logger(type(self).__name__) + logger.update_context(log_context, thread_id=self.name) + + async def _log_thread_execution(self) -> None: + self._logger.info("Websocket handling thread started") + never_set = asyncio.Event() + try: + # Run the event loop until termination is requested + await never_set.wait() + except (asyncio.CancelledError, GeneratorExit): + raise + except BaseException: + err_msg = "Terminating websocket thread due to exception" + self._logger.debug(err_msg, exc_info=True) + finally: + self._logger.info("Websocket thread terminated") + + +SyncChannelInfo: TypeAlias = tuple[int, Callable[[float | None], Any]] +SyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[float | None], Any]] + + +class SyncToAsyncWebsocketBridge: + def __init__( + self, + ws_thread: AsyncWebsocketThread, + ws_url: str, + auth_details: DictObject, + log_context: LogEventContext, + ) -> None: + self._ws_handler = AsyncWebsocketHandler( + ws_thread.task_manager, + ws_url, + auth_details, + log_context, + ) + self._logger = logger = new_logger(type(self).__name__) + logger.update_context(log_context) + + def connect(self) -> bool: + return self._ws_handler.connect_threadsafe() + + def disconnect(self) -> None: + self._ws_handler.disconnect_threadsafe() + + def send_json(self, message: DictObject) -> None: + self._ws_handler.send_json_threadsafe(message) + + @contextmanager + def open_channel(self) -> Generator[SyncChannelInfo, None, None]: + ws_handler = self._ws_handler + rx_queue, getter = ws_handler.new_threadsafe_rx_queue() + channel_id = ws_handler.acquire_channel_id_threadsafe(rx_queue) + try: + yield channel_id, getter + finally: + ws_handler.release_channel_id_threadsafe(channel_id, rx_queue) + + @contextmanager + def start_call(self) -> Generator[SyncRemoteCallInfo, None, None]: + ws_handler = self._ws_handler + rx_queue, getter = ws_handler.new_threadsafe_rx_queue() + call_id = ws_handler.acquire_call_id_threadsafe(rx_queue) + try: + yield call_id, getter + finally: + ws_handler.release_call_id_threadsafe(call_id, rx_queue) + + def notify_client_termination_threadsafe(self) -> int: + """Send None to all clients with open receive queues (from foreground thread).""" + return self._ws_handler.notify_client_termination_threadsafe() + + # These attributes are currently accessed directly... + @property + def _ws(self) -> AsyncWebSocketSession | None: + return self._ws_handler._ws + + @property + def _connection_failure(self) -> Exception | None: + return self._ws_handler._connection_failure + + @property + def _auth_failure(self) -> Any | None: + return self._ws_handler._auth_failure diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 775e9ee..09c7359 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -77,6 +77,7 @@ ModelSessionTypes, ModelTypesEmbedding, ModelTypesLlm, + MultiplexingManager, # Temporary until migration to AsyncWebsocketHandler PredictionStreamBase, PredictionEndpoint, PredictionFirstTokenCallback, @@ -87,6 +88,7 @@ PromptProcessingCallback, RemoteCallHandler, ResponseSchema, + RxQueue, TModelInfo, check_model_namespace, load_struct, @@ -133,7 +135,7 @@ class AsyncChannel(Generic[T]): def __init__( self, channel_id: int, - rx_queue: asyncio.Queue[Any], + rx_queue: RxQueue, endpoint: ChannelEndpoint[T, Any, Any], send_json: Callable[[DictObject], Awaitable[None]], log_context: LogEventContext, @@ -170,9 +172,8 @@ async def rx_stream( # (we can't easily suppress the SDK's own frames for iterators) message = await self._rx_queue.get() if message is None: - contents = None - else: - contents = self._api_channel.handle_rx_message(message) + raise LMStudioRuntimeError("Client unexpectedly disconnected.") + contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True break @@ -194,7 +195,7 @@ class AsyncRemoteCall: def __init__( self, call_id: int, - rx_queue: asyncio.Queue[Any], + rx_queue: RxQueue, log_context: LogEventContext, notice_prefix: str = "RPC", ) -> None: @@ -214,13 +215,11 @@ async def receive_result(self) -> Any: """Receive call response on the receive queue.""" message = await self._rx_queue.get() if message is None: - return None + raise LMStudioRuntimeError("Client unexpectedly disconnected.") return self._rpc.handle_rx_message(message) -class AsyncLMStudioWebsocket( - LMStudioWebsocket[AsyncWebSocketSession, asyncio.Queue[Any]] -): +class AsyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebSocketSession]): """Asynchronous websocket client that handles demultiplexing of reply messages.""" def __init__( @@ -235,6 +234,7 @@ def __init__( rm.push_async_callback(self._notify_client_termination) self._rx_task: asyncio.Task[None] | None = None self._terminate = asyncio.Event() + self._mux = MultiplexingManager(self._logger) @property def _httpx_ws(self) -> AsyncWebSocketSession | None: @@ -386,7 +386,7 @@ async def open_channel( endpoint: ChannelEndpoint[T, Any, Any], ) -> AsyncGenerator[AsyncChannel[T], None]: """Open a streaming channel over the websocket.""" - rx_queue: asyncio.Queue[Any] = asyncio.Queue() + rx_queue: RxQueue = asyncio.Queue() with self._mux.assign_channel_id(rx_queue) as channel_id: channel = AsyncChannel( channel_id, @@ -427,7 +427,7 @@ async def remote_call( notice_prefix: str = "RPC", ) -> Any: """Make a remote procedure call over the websocket.""" - rx_queue: asyncio.Queue[Any] = asyncio.Queue() + rx_queue: RxQueue = asyncio.Queue() with self._mux.assign_call_id(rx_queue) as call_id: rpc = AsyncRemoteCall( call_id, rx_queue, self._logger.event_context, notice_prefix diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 5eb90a2..39f9d02 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -1,6 +1,15 @@ -"""Sans I/O protocol implementation for the LM Studio remote access API.""" +"""Common protocol implementation for the LM Studio remote access API.""" -# TODO: Migrate additional protocol details from the [a]sync APIs to the sans I/O API +# In order to simplify the websocket demultiplexing logic, this is NOT +# a full sans I/O protocol implementation. Instead, it is an async +# protocol implementation that supports both async interaction +# (from the same event loop or from one running in another thread) +# *and* sync interaction (by blocking on threadsafe futures) +# +# The I/O *transport* layer is still abstracted out, but the internal +# use of asynchronous queues for message demultiplexing is assumed. + +import asyncio import copy import json import uuid @@ -123,7 +132,7 @@ # From here, we publish everything that might be needed # for API type hints, error handling, defining custom # structured responses, and other expected activities. -# The "sans I/O" API itself is *not* automatically exported. +# The shared API itself is *not* automatically exported. # If API consumers want to use that, they need to access it # explicitly via `lmstudio.json_api`, it isn't exported # implicitly as part of the top-level `lmstudio` API. @@ -532,25 +541,23 @@ def _redact_json(data: DictObject | None) -> DictObject | None: return redacted -# TODO: Now that even the sync API uses asyncio.Queue, -# change the multiplexing manager to no longer be generic -TQueue = TypeVar("TQueue") +RxQueue: TypeAlias = asyncio.Queue[Any] -class MultiplexingManager(Generic[TQueue]): +class MultiplexingManager: """Helper class to allocate distinct protocol multiplexing IDs.""" def __init__(self, logger: StructuredLogger) -> None: """Initialize ID multiplexer.""" - self._open_channels: dict[int, TQueue] = {} + self._open_channels: dict[int, RxQueue] = {} self._last_channel_id = 0 - self._pending_calls: dict[int, TQueue] = {} + self._pending_calls: dict[int, RxQueue] = {} self._last_call_id = 0 # `_active_subscriptions` (if we add signal support) # `_last_subscriber_id` (if we add signal support) self._logger = logger - def all_queues(self) -> Iterator[TQueue]: + def all_queues(self) -> Iterator[asyncio.Queue[Any]]: """Iterate over all queues (for example, to send a shutdown message).""" yield from self._open_channels.values() yield from self._pending_calls.values() @@ -562,18 +569,30 @@ def _get_next_channel_id(self) -> int: self._last_channel_id = next_id return next_id - @contextmanager - def assign_channel_id(self, rx_queue: TQueue) -> Generator[int, None, None]: - """Assign distinct streaming channel ID to given queue.""" + def acquire_channel_id(self, rx_queue: RxQueue) -> int: + """Acquire a distinct streaming channel ID for the given queue.""" channel_id = self._get_next_channel_id() self._open_channels[channel_id] = rx_queue + return channel_id + + def release_channel_id(self, channel_id: int, rx_queue: RxQueue) -> None: + """Release a previously acquired streaming channel ID.""" + open_channels = self._open_channels + assigned_queue = open_channels.get(channel_id) + if rx_queue is not assigned_queue: + raise LMStudioRuntimeError( + f"Unexpected change to reply queue for channel ({channel_id} in {self!r})" + ) + del open_channels[channel_id] + + @contextmanager + def assign_channel_id(self, rx_queue: RxQueue) -> Generator[int, None, None]: + """Assign distinct streaming channel ID to given queue.""" + channel_id = self.acquire_channel_id(rx_queue) try: yield channel_id finally: - dropped_queue = self._open_channels.pop(channel_id, None) - assert dropped_queue is rx_queue, ( - f"Unexpected change to reply queue for channel ({channel_id} in {self!r})" - ) + self.release_channel_id(channel_id, rx_queue) def _get_next_call_id(self) -> int: """Get next distinct RPC ID.""" @@ -581,24 +600,36 @@ def _get_next_call_id(self) -> int: self._last_call_id = next_id return next_id - @contextmanager - def assign_call_id(self, rx_queue: TQueue) -> Generator[int, None, None]: - """Assign distinct remote call ID to given queue.""" + def acquire_call_id(self, rx_queue: RxQueue) -> int: + """Acquire a distinct remote call ID for the given queue.""" call_id = self._get_next_call_id() self._pending_calls[call_id] = rx_queue + return call_id + + def release_call_id(self, call_id: int, rx_queue: RxQueue) -> None: + """Release a previously acquired remote call ID.""" + pending_calls = self._pending_calls + assigned_queue = pending_calls.get(call_id) + if rx_queue is not assigned_queue: + raise LMStudioRuntimeError( + f"Unexpected change to reply queue for remote call ({call_id} in {self!r})" + ) + del pending_calls[call_id] + + @contextmanager + def assign_call_id(self, rx_queue: RxQueue) -> Generator[int, None, None]: + """Assign distinct remote call ID to given queue.""" + call_id = self.acquire_call_id(rx_queue) try: yield call_id finally: - dropped_queue = self._pending_calls.pop(call_id, None) - assert dropped_queue is rx_queue, ( - f"Unexpected change to reply queue for remote call ({call_id} in {self!r})" - ) + self.release_call_id(call_id, rx_queue) - def map_rx_message(self, message: DictObject) -> TQueue | None: + def map_rx_message(self, message: DictObject) -> RxQueue | None: """Map received message to the relevant demultiplexing queue.""" # TODO: Define an even-spammier-than-debug trace logging level for this # self._logger.trace("Incoming websocket message", json=message) - rx_queue: TQueue | None = None + rx_queue: RxQueue | None = None match message: case {"channelId": channel_id}: rx_queue = self._open_channels.get(channel_id, None) @@ -631,6 +662,37 @@ def map_rx_message(self, message: DictObject) -> TQueue | None: raise LMStudioClientError(f"Unexpected message: {unmatched}") return rx_queue + def map_tx_message(self, message: DictObject) -> RxQueue | None: + """Map failed message transmission to the relevant demultiplexing queue.""" + # TODO: Define an even-spammier-than-debug trace logging level for this + # self._logger.trace("Failed to send websocket message", json=message) + rx_queue: RxQueue | None = None + match message: + case {"channelId": channel_id}: + rx_queue = self._open_channels.get(channel_id, None) + if rx_queue is None: + if channel_id <= self._last_channel_id: + self._logger.warn( + "Attempted to send message on already closed channel", + channel_id=channel_id, + ) + else: + self._logger.warn( + "Attempted to send message on not yet used channel", + channel_id=channel_id, + ) + case {"callId": call_id}: + rx_queue = self._pending_calls.get(call_id, None) + if rx_queue is None: + self._logger.warn( + "Attempted to send remote call with unknown ID", call_id=call_id + ) + case _: + self._logger.warn( + "Attempted to send top level message on closed session" + ) + return rx_queue + # Channel events are processed via structural pattern matching, so it would be nice # to define them as tuples to make them as lightweight as possible at runtime. @@ -1772,7 +1834,7 @@ def _format_exc(exc: Exception) -> str: return exc_name -class LMStudioWebsocket(Generic[TWebsocket, TQueue]): +class LMStudioWebsocket(Generic[TWebsocket]): """Common base class for LM Studio websocket clients.""" # The common websocket API is narrow due to the sync/async split, @@ -1780,8 +1842,6 @@ class LMStudioWebsocket(Generic[TWebsocket, TQueue]): # Subclasses will declare a specific underlying websocket type _ws: TWebsocket | None - # Subclasses will declare a specific receive queue type - _mux: MultiplexingManager[TQueue] def __init__( self, @@ -1794,7 +1854,6 @@ def __init__( self._auth_details = auth_details self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context, ws_url=ws_url) - self._mux = MultiplexingManager(logger) # Subclasses handle actually creating a websocket instance self._ws = None @@ -1842,7 +1901,7 @@ def _ensure_connected(self, usage: str) -> None | NoReturn: return None -TLMStudioWebsocket = TypeVar("TLMStudioWebsocket", bound=LMStudioWebsocket[Any, Any]) +TLMStudioWebsocket = TypeVar("TLMStudioWebsocket", bound=LMStudioWebsocket[Any]) class ClientBase: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index ee11ef7..cfe6c40 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -1,6 +1,5 @@ """Sync I/O protocol implementation for the LM Studio remote access API.""" -import asyncio import itertools import time import weakref @@ -110,7 +109,7 @@ _model_spec_to_api_dict, _redact_json, ) -from ._ws_impl import AsyncWebsocketThread, SyncToAsyncWebsocketBridge +from ._ws_thread import AsyncWebsocketThread, SyncToAsyncWebsocketBridge from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( EmbeddingRpcCountTokensParameter, @@ -230,6 +229,8 @@ def rx_stream( message = self._get_message(self.timeout) except TimeoutError: raise LMStudioTimeoutError from None + if message is None: + raise LMStudioRuntimeError("Client unexpectedly disconnected.") contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True @@ -284,12 +285,12 @@ def receive_result(self) -> Any: message = self._get_message(self.timeout) except TimeoutError: raise LMStudioTimeoutError from None + if message is None: + raise LMStudioRuntimeError("Client unexpectedly disconnected.") return self._rpc.handle_rx_message(message) -class SyncLMStudioWebsocket( - LMStudioWebsocket[SyncToAsyncWebsocketBridge, asyncio.Queue[Any]] -): +class SyncLMStudioWebsocket(LMStudioWebsocket[SyncToAsyncWebsocketBridge]): """Synchronous websocket client that handles demultiplexing of reply messages.""" def __init__( @@ -328,8 +329,6 @@ def connect(self) -> Self: self._ws_thread, self._ws_url, self._auth_details, - self._get_rx_queue, - self._mux.all_queues, self._logger.event_context, ) if not ws.connect(): @@ -361,13 +360,6 @@ def _send_json(self, message: DictObject) -> None: # Background thread handles the exception conversion ws.send_json(message) - def _get_rx_queue(self, message: Any) -> asyncio.Queue[Any] | None: - if message is None: - self._logger.info(f"Websocket session failed ({self._ws_url})") - self._ws = None - return None - return self._mux.map_rx_message(message) - def _connect_to_endpoint(self, channel: SyncChannel[Any]) -> None: """Connect channel to specified endpoint.""" self._ensure_connected("open channel endpoints") @@ -383,8 +375,7 @@ def open_channel( """Open a streaming channel over the websocket.""" ws = self._ws assert ws is not None - rx_queue, getter = ws.new_rx_queue() - with self._mux.assign_channel_id(rx_queue) as channel_id: + with ws.open_channel() as (channel_id, getter): channel = SyncChannel( channel_id, getter, @@ -423,8 +414,7 @@ def remote_call( """Make a remote procedure call over the websocket.""" ws = self._ws assert ws is not None - rx_queue, getter = ws.new_rx_queue() - with self._mux.assign_call_id(rx_queue) as call_id: + with ws.start_call() as (call_id, getter): rpc = SyncRemoteCall( call_id, getter, self._logger.event_context, notice_prefix ) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index a592437..6e046ea 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -21,7 +21,7 @@ SyncSession, SyncSessionSystem, ) -from lmstudio._ws_impl import AsyncWebsocketThread +from lmstudio._ws_thread import AsyncWebsocketThread from .support import LOCAL_API_HOST