Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,15 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:
"""Run given coroutine in the event loop and wait for the result."""
return self._task_manager.run_coroutine_threadsafe(coro).result()

def rx_queue_get_threadsafe(self, rx_queue: asyncio.Queue[Any]) -> Any:
def rx_queue_get_threadsafe(
self, rx_queue: asyncio.Queue[Any], timeout: float | None
) -> Any:
future = self._task_manager.run_coroutine_threadsafe(rx_queue.get())
return future.result()
try:
return future.result(timeout)
except TimeoutError:
future.cancel()
raise

def rx_queue_put_threadsafe(
self, rx_queue: asyncio.Queue[Any], message: Any
Expand Down Expand Up @@ -562,7 +568,7 @@ def disconnect(self) -> None:
def send_json(self, message: DictObject) -> None:
self._ws_handler.send_json_threadsafe(message)

def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[], Any]]:
def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[float | None], Any]]:
rx_queue: asyncio.Queue[Any] = asyncio.Queue()
return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue)

Expand Down
9 changes: 8 additions & 1 deletion src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
"LMStudioPredictionError",
"LMStudioPresetNotFoundError",
"LMStudioServerError",
"LMStudioTimeoutError",
"LMStudioUnknownMessageWarning",
"LMStudioWebsocketError",
"ModelInfo",
Expand Down Expand Up @@ -424,9 +425,14 @@ class LMStudioCancelledError(LMStudioClientError):
"""Requested operation was cancelled via the SDK client session."""


@sdk_public_type
class LMStudioTimeoutError(LMStudioError, TimeoutError):
"""Client failed to receive a message from the server in the expected time."""


@sdk_public_type
class LMStudioWebsocketError(LMStudioClientError):
"""Client websocket sessiqqon has terminated (or was never opened)."""
"""Client websocket session has terminated (or was never opened)."""


# dataclass vs LMStudioStruct:
Expand Down Expand Up @@ -689,6 +695,7 @@ def is_finished(self) -> bool:
def _set_result(self, result: T) -> ChannelFinishedEvent:
# Note: errors are raised immediately when handling the relevant message
# rather than only being reported when the result is accessed
self._logger.debug("Channel result received, closing channel")
self._is_finished = True
self._result = result
return ChannelFinishedEvent(None)
Expand Down
2 changes: 2 additions & 0 deletions src/lmstudio/plugin/hooks/prompt_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ async def process_requests(
"""Create plugin channel and wait for server requests."""
logger = self._logger
endpoint = PromptPreprocessingEndpoint()
# Async API expects timeouts to be handled via task groups,
# so there's no default timeout to override when creating the channel
async with session._create_channel(endpoint) as channel:
notify_ready()
logger.info("Opened channel to receive prompt preprocessing requests...")
Expand Down
61 changes: 55 additions & 6 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
contextmanager,
ExitStack,
)
from types import TracebackType
from types import EllipsisType, TracebackType
from typing import (
Any,
ContextManager,
Expand Down Expand Up @@ -80,6 +80,7 @@
LMStudioCancelledError,
LMStudioClientError,
LMStudioPredictionError,
LMStudioTimeoutError,
LMStudioWebsocket,
LoadModelEndpoint,
ModelDownloadOptionBase,
Expand Down Expand Up @@ -141,15 +142,37 @@
"PredictionStream",
"configure_default_client",
"get_default_client",
"get_sync_api_timeout",
"embedding_model",
"list_downloaded_models",
"list_loaded_models",
"llm",
"prepare_image",
"set_sync_api_timeout",
]

#
_DEFAULT_TIMEOUT: float | None = 60.0


@sdk_public_api()
def get_sync_api_timeout() -> float | None:
"""Return the current default sync API timeout when waiting for server messages."""
return _DEFAULT_TIMEOUT


@sdk_public_api()
def set_sync_api_timeout(timeout: float | None) -> None:
"""Set the default sync API timeout when waiting for server messages."""
global _DEFAULT_TIMEOUT
if timeout is not None:
timeout = float(timeout)
_DEFAULT_TIMEOUT = timeout


T = TypeVar("T")
CallWithTimeout: TypeAlias = Callable[[float | None], Any]
TimeoutOption: TypeAlias = float | None | EllipsisType


class SyncChannel(Generic[T]):
Expand All @@ -158,16 +181,18 @@ class SyncChannel(Generic[T]):
def __init__(
self,
channel_id: int,
get_message: Callable[[], Any],
get_message: CallWithTimeout,
endpoint: ChannelEndpoint[T, Any, Any],
send_json: Callable[[DictObject], None],
log_context: LogEventContext,
timeout: TimeoutOption = ...,
) -> None:
"""Initialize synchronous websocket streaming channel."""
self._is_finished = False
self._get_message = get_message
self._api_channel = ChannelHandler(channel_id, endpoint, log_context)
self._send_json = send_json
self._timeout = timeout
self._api_channel = ChannelHandler(channel_id, endpoint, log_context)

def get_creation_message(self) -> DictObject:
"""Get the message to send to create this channel."""
Expand All @@ -185,6 +210,14 @@ def cancel(self) -> None:
cancel_message = self._api_channel.get_cancel_message()
self._send_json(cancel_message)

