diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 8fa8bb3..1bb19ae 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -29,7 +29,7 @@ ) # Synchronous API still uses an async websocket (just in a background thread) -from anyio import create_task_group +from anyio import create_task_group, get_cancelled_exc_class from exceptiongroup import suppress from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException @@ -47,43 +47,11 @@ T = TypeVar("T") -class BackgroundThread(threading.Thread): - """Background async event loop thread.""" - - def __init__( - self, - task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None, - name: str | None = None, - ) -> None: - # Accepts the same args as `threading.Thread`, *except*: - # * a `task_target` coroutine replaces the `target` function - # * No `daemon` option (always runs as a daemon) - # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run` - # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run` - self._task_target = task_target - self._loop_started = threading.Event() - self._terminate = asyncio.Event() - self._event_loop: asyncio.AbstractEventLoop | None = None - # Annoyingly, we have to mark the background thread as a daemon thread to - # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate - # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 - super().__init__(name=name, daemon=True) - weakref.finalize(self, self.terminate) - - def run(self) -> None: - """Run an async event loop in the background thread.""" - # Only public to override threading.Thread.run - asyncio.run(self._run_until_terminated()) - - def wait_for_loop(self) -> asyncio.AbstractEventLoop | None: - """Wait for the event loop to start from a synchronous foreground thread.""" - if self._event_loop is None and not self._loop_started.is_set(): - self._loop_started.wait() - return self._event_loop - - async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None: - """Wait for the event loop to start from an asynchronous foreground thread.""" - return await asyncio.to_thread(self.wait_for_loop) +class _BackgroundTaskHandlerMixin: + # Subclasses need to handle providing these instance attributes + _event_loop: asyncio.AbstractEventLoop | None + _task_target: Callable[[], Coroutine[Any, Any, Any]] | None + _terminate: asyncio.Event def called_in_background_loop(self) -> bool: """Returns true if currently running in this thread's event loop, false otherwise.""" @@ -123,10 +91,12 @@ async def terminate_async(self) -> bool: """Request termination of the event loop from an asynchronous foreground thread.""" return await asyncio.to_thread(self.terminate) + def _init_event_loop(self) -> None: + self._event_loop = asyncio.get_running_loop() + async def _run_until_terminated(self) -> None: """Run task in the background thread until termination is requested.""" - self._event_loop = asyncio.get_running_loop() - self._loop_started.set() + self._init_event_loop() # Use anyio and exceptiongroup to handle the lack of native task # and exception groups prior to Python 3.11 raise_on_termination, terminated_exc = self._raise_on_termination() @@ -163,6 +133,49 @@ def schedule_background_task(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T assert loop is not None return asyncio.run_coroutine_threadsafe(coro, loop) + +class BackgroundThread(_BackgroundTaskHandlerMixin, threading.Thread): + """Background async event loop thread.""" + + def __init__( + self, + task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None, + name: str | None = None, + ) -> None: + # Accepts the same args as `threading.Thread`, *except*: + # * a `task_target` coroutine replaces the `target` function + # * No `daemon` option (always runs as a daemon) + # Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run` + # Alternative: accept a `task_runner` callback, defaulting to `asyncio.run` + self._task_target = task_target + self._loop_started = threading.Event() + self._terminate = asyncio.Event() + self._event_loop: asyncio.AbstractEventLoop | None = None + # Annoyingly, we have to mark the background thread as a daemon thread to + # prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate + # https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695 + super().__init__(name=name, daemon=True) + weakref.finalize(self, self.terminate) + + def run(self) -> None: + """Run an async event loop in the background thread.""" + # Only public to override threading.Thread.run + asyncio.run(self._run_until_terminated()) + + def _init_event_loop(self) -> None: + super()._init_event_loop() + self._loop_started.set() + + def wait_for_loop(self) -> asyncio.AbstractEventLoop | None: + """Wait for the event loop to start from a synchronous foreground thread.""" + if self._event_loop is None and not self._loop_started.is_set(): + self._loop_started.wait() + return self._event_loop + + async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None: + """Wait for the event loop to start from an asynchronous foreground thread.""" + return await asyncio.to_thread(self.wait_for_loop) + def run_background_task(self, coro: Coroutine[Any, Any, T]) -> T: """Run given coroutine in the background event loop and wait for the result.""" return self.schedule_background_task(coro).result() @@ -178,62 +191,83 @@ def call_in_background(self, callback: Callable[[], Any]) -> None: loop.call_soon_threadsafe(callback) -# TODO: Allow multiple websockets to share a single event loop thread -# (reduces thread usage in sync API, blocker for async API migration) class AsyncWebsocketThread(BackgroundThread): + def __init__(self, log_context: LogEventContext | None = None) -> None: + super().__init__(task_target=self._run_main_task) + self._logger = logger = get_logger(type(self).__name__) + logger.update_context(log_context, thread_id=self.name) + + async def _run_main_task(self) -> None: + self._logger.info("Websocket handling thread started") + never_set = asyncio.Event() + try: + # Run the event loop until termination is requested + await never_set.wait() + except get_cancelled_exc_class(): + pass + except BaseException: + err_msg = "Terminating websocket thread due to exception" + self._logger.debug(err_msg, exc_info=True) + self._logger.info("Websocket thread terminated") + + +# TODO: Improve code sharing between AsyncWebsocketHandler and +# the async-native AsyncLMStudioWebsocket implementation +class AsyncWebsocketHandler(_BackgroundTaskHandlerMixin): + """Async task handler for a single websocket connection.""" + def __init__( self, + ws_thread: AsyncWebsocketThread, ws_url: str, auth_details: DictObject, enqueue_message: Callable[[DictObject], bool], - log_context: LogEventContext, + log_context: LogEventContext | None = None, ) -> None: self._auth_details = auth_details self._connection_attempted = asyncio.Event() self._connection_failure: Exception | None = None self._auth_failure: Any | None = None self._terminate = asyncio.Event() + self._ws_thread = ws_thread self._ws_url = ws_url self._ws: AsyncWebSocketSession | None = None self._rx_task: asyncio.Task[None] | None = None self._queue_message = enqueue_message - super().__init__(task_target=self._run_main_task) + self._logger = get_logger(type(self).__name__) self._logger = logger = get_logger(type(self).__name__) - logger.update_context(log_context, thread_id=self.name) + logger.update_context(log_context) def connect(self) -> bool: - if not self.is_alive(): - self.start() - loop = self.wait_for_loop() # Block until connection has been attempted + ws_thread = self._ws_thread + if not ws_thread.is_alive(): + raise RuntimeError("Websocket handling thread has failed unexpectedly") + loop = ws_thread.wait_for_loop() # Block until loop is available if loop is None: - return False + raise RuntimeError("Websocket handling thread has no event loop") + ws_thread.schedule_background_task(self._run_until_terminated()) asyncio.run_coroutine_threadsafe( self._connection_attempted.wait(), loop ).result() return self._ws is not None - def disconnect(self) -> None: - if self._ws is not None: - self.terminate() - # Ensure thread has terminated - self.join() - - async def _run_main_task(self) -> None: - self._logger.info("Websocket thread started") + async def _task_target(self) -> None: + self._logger.info("Websocket handling task started") + self._init_event_loop() try: - await self._main_task() + await self._handle_ws() + except get_cancelled_exc_class(): + pass except BaseException: - err_msg = "Terminating websocket thread due to exception" + err_msg = "Terminating websocket task due to exception" self._logger.debug(err_msg, exc_info=True) finally: # Ensure the foreground thread is unblocked even if the # background async task errors out completely self._connection_attempted.set() - self._logger.info("Websocket thread terminated") + self._logger.info("Websocket task terminated") - # TODO: Improve code sharing between this background thread async websocket - # and the async-native AsyncLMStudioWebsocket implementation - async def _main_task(self) -> None: + async def _handle_ws(self) -> None: resources = AsyncExitStack() try: ws: AsyncWebSocketSession = await resources.enter_async_context( @@ -274,6 +308,10 @@ async def _send_json(self, message: DictObject) -> None: self._logger.debug(str(err), exc_info=True) raise err from None + def send_json(self, message: DictObject) -> None: + future = self.schedule_background_task(self._send_json(message)) + future.result() # Block until the message is sent + async def _receive_json(self) -> Any: # This is only called if the websocket has been created assert self.called_in_background_loop() @@ -335,8 +373,6 @@ async def _demultiplexing_task(self) -> None: finally: self._logger.info("Websocket closed, terminating demultiplexing task.") - raise_on_termination, terminated_exc = self._raise_on_termination() - async def _receive_messages(self) -> None: """Process received messages until task is cancelled.""" while True: @@ -349,6 +385,38 @@ async def _receive_messages(self) -> None: self._terminate.set() break + +class SyncToAsyncWebsocketBridge: + def __init__( + self, + ws_thread: AsyncWebsocketThread, + ws_url: str, + auth_details: DictObject, + enqueue_message: Callable[[DictObject], bool], + log_context: LogEventContext, + ) -> None: + self._ws_handler = AsyncWebsocketHandler( + ws_thread, ws_url, auth_details, enqueue_message, log_context + ) + + def connect(self) -> bool: + return self._ws_handler.connect() + + def disconnect(self) -> None: + self._ws_handler.terminate() + def send_json(self, message: DictObject) -> None: - # Block until message has been sent - self.run_background_task(self._send_json(message)) + self._ws_handler.send_json(message) + + # These attributes are currently accessed directly... + @property + def _ws(self) -> AsyncWebSocketSession | None: + return self._ws_handler._ws + + @property + def _connection_failure(self) -> Exception | None: + return self._ws_handler._connection_failure + + @property + def _auth_failure(self) -> Any | None: + return self._ws_handler._auth_failure diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index d801c80..7b382bf 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -109,7 +109,7 @@ _model_spec_to_api_dict, _redact_json, ) -from ._ws_impl import AsyncWebsocketThread +from ._ws_impl import AsyncWebsocketThread, SyncToAsyncWebsocketBridge from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config from ._sdk_models import ( EmbeddingRpcCountTokensParameter, @@ -233,17 +233,21 @@ def receive_result(self) -> Any: return self._rpc.handle_rx_message(message) -class SyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebsocketThread, queue.Queue[Any]]): +class SyncLMStudioWebsocket( + LMStudioWebsocket[SyncToAsyncWebsocketBridge, queue.Queue[Any]] +): """Synchronous websocket client that handles demultiplexing of reply messages.""" def __init__( self, + ws_thread: AsyncWebsocketThread, ws_url: str, auth_details: DictObject, log_context: LogEventContext | None = None, ) -> None: """Initialize synchronous websocket client.""" super().__init__(ws_url, auth_details, log_context) + self._ws_thread = ws_thread @property def _httpx_ws(self) -> AsyncWebSocketSession | None: @@ -266,7 +270,8 @@ def __exit__(self, *args: Any) -> None: def connect(self) -> Self: """Connect to and authenticate with the LM Studio API.""" self._fail_if_connected("Attempted to connect already connected websocket") - ws = AsyncWebsocketThread( + ws = SyncToAsyncWebsocketBridge( + self._ws_thread, self._ws_url, self._auth_details, self._enqueue_message, @@ -409,7 +414,9 @@ def connect(self) -> SyncLMStudioWebsocket: session_url = f"ws://{api_host}/{namespace}" resources = self._resources self._lmsws = lmsws = resources.enter_context( - SyncLMStudioWebsocket(session_url, self._client._auth_details) + SyncLMStudioWebsocket( + self._client._ws_thread, session_url, self._client._auth_details + ) ) return lmsws @@ -1482,8 +1489,11 @@ def __init__(self, api_host: str | None = None) -> None: """Initialize API client.""" super().__init__(api_host) self._resources = rm = ExitStack() + self._ws_thread = ws_thread = AsyncWebsocketThread(dict(client=repr(self))) + ws_thread.start() + rm.callback(ws_thread.terminate) self._sessions: dict[str, SyncSession] = {} - # Suport GC-based resource management in the sync API by + # Support GC-based resource management in the sync API by # finalizing at the client layer, and letting its resource # manager handle clearing up everything else rm.callback(self._sessions.clear) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index d6209fd..495f4d1 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -1,6 +1,7 @@ """Test common client session behaviour.""" import logging +from typing import Generator import pytest from pytest import LogCaptureFixture as LogCap @@ -11,15 +12,16 @@ LMStudioWebsocketError, ) from lmstudio.async_api import ( + AsyncLMStudioWebsocket, AsyncSession, AsyncSessionSystem, ) from lmstudio.sync_api import ( + SyncLMStudioWebsocket, SyncSession, SyncSessionSystem, ) -from lmstudio.async_api import AsyncLMStudioWebsocket -from lmstudio.sync_api import SyncLMStudioWebsocket +from lmstudio._ws_impl import AsyncWebsocketThread from .support import LOCAL_API_HOST @@ -175,11 +177,23 @@ async def test_websocket_cm_async(caplog: LogCap) -> None: assert httpx_ws.connection.state.value in WS_CLOSING_STATES +@pytest.fixture +def ws_thread() -> Generator[AsyncWebsocketThread, None, None]: + ws_thread = AsyncWebsocketThread() + ws_thread.start() + try: + yield ws_thread + finally: + ws_thread.terminate() + + @pytest.mark.lmstudio -def test_websocket_cm_sync(caplog: LogCap) -> None: +def test_websocket_cm_sync(ws_thread: AsyncWebsocketThread, caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) auth_details = Client._create_auth_message() - lmsws = SyncLMStudioWebsocket(f"http://{LOCAL_API_HOST}/system", auth_details) + lmsws = SyncLMStudioWebsocket( + ws_thread, f"http://{LOCAL_API_HOST}/system", auth_details + ) # SDK client websockets start out disconnected assert not lmsws.connected # Entering the CM opens the websocket if it isn't already open