diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 34217f0..80dba6f 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -23,6 +23,10 @@ TypeAlias, TypeVar, ) +from typing_extensions import ( + # Native in 3.11+ + Self, +) from anyio import create_task_group, move_on_after from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException @@ -45,13 +49,19 @@ class AsyncTaskManager: - def __init__(self, *, on_activation: Callable[[], Any] | None) -> None: + def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None: self._activated = False self._event_loop: asyncio.AbstractEventLoop | None = None self._on_activation = on_activation self._task_queue: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue() self._terminate = asyncio.Event() self._terminated = asyncio.Event() + # For the case where the task manager is run via its context manager + self._tm_started = asyncio.Event() + self._tm_task: asyncio.Task[Any] | None = None + + ACTIVATION_TIMEOUT = 5 # Just starts an async task, should be fast + TERMINATION_TIMEOUT = 20 # May have to shut down TCP links @property def activated(self) -> bool: @@ -65,6 +75,20 @@ def active(self) -> bool: and not self._terminated.is_set() ) + async def __aenter__(self) -> Self: + # Handle reentrancy the same way files do: + # allow nested use as a CM, but close on the first exit + if self._tm_task is None: + self._tm_task = asyncio.create_task(self.run_until_terminated()) + with move_on_after(self.ACTIVATION_TIMEOUT): + await self._tm_started.wait() + return self + + async def __aexit__(self, *args: Any) -> None: + await self.request_termination() + with move_on_after(self.TERMINATION_TIMEOUT): + await self._terminated.wait() + def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool: """Returns if running in this manager's event loop, raises RuntimeError otherwise.""" this_loop = self._event_loop @@ -138,6 +162,7 @@ def _init_event_loop(self) -> None: notify = self._on_activation if notify is not None: notify() + self._tm_started.set() async def run_until_terminated( self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None @@ -218,9 +243,6 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle: AsyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]] -# TODO: Improve code sharing between AsyncWebsocketHandler and -# the async-native AsyncLMStudioWebsocket implementation -# (likely by migrating the websocket over to using the handler) class AsyncWebsocketHandler: """Async task handler for a single websocket connection.""" @@ -243,7 +265,7 @@ def __init__( self._ws_disconnected = asyncio.Event() self._rx_task: asyncio.Task[None] | None = None self._logger = logger = new_logger(type(self).__name__) - logger.update_context(log_context) + logger.update_context(log_context, ws_url=ws_url) self._mux = MultiplexingManager(logger) async def connect(self) -> bool: @@ -275,7 +297,7 @@ def disconnect_threadsafe(self) -> None: task_manager.run_coroutine_threadsafe(self.disconnect()).result() async def _logged_ws_handler(self) -> None: - self._logger.info("Websocket handling task started") + self._logger.debug("Websocket handling task started") try: await self._handle_ws() except (asyncio.CancelledError, GeneratorExit): @@ -287,7 +309,7 @@ async def _logged_ws_handler(self) -> None: # Ensure connections attempt are unblocked even if the # background async task errors out completely self._connection_attempted.set() - self._logger.info("Websocket task terminated") + self._logger.debug("Websocket task terminated") async def _handle_ws(self) -> None: assert self._task_manager.check_running_in_task_loop() @@ -311,12 +333,12 @@ def _clear_task_state() -> None: if not await self._authenticate(): return self._connection_attempted.set() - self._logger.info(f"Websocket session established ({self._ws_url})") + self._logger.info("Websocket session established") # Task will run until message reception fails or is cancelled try: await self._receive_messages() finally: - self._logger.info("Websocket demultiplexing task terminated.") + self._logger.debug("Websocket demultiplexing task terminated.") # Notify foreground thread of background thread termination # (this covers termination due to link failure) await self.notify_client_termination() diff --git a/src/lmstudio/async_api.py b/src/lmstudio/async_api.py index 09c7359..25605a4 100644 --- a/src/lmstudio/async_api.py +++ b/src/lmstudio/async_api.py @@ -27,10 +27,7 @@ TypeIs, ) -from anyio import create_task_group -from anyio.abc import TaskGroup -from httpx import RequestError, HTTPStatusError -from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException +from httpx_ws import AsyncWebSocketSession from .sdk_api import LMStudioRuntimeError, sdk_public_api, sdk_public_api_async from .schemas import AnyLMStudioStruct, DictObject @@ -77,7 +74,6 @@ ModelSessionTypes, ModelTypesEmbedding, ModelTypesLlm, - MultiplexingManager, # Temporary until migration to AsyncWebsocketHandler PredictionStreamBase, PredictionEndpoint, PredictionFirstTokenCallback, @@ -88,7 +84,6 @@ PromptProcessingCallback, RemoteCallHandler, ResponseSchema, - RxQueue, TModelInfo, check_model_namespace, load_struct, @@ -105,6 +100,7 @@ LlmRpcApplyPromptTemplateParameter, ModelCompatibilityType, ) +from ._ws_impl import AsyncTaskManager, AsyncWebsocketHandler from ._logging import new_logger, LogEventContext @@ -135,14 +131,14 @@ class AsyncChannel(Generic[T]): def __init__( self, channel_id: int, - rx_queue: RxQueue, + get_message: Callable[[], Awaitable[Any]], endpoint: ChannelEndpoint[T, Any, Any], send_json: Callable[[DictObject], Awaitable[None]], log_context: LogEventContext, ) -> None: """Initialize asynchronous websocket streaming channel.""" self._is_finished = False - self._rx_queue = rx_queue + self._get_message = get_message self._api_channel = ChannelHandler(channel_id, endpoint, log_context) self._send_json = send_json @@ -170,9 +166,9 @@ async def rx_stream( with sdk_public_api(): # Avoid emitting tracebacks that delve into supporting libraries # (we can't easily suppress the SDK's own frames for iterators) - message = await self._rx_queue.get() + message = await self._get_message() if message is None: - raise LMStudioRuntimeError("Client unexpectedly disconnected.") + raise LMStudioWebsocketError("Client unexpectedly disconnected.") contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True @@ -195,12 +191,12 @@ class AsyncRemoteCall: def __init__( self, call_id: int, - rx_queue: RxQueue, + get_message: Callable[[], Awaitable[Any]], log_context: LogEventContext, notice_prefix: str = "RPC", ) -> None: """Initialize asynchronous remote procedure call.""" - self._rx_queue = rx_queue + self._get_message = get_message self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix) self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context, call_id=call_id) @@ -213,9 +209,9 @@ def get_rpc_message( async def receive_result(self) -> Any: """Receive call response on the receive queue.""" - message = await self._rx_queue.get() + message = await self._get_message() if message is None: - raise LMStudioRuntimeError("Client unexpectedly disconnected.") + raise LMStudioWebsocketError("Client unexpectedly disconnected.") return self._rpc.handle_rx_message(message) @@ -224,17 +220,16 @@ class AsyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebSocketSession]): def __init__( self, + task_manager: AsyncTaskManager, ws_url: str, auth_details: DictObject, log_context: LogEventContext | None = None, ) -> None: """Initialize asynchronous websocket client.""" super().__init__(ws_url, auth_details, log_context) - self._resource_manager = rm = AsyncExitStack() - rm.push_async_callback(self._notify_client_termination) - self._rx_task: asyncio.Task[None] | None = None - self._terminate = asyncio.Event() - self._mux = MultiplexingManager(self._logger) + self._ws_handler = AsyncWebsocketHandler( + task_manager, ws_url, auth_details, log_context + ) @property def _httpx_ws(self) -> AsyncWebSocketSession | None: @@ -251,127 +246,32 @@ async def __aenter__(self) -> Self: async def __aexit__(self, *args: Any) -> None: await self.disconnect() - async def _send_json(self, message: DictObject) -> None: - # Callers are expected to call `_ensure_connected` before this method - ws = self._ws - if ws is None: - # Assume app is shutting down and the owning task has already been cancelled - return - try: - await ws.send_json(message) - except Exception as exc: - err = self._get_tx_error(message, exc) - # Log the underlying exception info, but simplify the raised traceback - self._logger.debug(str(err), exc_info=True) - raise err from None - - async def _receive_json(self) -> Any: - # Callers are expected to call `_ensure_connected` before this method - ws = self._ws - if ws is None: - # Assume app is shutting down and the owning task has already been cancelled - return - try: - return await ws.receive_json() - except Exception as exc: - err = self._get_rx_error(exc) - # Log the underlying exception info, but simplify the raised traceback - self._logger.debug(str(err), exc_info=True) - raise err from None - async def connect(self) -> Self: """Connect to and authenticate with the LM Studio API.""" self._fail_if_connected("Attempted to connect already connected websocket") - resources = self._resource_manager - try: - ws: AsyncWebSocketSession = await resources.enter_async_context( - aconnect_ws(self._ws_url) - ) - except (RequestError, HTTPStatusError, HTTPXWSException) as exc: - err = self._get_connection_failure_error(exc) - # Log the underlying exception info, but simplify the raised traceback - self._logger.debug(str(err), exc_info=True) - raise err from None - self._ws = ws - self._logger.debug("Websocket connected") - # Authenticate - auth_message = self._auth_details - await self._send_json(auth_message) - auth_result = await self._receive_json() - self._logger.debug("Websocket authenticated", json=auth_result) - if not auth_result["success"]: - raise self._get_auth_failure_error(auth_result["error"]) - # Start the websocket demultiplexing task - # The websocket manages a task group internally and will complain - # if those aren't managed correctly, so we don't worry about - # creating a task group of our own here - self._rx_task = rx_task = asyncio.create_task(self._receive_messages()) - - async def _terminate_rx_task() -> None: - self._terminate.set() - try: - await rx_task - except asyncio.CancelledError: - pass - - self._resource_manager.push_async_callback(_terminate_rx_task) - self._logger.info(f"Websocket session established ({self._ws_url})") + self._logger.info("Connecting websocket session") + ws_handler = self._ws_handler + if not await self._ws_handler.connect(): + if ws_handler._connection_failure is not None: + raise self._get_connection_failure_error(ws_handler._connection_failure) + if ws_handler._auth_failure is not None: + raise self._get_auth_failure_error(ws_handler._auth_failure) + self._logger.error("Connection failed, but no failure reason reported.") + raise self._get_connection_failure_error() + self._ws = ws_handler._ws return self async def disconnect(self) -> None: """Drop the LM Studio API connection.""" self._ws = None - self._rx_task = None - self._terminate.set() - await self._resource_manager.aclose() - self._logger.info(f"Websocket session disconnected ({self._ws_url})") + await self._ws_handler.disconnect() + self._logger.info("Websocket session disconnected") aclose = disconnect - async def _cancel_on_termination(self, tg: TaskGroup) -> None: - await self._terminate.wait() - tg.cancel_scope.cancel() - - async def _process_next_message(self) -> bool: - """Process the next message received on the websocket. - - Returns True if a message queue was updated. - """ - self._ensure_connected("receive messages") - async with create_task_group() as tg: - tg.start_soon(self._cancel_on_termination, tg) - try: - message = await self._receive_json() - except (LMStudioWebsocketError, HTTPXWSException): - if self._ws is not None and not self._terminate.is_set(): - # Websocket failed unexpectedly (rather than due to client shutdown) - self._logger.error("Websocket failed, terminating session.") - self._terminate.set() - tg.cancel_scope.cancel() - if self._terminate.is_set(): - return (await self._notify_client_termination()) > 0 - rx_queue = self._mux.map_rx_message(message) - if rx_queue is None: - return False - await rx_queue.put(message) - return True - - async def _receive_messages(self) -> None: - """Process received messages until connection is terminated.""" - while not self._terminate.is_set(): - await self._process_next_message() - - async def _notify_client_termination(self) -> int: - """Send None to all clients with open receive queues.""" - num_clients = 0 - for rx_queue in self._mux.all_queues(): - await rx_queue.put(None) - num_clients += 1 - self._logger.debug( - f"Notified {num_clients} clients of websocket termination", - num_clients=num_clients, - ) - return num_clients + async def _send_json(self, message: DictObject) -> None: + # Callers are expected to call `_ensure_connected` before this method + await self._ws_handler.send_json(message) async def _connect_to_endpoint(self, channel: AsyncChannel[Any]) -> None: """Connect channel to specified endpoint.""" @@ -386,19 +286,15 @@ async def open_channel( endpoint: ChannelEndpoint[T, Any, Any], ) -> AsyncGenerator[AsyncChannel[T], None]: """Open a streaming channel over the websocket.""" - rx_queue: RxQueue = asyncio.Queue() - with self._mux.assign_channel_id(rx_queue) as channel_id: + with self._ws_handler.open_channel() as (channel_id, getter): channel = AsyncChannel( channel_id, - rx_queue, + getter, endpoint, self._send_json, self._logger.event_context, ) await self._connect_to_endpoint(channel) - if self._terminate.is_set(): - # Link has been terminated, ensure client gets a response - await rx_queue.put(None) yield channel async def _send_call( @@ -427,15 +323,11 @@ async def remote_call( notice_prefix: str = "RPC", ) -> Any: """Make a remote procedure call over the websocket.""" - rx_queue: RxQueue = asyncio.Queue() - with self._mux.assign_call_id(rx_queue) as call_id: + with self._ws_handler.start_call() as (call_id, getter): rpc = AsyncRemoteCall( - call_id, rx_queue, self._logger.event_context, notice_prefix + call_id, getter, self._logger.event_context, notice_prefix ) await self._send_call(rpc, endpoint, params) - if self._terminate.is_set(): - # Link has been terminated, ensure client gets a response - await rx_queue.put(None) return await rpc.receive_result() @@ -473,8 +365,11 @@ async def connect(self) -> AsyncLMStudioWebsocket: ) session_url = f"ws://{api_host}/{namespace}" resources = self._resource_manager + client = self._client self._lmsws = lmsws = await resources.enter_async_context( - AsyncLMStudioWebsocket(session_url, self._client._auth_details) + AsyncLMStudioWebsocket( + client._task_manager, session_url, client._auth_details + ) ) return lmsws @@ -1040,10 +935,6 @@ class AsyncSessionLlm( API_NAMESPACE = "llm" _API_TYPES = ModelTypesLlm - def __init__(self, client: "AsyncClient") -> None: - """Initialize API client session for LLM interaction.""" - super().__init__(client) - def _create_handle(self, model_identifier: str) -> "AsyncLLM": """Create a symbolic handle to the specified LLM model.""" return AsyncLLM(model_identifier, self) @@ -1150,10 +1041,6 @@ class AsyncSessionEmbedding( API_NAMESPACE = "embedding" _API_TYPES = ModelTypesEmbedding - def __init__(self, client: "AsyncClient") -> None: - """Initialize API client session for embedding model interaction.""" - super().__init__(client) - def _create_handle(self, model_identifier: str) -> "AsyncEmbeddingModel": """Create a symbolic handle to the specified embedding model.""" return AsyncEmbeddingModel(model_identifier, self) @@ -1402,6 +1289,7 @@ def __init__(self, api_host: str | None = None) -> None: super().__init__(api_host) self._resources = AsyncExitStack() self._sessions: dict[str, AsyncSession] = {} + self._task_manager = AsyncTaskManager() # Unlike the sync API, we don't support GC-based resource # management in the async API. Structured concurrency # is required to reliably offer graceful termination in @@ -1410,7 +1298,9 @@ def __init__(self, api_host: str | None = None) -> None: # The async API can't implicitly perform network I/O in properties. # However, lazy connections also don't work due to structured concurrency. # For now, all sessions are opened eagerly by the client - # TODO: provide a way to selectively exclude unnecessary client sessions + # TODO: revisit lazy connections given the task manager implementation + # (for example, eagerly start tasks for all sessions, and lazily + # trigger events that allow them to initiate their connection) _ALL_SESSIONS: tuple[Type[AsyncSession], ...] = ( AsyncSessionEmbedding, _AsyncSessionFiles, @@ -1423,12 +1313,14 @@ async def __aenter__(self) -> Self: # Handle reentrancy the same way files do: # allow nested use as a CM, but close on the first exit if not self._sessions: - for cls in self._ALL_SESSIONS: - namespace = cls.API_NAMESPACE + rm = self._resources + await rm.enter_async_context(self._task_manager) + for session_cls in self._ALL_SESSIONS: + namespace = session_cls.API_NAMESPACE assert namespace is not None - session = cls(self) + session = session_cls(self) self._sessions[namespace] = session - await self._resources.enter_async_context(session) + await rm.enter_async_context(session) return self async def __aexit__(self, *args: Any) -> None: diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index cfe6c40..b168eb8 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -81,6 +81,7 @@ LMStudioPredictionError, LMStudioTimeoutError, LMStudioWebsocket, + LMStudioWebsocketError, LoadModelEndpoint, ModelDownloadOptionBase, ModelHandleBase, @@ -230,7 +231,7 @@ def rx_stream( except TimeoutError: raise LMStudioTimeoutError from None if message is None: - raise LMStudioRuntimeError("Client unexpectedly disconnected.") + raise LMStudioWebsocketError("Client unexpectedly disconnected.") contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True @@ -286,7 +287,7 @@ def receive_result(self) -> Any: except TimeoutError: raise LMStudioTimeoutError from None if message is None: - raise LMStudioRuntimeError("Client unexpectedly disconnected.") + raise LMStudioWebsocketError("Client unexpectedly disconnected.") return self._rpc.handle_rx_message(message) diff --git a/tests/test_session_errors.py b/tests/test_session_errors.py index 31c5603..ec15b3e 100644 --- a/tests/test_session_errors.py +++ b/tests/test_session_errors.py @@ -68,8 +68,9 @@ async def test_session_not_started_async(caplog: LogCap) -> None: @pytest.mark.asyncio async def test_session_disconnected_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = AsyncSessionSystem(AsyncClient()) - async with session: + client = AsyncClient() + session = AsyncSessionSystem(client) + async with client._task_manager, session: assert session.connected # Session is disconnected after use assert not session.connected @@ -80,12 +81,14 @@ async def test_session_disconnected_async(caplog: LogCap) -> None: @pytest.mark.asyncio async def test_session_closed_port_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = AsyncSessionSystem(AsyncClient(closed_api_host())) + client = AsyncClient(closed_api_host()) + session = AsyncSessionSystem(client) # Sessions start out disconnected assert not session.connected # Should get an SDK exception rather than the underlying exception - with pytest.raises(LMStudioWebsocketError, match="is not reachable"): - await session.connect() + async with client._task_manager: + with pytest.raises(LMStudioWebsocketError, match="is not reachable"): + await session.connect() # Session should still be considered disconnected assert not session.connected # Check server call errors are reported as expected @@ -97,12 +100,14 @@ async def test_session_closed_port_async(caplog: LogCap) -> None: async def test_session_nonresponsive_port_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) with nonresponsive_api_host() as api_host: - session = AsyncSessionSystem(AsyncClient(api_host)) + client = AsyncClient(api_host) + session = AsyncSessionSystem(client) # Sessions start out disconnected assert not session.connected # Should get an SDK exception rather than the underlying exception - with pytest.raises(LMStudioWebsocketError, match="is not reachable"): - await session.connect() + async with client._task_manager: + with pytest.raises(LMStudioWebsocketError, match="is not reachable"): + await session.connect() # Session should still be considered disconnected assert not session.connected # Check server call errors are reported as expected diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 6e046ea..9dce3c1 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -21,6 +21,7 @@ SyncSession, SyncSessionSystem, ) +from lmstudio._ws_impl import AsyncTaskManager from lmstudio._ws_thread import AsyncWebsocketThread from .support import LOCAL_API_HOST @@ -48,13 +49,14 @@ async def check_connected_async_session(session: AsyncSession) -> None: @pytest.mark.lmstudio async def test_session_cm_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) - session = AsyncSessionSystem(AsyncClient()) + client = AsyncClient() + session = AsyncSessionSystem(client) # Sessions start out disconnected assert not session.connected # Disconnecting should run without error await session.disconnect() # Entering a session opens the websocket if it isn't already open - async with session as entry_result: + async with client._task_manager, session as entry_result: # Sessions are their own entry result assert entry_result is session # Check connected session behaves as expected @@ -153,11 +155,12 @@ def test_implicit_reconnection_sync(caplog: LogCap) -> None: async def test_websocket_cm_async(caplog: LogCap) -> None: caplog.set_level(logging.DEBUG) auth_details = AsyncClient._format_auth_message() - lmsws = AsyncLMStudioWebsocket(f"http://{LOCAL_API_HOST}/system", auth_details) + tm = AsyncTaskManager(on_activation=None) + lmsws = AsyncLMStudioWebsocket(tm, f"http://{LOCAL_API_HOST}/system", auth_details) # SDK client websockets start out disconnected assert not lmsws.connected # Entering the CM opens the websocket if it isn't already open - async with lmsws as entry_result: + async with tm, lmsws as entry_result: assert lmsws.connected httpx_ws = lmsws._httpx_ws assert httpx_ws is not None