diff --git a/misc/open_client.py b/misc/open_client.py new file mode 100755 index 0000000..830bd57 --- /dev/null +++ b/misc/open_client.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +"""Open a client instance for link failure testing.""" +import asyncio +import logging +import sys +import time + +from lmstudio import AsyncClient, Client + +LINK_POLLING_INTERVAL = 1 + +async def open_client_async(): + """Start async client, wait for link failure.""" + print("Connecting async client...") + async with AsyncClient() as client: + await client.list_downloaded_models() + print ("Async client connected. Close LM Studio to terminate.") + while True: + await asyncio.sleep(LINK_POLLING_INTERVAL) + await client.list_downloaded_models() + +def open_client_sync(): + """Start sync client, wait for link failure.""" + print("Connecting sync client...") + with Client() as client: + client.list_downloaded_models() + print ("Sync client connected. Close LM Studio to terminate.") + while True: + time.sleep(LINK_POLLING_INTERVAL) + client.list_downloaded_models() + +if __name__ == "__main__": + # Link polling makes debug logging excessively spammy + log_level = logging.DEBUG if "--debug" in sys.argv else logging.INFO + logging.basicConfig(level=log_level) + if "--async" in sys.argv: + asyncio.run(open_client_async()) + else: + open_client_sync() diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 9458b0e..8d33ec1 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -290,6 +290,8 @@ async def _log_thread_execution(self) -> None: try: # Run the event loop until termination is requested await never_set.wait() + except asyncio.CancelledError: + raise except BaseException: err_msg = "Terminating websocket thread due to exception" self._logger.debug(err_msg, exc_info=True) @@ -309,7 +311,7 @@ def __init__( task_manager: AsyncTaskManager, ws_url: str, auth_details: DictObject, - enqueue_message: Callable[[DictObject], bool], + enqueue_message: Callable[[DictObject | None], Awaitable[bool]], log_context: LogEventContext | None = None, ) -> None: self._auth_details = auth_details @@ -357,6 +359,8 @@ async def _logged_ws_handler(self) -> None: self._logger.info("Websocket handling task started") try: await self._handle_ws() + except asyncio.CancelledError: + raise except BaseException: err_msg = "Terminating websocket task due to exception" self._logger.debug(err_msg, exc_info=True) @@ -364,7 +368,7 @@ async def _logged_ws_handler(self) -> None: # Ensure the foreground thread is unblocked even if the # background async task errors out completely self._connection_attempted.set() - self._logger.info("Websocket task terminated") + self._logger.info("Websocket task terminated") async def _handle_ws(self) -> None: assert self._task_manager.check_running_in_task_loop() @@ -396,12 +400,19 @@ def _clear_task_state() -> None: await self._receive_messages() finally: self._logger.info("Websocket demultiplexing task terminated.") + # Notify foreground thread of background thread termination + # (this covers termination due to link failure) + await self._enqueue_message(None) dc_timeout = self.WS_DISCONNECT_TIMEOUT with move_on_after(dc_timeout, shield=True) as cancel_scope: # Workaround an anyio/httpx-ws issue with task cancellation: # https://github.com/frankie567/httpx-ws/issues/107 self._ws = None - await ws.close() + try: + await ws.close() + except Exception: + # Closing may fail if the link is already down + pass if cancel_scope.cancelled_caught: self._logger.warn( f"Failed to close websocket in {dc_timeout} seconds." @@ -413,7 +424,9 @@ async def send_json(self, message: DictObject) -> None: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() ws = self._ws - assert ws is not None + if ws is None: + # Assume app is shutting down and the owning task has already been cancelled + return try: await ws.send_json(message) except Exception as exc: @@ -430,7 +443,9 @@ async def _receive_json(self) -> Any: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() ws = self._ws - assert ws is not None + if ws is None: + # Assume app is shutting down and the owning task has already been cancelled + return try: return await ws.receive_json() except Exception as exc: @@ -443,7 +458,9 @@ async def _authenticate(self) -> bool: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() ws = self._ws - assert ws is not None + if ws is None: + # Assume app is shutting down and the owning task has already been cancelled + return False auth_message = self._auth_details await self.send_json(auth_message) auth_result = await self._receive_json() @@ -461,11 +478,11 @@ async def _process_next_message(self) -> bool: # This is only called if the websocket has been created assert self._task_manager.check_running_in_task_loop() ws = self._ws - assert ws is not None + if ws is None: + # Assume app is shutting down and the owning task has already been cancelled + return False message = await ws.receive_json() - # Enqueueing messages may be a blocking call - # TODO: Require it to return an Awaitable, move to_thread call to the sync bridge - return await asyncio.to_thread(self._enqueue_message, message) + return await self._enqueue_message(message) async def _receive_messages(self) -> None: """Process received messages until task is cancelled.""" @@ -475,7 +492,7 @@ async def _receive_messages(self) -> None: except (LMStudioWebsocketError, HTTPXWSException): if self._ws is not None and not self._ws_disconnected.is_set(): # Websocket failed unexpectedly (rather than due to client shutdown) - self._logger.exception("Websocket failed, terminating session.") + self._logger.error("Websocket failed, terminating session.") break @@ -485,11 +502,14 @@ def __init__( ws_thread: AsyncWebsocketThread, ws_url: str, auth_details: DictObject, - enqueue_message: Callable[[DictObject], bool], + enqueue_message: Callable[[DictObject | None], bool], log_context: LogEventContext, ) -> None: + async def enqueue_async(message: DictObject | None) -> bool: + return await asyncio.to_thread(enqueue_message, message) + self._ws_handler = AsyncWebsocketHandler( - ws_thread.task_manager, ws_url, auth_details, enqueue_message, log_context + ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context ) def connect(self) -> bool: diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index b04e926..ad1ce55 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -1,7 +1,6 @@ """Async I/O protocol implementation for the LM Studio remote access API.""" import asyncio -import asyncio.queues import warnings from abc import abstractmethod @@ -28,6 +27,8 @@ TypeIs, ) +from anyio import create_task_group +from anyio.abc import TaskGroup from httpx import RequestError, HTTPStatusError from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException @@ -163,7 +164,10 @@ async def rx_stream( # Avoid emitting tracebacks that delve into supporting libraries # (we can't easily suppress the SDK's own frames for iterators) message = await self._rx_queue.get() - contents = self._api_channel.handle_rx_message(message) + if message is None: + contents = None + else: + contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True break @@ -204,6 +208,8 @@ def get_rpc_message( async def receive_result(self) -> Any: """Receive call response on the receive queue.""" message = await self._rx_queue.get() + if message is None: + return None return self._rpc.handle_rx_message(message) @@ -220,8 +226,10 @@ def __init__( ) -> None: """Initialize asynchronous websocket client.""" super().__init__(ws_url, auth_details, log_context) - self._resource_manager = AsyncExitStack() + self._resource_manager = rm = AsyncExitStack() + rm.push_async_callback(self._notify_client_termination) self._rx_task: asyncio.Task[None] | None = None + self._terminate = asyncio.Event() @property def _httpx_ws(self) -> AsyncWebSocketSession | None: @@ -241,7 +249,9 @@ async def __aexit__(self, *args: Any) -> None: async def _send_json(self, message: DictObject) -> None: # Callers are expected to call `_ensure_connected` before this method ws = self._ws - assert ws is not None + if ws is None: + # Assume app is shutting down and the owning task has already been cancelled + return try: await ws.send_json(message) except Exception as exc: @@ -253,7 +263,9 @@ async def _send_json(self, message: DictObject) -> None: async def _receive_json(self) -> Any: # Callers are expected to call `_ensure_connected` before this method ws = self._ws - assert ws is not None + if ws is None: + # Assume app is shutting down and the owning task has already been cancelled + return try: return await ws.receive_json() except Exception as exc: @@ -291,7 +303,7 @@ async def connect(self) -> Self: self._rx_task = rx_task = asyncio.create_task(self._receive_messages()) async def _terminate_rx_task() -> None: - rx_task.cancel() + self._terminate.set() try: await rx_task except asyncio.CancelledError: @@ -305,19 +317,34 @@ async def disconnect(self) -> None: """Drop the LM Studio API connection.""" self._ws = None self._rx_task = None - await self._notify_client_termination() + self._terminate.set() await self._resource_manager.aclose() self._logger.info(f"Websocket session disconnected ({self._ws_url})") aclose = disconnect + async def _cancel_on_termination(self, tg: TaskGroup) -> None: + await self._terminate.wait() + tg.cancel_scope.cancel() + async def _process_next_message(self) -> bool: """Process the next message received on the websocket. Returns True if a message queue was updated. """ self._ensure_connected("receive messages") - message = await self._receive_json() + async with create_task_group() as tg: + tg.start_soon(self._cancel_on_termination, tg) + try: + message = await self._receive_json() + except (LMStudioWebsocketError, HTTPXWSException): + if self._ws is not None and not self._terminate.is_set(): + # Websocket failed unexpectedly (rather than due to client shutdown) + self._logger.error("Websocket failed, terminating session.") + self._terminate.set() + tg.cancel_scope.cancel() + if self._terminate.is_set(): + return (await self._notify_client_termination()) > 0 rx_queue = self._mux.map_rx_message(message) if rx_queue is None: return False @@ -326,18 +353,20 @@ async def _process_next_message(self) -> bool: async def _receive_messages(self) -> None: """Process received messages until connection is terminated.""" - while True: - try: - await self._process_next_message() - except (LMStudioWebsocketError, HTTPXWSException): - self._logger.exception("Websocket failed, terminating session.") - await self.disconnect() - break + while not self._terminate.is_set(): + await self._process_next_message() - async def _notify_client_termination(self) -> None: + async def _notify_client_termination(self) -> int: """Send None to all clients with open receive queues.""" + num_clients = 0 for rx_queue in self._mux.all_queues(): await rx_queue.put(None) + num_clients += 1 + self._logger.info( + f"Notified {num_clients} clients of websocket termination", + num_clients=num_clients, + ) + return num_clients async def _connect_to_endpoint(self, channel: AsyncChannel[Any]) -> None: """Connect channel to specified endpoint.""" @@ -362,6 +391,9 @@ async def open_channel( self._logger.event_context, ) await self._connect_to_endpoint(channel) + if self._terminate.is_set(): + # Link has been terminated, ensure client gets a response + await rx_queue.put(None) yield channel async def _send_call( @@ -396,6 +428,9 @@ async def remote_call( call_id, rx_queue, self._logger.event_context, notice_prefix ) await self._send_call(rpc, endpoint, params) + if self._terminate.is_set(): + # Link has been terminated, ensure client gets a response + await rx_queue.put(None) return await rpc.receive_result() diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 029d2c2..b95c2fd 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -352,10 +352,10 @@ def _format_server_error(details: SerializedLMSExtendedError) -> str: lines.extend(_get_data_lines(details.error_data, " ")) if details.cause is not None: lines.extend(("", " Reported cause:")) - lines.extend(f" {details.cause}") + lines.append(f" {details.cause}") if details.suggestion is not None: lines.extend(("", " Suggested potential remedy:")) - lines.extend(f" {details.suggestion}") + lines.append(f" {details.suggestion}") # Only use the multi-line format if at least one # of the extended error fields is populated if lines: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index d2623ce..bf15038 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -291,7 +291,6 @@ def disconnect(self) -> None: """Drop the LM Studio API connection.""" ws = self._ws self._ws = None - self._rx_task = None if ws is not None: self._logger.debug(f"Disconnecting websocket session ({self._ws_url})") self._notify_client_termination() @@ -301,16 +300,27 @@ def disconnect(self) -> None: close = disconnect def _enqueue_message(self, message: Any) -> bool: + if message is None: + self._logger.info(f"Websocket session failed ({self._ws_url})") + self._ws = None + return self._notify_client_termination() > 0 rx_queue = self._mux.map_rx_message(message) if rx_queue is None: return False rx_queue.put(message) return True - def _notify_client_termination(self) -> None: + def _notify_client_termination(self) -> int: """Send None to all clients with open receive queues.""" + num_clients = 0 for rx_queue in self._mux.all_queues(): rx_queue.put(None) + num_clients += 1 + self._logger.info( + f"Notified {num_clients} clients of websocket termination", + num_clients=num_clients, + ) + return num_clients def _send_json(self, message: DictObject) -> None: # Callers are expected to call `_ensure_connected` before this method @@ -342,6 +352,9 @@ def open_channel( self._logger.event_context, ) self._connect_to_endpoint(channel) + if self._ws is None: + # Link has been terminated, ensure client gets a response + rx_queue.put(None) yield channel def _send_call( @@ -376,6 +389,9 @@ def remote_call( call_id, rx_queue, self._logger.event_context, notice_prefix ) self._send_call(rpc, endpoint, params) + if self._ws is None: + # Link has been terminated, ensure client gets a response + rx_queue.put(None) return rpc.receive_result()