diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 1d68cf8..a9d917c 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -464,9 +464,15 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: """Run given coroutine in the event loop and wait for the result.""" return self._task_manager.run_coroutine_threadsafe(coro).result() - def rx_queue_get_threadsafe(self, rx_queue: asyncio.Queue[Any]) -> Any: + def rx_queue_get_threadsafe( + self, rx_queue: asyncio.Queue[Any], timeout: float | None + ) -> Any: future = self._task_manager.run_coroutine_threadsafe(rx_queue.get()) - return future.result() + try: + return future.result(timeout) + except TimeoutError: + future.cancel() + raise def rx_queue_put_threadsafe( self, rx_queue: asyncio.Queue[Any], message: Any @@ -562,7 +568,7 @@ def disconnect(self) -> None: def send_json(self, message: DictObject) -> None: self._ws_handler.send_json_threadsafe(message) - def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[], Any]]: + def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[float | None], Any]]: rx_queue: asyncio.Queue[Any] = asyncio.Queue() return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue) diff --git a/src/lmstudio/json_api.py b/src/lmstudio/json_api.py index 00b1365..5eb90a2 100644 --- a/src/lmstudio/json_api.py +++ b/src/lmstudio/json_api.py @@ -153,6 +153,7 @@ "LMStudioPredictionError", "LMStudioPresetNotFoundError", "LMStudioServerError", + "LMStudioTimeoutError", "LMStudioUnknownMessageWarning", "LMStudioWebsocketError", "ModelInfo", @@ -424,9 +425,14 @@ class LMStudioCancelledError(LMStudioClientError): """Requested operation was cancelled via the SDK client session.""" +@sdk_public_type +class LMStudioTimeoutError(LMStudioError, TimeoutError): + """Client failed to receive a message from the server in the expected time.""" + + @sdk_public_type class LMStudioWebsocketError(LMStudioClientError): - """Client websocket sessiqqon has terminated (or was never opened).""" + """Client websocket session has terminated (or was never opened).""" # dataclass vs LMStudioStruct: @@ -689,6 +695,7 @@ def is_finished(self) -> bool: def _set_result(self, result: T) -> ChannelFinishedEvent: # Note: errors are raised immediately when handling the relevant message # rather than only being reported when the result is accessed + self._logger.debug("Channel result received, closing channel") self._is_finished = True self._result = result return ChannelFinishedEvent(None) diff --git a/src/lmstudio/plugin/hooks/prompt_preprocessor.py b/src/lmstudio/plugin/hooks/prompt_preprocessor.py index 1432194..73a4910 100644 --- a/src/lmstudio/plugin/hooks/prompt_preprocessor.py +++ b/src/lmstudio/plugin/hooks/prompt_preprocessor.py @@ -215,6 +215,8 @@ async def process_requests( """Create plugin channel and wait for server requests.""" logger = self._logger endpoint = PromptPreprocessingEndpoint() + # Async API expects timeouts to be handled via task groups, + # so there's no default timeout to override when creating the channel async with session._create_channel(endpoint) as channel: notify_ready() logger.info("Opened channel to receive prompt preprocessing requests...") diff --git a/src/lmstudio/sync_api.py b/src/lmstudio/sync_api.py index b7d8f28..ee11ef7 100644 --- a/src/lmstudio/sync_api.py +++ b/src/lmstudio/sync_api.py @@ -11,7 +11,7 @@ contextmanager, ExitStack, ) -from types import TracebackType +from types import EllipsisType, TracebackType from typing import ( Any, ContextManager, @@ -80,6 +80,7 @@ LMStudioCancelledError, LMStudioClientError, LMStudioPredictionError, + LMStudioTimeoutError, LMStudioWebsocket, LoadModelEndpoint, ModelDownloadOptionBase, @@ -141,15 +142,37 @@ "PredictionStream", "configure_default_client", "get_default_client", + "get_sync_api_timeout", "embedding_model", "list_downloaded_models", "list_loaded_models", "llm", "prepare_image", + "set_sync_api_timeout", ] +# +_DEFAULT_TIMEOUT: float | None = 60.0 + + +@sdk_public_api() +def get_sync_api_timeout() -> float | None: + """Return the current default sync API timeout when waiting for server messages.""" + return _DEFAULT_TIMEOUT + + +@sdk_public_api() +def set_sync_api_timeout(timeout: float | None) -> None: + """Set the default sync API timeout when waiting for server messages.""" + global _DEFAULT_TIMEOUT + if timeout is not None: + timeout = float(timeout) + _DEFAULT_TIMEOUT = timeout + T = TypeVar("T") +CallWithTimeout: TypeAlias = Callable[[float | None], Any] +TimeoutOption: TypeAlias = float | None | EllipsisType class SyncChannel(Generic[T]): @@ -158,16 +181,18 @@ class SyncChannel(Generic[T]): def __init__( self, channel_id: int, - get_message: Callable[[], Any], + get_message: CallWithTimeout, endpoint: ChannelEndpoint[T, Any, Any], send_json: Callable[[DictObject], None], log_context: LogEventContext, + timeout: TimeoutOption = ..., ) -> None: """Initialize synchronous websocket streaming channel.""" self._is_finished = False self._get_message = get_message - self._api_channel = ChannelHandler(channel_id, endpoint, log_context) self._send_json = send_json + self._timeout = timeout + self._api_channel = ChannelHandler(channel_id, endpoint, log_context) def get_creation_message(self) -> DictObject: """Get the message to send to create this channel.""" @@ -185,6 +210,14 @@ def cancel(self) -> None: cancel_message = self._api_channel.get_cancel_message() self._send_json(cancel_message) + @property + def timeout(self) -> float | None: + """Permitted time between received messages for this channel.""" + timeout = self._timeout + if timeout is ...: + return _DEFAULT_TIMEOUT + return timeout + def rx_stream( self, ) -> Iterator[DictObject | None]: @@ -193,7 +226,10 @@ def rx_stream( with sdk_public_api(): # Avoid emitting tracebacks that delve into supporting libraries # (we can't easily suppress the SDK's own frames for iterators) - message = self._get_message() + try: + message = self._get_message(self.timeout) + except TimeoutError: + raise LMStudioTimeoutError from None contents = self._api_channel.handle_rx_message(message) if contents is None: self._is_finished = True @@ -216,12 +252,14 @@ class SyncRemoteCall: def __init__( self, call_id: int, - get_message: Callable[[], Any], + get_message: CallWithTimeout, log_context: LogEventContext, notice_prefix: str = "RPC", + timeout: TimeoutOption = ..., ) -> None: """Initialize synchronous remote procedure call.""" self._get_message = get_message + self._timeout = timeout self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix) self._logger = logger = new_logger(type(self).__name__) logger.update_context(log_context, call_id=call_id) @@ -232,9 +270,20 @@ def get_rpc_message( """Get the message to send to initiate this remote procedure call.""" return self._rpc.get_rpc_message(endpoint, params) + @property + def timeout(self) -> float | None: + """Permitted time to wait for a reply to this call.""" + timeout = self._timeout + if timeout is ...: + return _DEFAULT_TIMEOUT + return timeout + def receive_result(self) -> Any: """Receive call response on the receive queue.""" - message = self._get_message() + try: + message = self._get_message(self.timeout) + except TimeoutError: + raise LMStudioTimeoutError from None return self._rpc.handle_rx_message(message) diff --git a/tests/README.md b/tests/README.md index e319d97..296de15 100644 --- a/tests/README.md +++ b/tests/README.md @@ -61,8 +61,9 @@ $ tox -m test -- -k test_download_model Test files should follow the following naming conventions: -- `test_XYZ.py`: mix of async and sync test cases for `XYZ` that aren't amenable to automated conversion - (for whatever reason; for example, `anyio.fail_after` has no sync counterpart) +- `test_XYZ.py`: either a mix of async and sync test cases for `XYZ` that aren't amenable to + automated conversion (for whatever reason; for example, `anyio.fail_after` has no sync counterpart), + or else test cases for a behaviour which currently only exists in one API or the other - `async/test_XYZ_async.py` : async test cases for `XYZ` that are amenable to automated sync conversion; all test method names should also end in `_async`. - `sync/test_XYZ_sync.py` : sync test cases auto-generated from `test_XYZ_async.py` diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py new file mode 100644 index 0000000..d27a370 --- /dev/null +++ b/tests/test_timeouts.py @@ -0,0 +1,73 @@ +"""Test client timeout behaviour.""" + +import logging + +from contextlib import contextmanager +from typing import Generator + +import pytest +from pytest import LogCaptureFixture as LogCap + +from lmstudio import ( + Client, + LMStudioTimeoutError, + get_sync_api_timeout, + set_sync_api_timeout, +) +from lmstudio.sync_api import _DEFAULT_TIMEOUT + +from .support import EXPECTED_LLM_ID + +# Sync only, as async API uses standard async timeout constructs like anyio.move_on_after + + +@contextmanager +def sync_api_timeout(timeout: float | None) -> Generator[float | None, None, None]: + previous_timeout = get_sync_api_timeout() + set_sync_api_timeout(timeout) + try: + yield previous_timeout + finally: + set_sync_api_timeout(previous_timeout) + + +def test_default_timeout() -> None: + # Ensure default timeout is defined, but is not excessively short or long + # (the bounds that are considered reasonable may change over time) + assert _DEFAULT_TIMEOUT is not None + assert _DEFAULT_TIMEOUT >= 60 + assert _DEFAULT_TIMEOUT <= 600 + + +@pytest.mark.parametrize("timeout", (None, 0, 1.5, 3600, 3600 * 24 * 7)) +def test_timeout_updates_sync(timeout: float | None) -> None: + with sync_api_timeout(timeout) as previous_timeout: + assert previous_timeout == _DEFAULT_TIMEOUT + assert get_sync_api_timeout() == timeout + assert get_sync_api_timeout() == previous_timeout + + +@pytest.mark.lmstudio +def test_timeout_rpc_sync(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + + with Client() as client: + model = client.llm.model(EXPECTED_LLM_ID) + with sync_api_timeout(0): + assert get_sync_api_timeout() == 0 + with pytest.raises(LMStudioTimeoutError): + response = model.get_info() + logging.error(f"Unexpected response: {response}") + + +@pytest.mark.lmstudio +def test_timeout_channel_sync(caplog: LogCap) -> None: + caplog.set_level(logging.DEBUG) + + with Client() as client: + model = client.llm.model(EXPECTED_LLM_ID) + with sync_api_timeout(0): + assert get_sync_api_timeout() == 0 + with pytest.raises(LMStudioTimeoutError): + response = model.respond("This will time out") + logging.error(f"Unexpected response: {response}")