@property
def timeout(self) -> float | None:
"""Permitted time between received messages for this channel."""
timeout = self._timeout
if timeout is ...:
return _DEFAULT_TIMEOUT
return timeout

def rx_stream(
self,
) -> Iterator[DictObject | None]:
Expand All @@ -193,7 +226,10 @@ def rx_stream(
with sdk_public_api():
# Avoid emitting tracebacks that delve into supporting libraries
# (we can't easily suppress the SDK's own frames for iterators)
message = self._get_message()
try:
message = self._get_message(self.timeout)
except TimeoutError:
raise LMStudioTimeoutError from None
contents = self._api_channel.handle_rx_message(message)
if contents is None:
self._is_finished = True
Expand All @@ -216,12 +252,14 @@ class SyncRemoteCall:
def __init__(
self,
call_id: int,
get_message: Callable[[], Any],
get_message: CallWithTimeout,
log_context: LogEventContext,
notice_prefix: str = "RPC",
timeout: TimeoutOption = ...,
) -> None:
"""Initialize synchronous remote procedure call."""
self._get_message = get_message
self._timeout = timeout
self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix)
self._logger = logger = new_logger(type(self).__name__)
logger.update_context(log_context, call_id=call_id)
Expand All @@ -232,9 +270,20 @@ def get_rpc_message(
"""Get the message to send to initiate this remote procedure call."""
return self._rpc.get_rpc_message(endpoint, params)

@property
def timeout(self) -> float | None:
"""Permitted time to wait for a reply to this call."""
timeout = self._timeout
if timeout is ...:
return _DEFAULT_TIMEOUT
return timeout

def receive_result(self) -> Any:
"""Receive call response on the receive queue."""
message = self._get_message()
try:
message = self._get_message(self.timeout)
except TimeoutError:
raise LMStudioTimeoutError from None
return self._rpc.handle_rx_message(message)


Expand Down
5 changes: 3 additions & 2 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ $ tox -m test -- -k test_download_model

Test files should follow the following naming conventions:

- `test_XYZ.py`: mix of async and sync test cases for `XYZ` that aren't amenable to automated conversion
(for whatever reason; for example, `anyio.fail_after` has no sync counterpart)
- `test_XYZ.py`: either a mix of async and sync test cases for `XYZ` that aren't amenable to
automated conversion (for whatever reason; for example, `anyio.fail_after` has no sync counterpart),
or else test cases for a behaviour which currently only exists in one API or the other
- `async/test_XYZ_async.py` : async test cases for `XYZ` that are amenable to automated sync conversion;
all test method names should also end in `_async`.
- `sync/test_XYZ_sync.py` : sync test cases auto-generated from `test_XYZ_async.py`
Expand Down
73 changes: 73 additions & 0 deletions tests/test_timeouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Test client timeout behaviour."""

import logging

from contextlib import contextmanager
from typing import Generator

import pytest
from pytest import LogCaptureFixture as LogCap

from lmstudio import (
Client,
LMStudioTimeoutError,
get_sync_api_timeout,
set_sync_api_timeout,
)
from lmstudio.sync_api import _DEFAULT_TIMEOUT

from .support import EXPECTED_LLM_ID

# Sync only, as async API uses standard async timeout constructs like anyio.move_on_after


@contextmanager
def sync_api_timeout(timeout: float | None) -> Generator[float | None, None, None]:
previous_timeout = get_sync_api_timeout()
set_sync_api_timeout(timeout)
try:
yield previous_timeout
finally:
set_sync_api_timeout(previous_timeout)


def test_default_timeout() -> None:
# Ensure default timeout is defined, but is not excessively short or long
# (the bounds that are considered reasonable may change over time)
assert _DEFAULT_TIMEOUT is not None
assert _DEFAULT_TIMEOUT >= 60
assert _DEFAULT_TIMEOUT <= 600


@pytest.mark.parametrize("timeout", (None, 0, 1.5, 3600, 3600 * 24 * 7))
def test_timeout_updates_sync(timeout: float | None) -> None:
with sync_api_timeout(timeout) as previous_timeout:
assert previous_timeout == _DEFAULT_TIMEOUT
assert get_sync_api_timeout() == timeout
assert get_sync_api_timeout() == previous_timeout


@pytest.mark.lmstudio
def test_timeout_rpc_sync(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)

with Client() as client:
model = client.llm.model(EXPECTED_LLM_ID)
with sync_api_timeout(0):
assert get_sync_api_timeout() == 0
with pytest.raises(LMStudioTimeoutError):
response = model.get_info()
logging.error(f"Unexpected response: {response}")


@pytest.mark.lmstudio
def test_timeout_channel_sync(caplog: LogCap) -> None:
caplog.set_level(logging.DEBUG)

with Client() as client:
model = client.llm.model(EXPECTED_LLM_ID)
with sync_api_timeout(0):
assert get_sync_api_timeout() == 0
with pytest.raises(LMStudioTimeoutError):
response = model.respond("This will time out")
logging.error(f"Unexpected response: {response}")