Skip to content

Commit 5ec24d1

Browse files
committed
WIP: Add default timeouts to sync API
1 parent e2066fe commit 5ec24d1

File tree

6 files changed

+135
-12
lines changed

6 files changed

+135
-12
lines changed

src/lmstudio/_ws_impl.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,15 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:
464464
"""Run given coroutine in the event loop and wait for the result."""
465465
return self._task_manager.run_coroutine_threadsafe(coro).result()
466466

467-
def rx_queue_get_threadsafe(self, rx_queue: asyncio.Queue[Any]) -> Any:
467+
def rx_queue_get_threadsafe(
468+
self, rx_queue: asyncio.Queue[Any], timeout: float | None
469+
) -> Any:
468470
future = self._task_manager.run_coroutine_threadsafe(rx_queue.get())
469-
return future.result()
471+
try:
472+
return future.result(timeout)
473+
except TimeoutError:
474+
future.cancel()
475+
raise
470476

471477
def rx_queue_put_threadsafe(
472478
self, rx_queue: asyncio.Queue[Any], message: Any
@@ -562,7 +568,7 @@ def disconnect(self) -> None:
562568
def send_json(self, message: DictObject) -> None:
563569
self._ws_handler.send_json_threadsafe(message)
564570

565-
def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[], Any]]:
571+
def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[float | None], Any]]:
566572
rx_queue: asyncio.Queue[Any] = asyncio.Queue()
567573
return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue)
568574

src/lmstudio/json_api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
"LMStudioPredictionError",
154154
"LMStudioPresetNotFoundError",
155155
"LMStudioServerError",
156+
"LMStudioTimeoutError",
156157
"LMStudioUnknownMessageWarning",
157158
"LMStudioWebsocketError",
158159
"ModelInfo",
@@ -424,9 +425,14 @@ class LMStudioCancelledError(LMStudioClientError):
424425
"""Requested operation was cancelled via the SDK client session."""
425426

426427

428+
@sdk_public_type
429+
class LMStudioTimeoutError(LMStudioError, TimeoutError):
430+
"""Client failed to receive a message from the server in the expected time."""
431+
432+
427433
@sdk_public_type
428434
class LMStudioWebsocketError(LMStudioClientError):
429-
"""Client websocket sessiqqon has terminated (or was never opened)."""
435+
"""Client websocket session has terminated (or was never opened)."""
430436

431437

432438
# dataclass vs LMStudioStruct:

src/lmstudio/plugin/hooks/prompt_preprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ async def process_requests(
215215
"""Create plugin channel and wait for server requests."""
216216
logger = self._logger
217217
endpoint = PromptPreprocessingEndpoint()
218+
# Async API expects timeouts to be handled via task groups,
219+
# so there's no default timeout to override when creating the channel
218220
async with session._create_channel(endpoint) as channel:
219221
notify_ready()
220222
logger.info("Opened channel to receive prompt preprocessing requests...")

src/lmstudio/sync_api.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
contextmanager,
1212
ExitStack,
1313
)
14-
from types import TracebackType
14+
from types import EllipsisType, TracebackType
1515
from typing import (
1616
Any,
1717
ContextManager,
@@ -80,6 +80,7 @@
8080
LMStudioCancelledError,
8181
LMStudioClientError,
8282
LMStudioPredictionError,
83+
LMStudioTimeoutError,
8384
LMStudioWebsocket,
8485
LoadModelEndpoint,
8586
ModelDownloadOptionBase,
@@ -141,15 +142,35 @@
141142
"PredictionStream",
142143
"configure_default_client",
143144
"get_default_client",
145+
"get_sync_api_timeout",
144146
"embedding_model",
145147
"list_downloaded_models",
146148
"list_loaded_models",
147149
"llm",
148150
"prepare_image",
151+
"set_sync_api_timeout",
149152
]
150153

154+
#
155+
_DEFAULT_TIMEOUT: float | None = 60.0
156+
157+
158+
def get_sync_api_timeout() -> float | None:
159+
"""Return the current default sync API timeout when waiting for server messages."""
160+
return _DEFAULT_TIMEOUT
161+
162+
163+
def set_sync_api_timeout(timeout: float | None) -> None:
164+
"""Set the default sync API timeout when waiting for server messages."""
165+
global _DEFAULT_TIMEOUT
166+
if timeout is not None:
167+
timeout = float(timeout)
168+
_DEFAULT_TIMEOUT = timeout
169+
151170

152171
T = TypeVar("T")
172+
CallWithTimeout: TypeAlias = Callable[[float | None], Any]
173+
TimeoutOption: TypeAlias = float | None | EllipsisType
153174

154175

155176
class SyncChannel(Generic[T]):
@@ -158,16 +179,18 @@ class SyncChannel(Generic[T]):
158179
def __init__(
159180
self,
160181
channel_id: int,
161-
get_message: Callable[[], Any],
182+
get_message: CallWithTimeout,
162183
endpoint: ChannelEndpoint[T, Any, Any],
163184
send_json: Callable[[DictObject], None],
164185
log_context: LogEventContext,
186+
timeout: TimeoutOption = ...,
165187
) -> None:
166188
"""Initialize synchronous websocket streaming channel."""
167189
self._is_finished = False
168190
self._get_message = get_message
169-
self._api_channel = ChannelHandler(channel_id, endpoint, log_context)
170191
self._send_json = send_json
192+
self._timeout = timeout
193+
self._api_channel = ChannelHandler(channel_id, endpoint, log_context)
171194

172195
def get_creation_message(self) -> DictObject:
173196
"""Get the message to send to create this channel."""
@@ -185,6 +208,14 @@ def cancel(self) -> None:
185208
cancel_message = self._api_channel.get_cancel_message()
186209
self._send_json(cancel_message)
187210

211+
@property
212+
def timeout(self) -> float | None:
213+
"""Permitted time between received messages for this channel."""
214+
timeout = self._timeout
215+
if timeout is ...:
216+
return _DEFAULT_TIMEOUT
217+
return timeout
218+
188219
def rx_stream(
189220
self,
190221
) -> Iterator[DictObject | None]:
@@ -193,7 +224,10 @@ def rx_stream(
193224
with sdk_public_api():
194225
# Avoid emitting tracebacks that delve into supporting libraries
195226
# (we can't easily suppress the SDK's own frames for iterators)
196-
message = self._get_message()
227+
try:
228+
message = self._get_message(self.timeout)
229+
except TimeoutError:
230+
raise LMStudioTimeoutError from None
197231
contents = self._api_channel.handle_rx_message(message)
198232
if contents is None:
199233
self._is_finished = True
@@ -216,12 +250,14 @@ class SyncRemoteCall:
216250
def __init__(
217251
self,
218252
call_id: int,
219-
get_message: Callable[[], Any],
253+
get_message: CallWithTimeout,
220254
log_context: LogEventContext,
221255
notice_prefix: str = "RPC",
256+
timeout: TimeoutOption = ...,
222257
) -> None:
223258
"""Initialize synchronous remote procedure call."""
224259
self._get_message = get_message
260+
self._timeout = timeout
225261
self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix)
226262
self._logger = logger = new_logger(type(self).__name__)
227263
logger.update_context(log_context, call_id=call_id)
@@ -232,9 +268,20 @@ def get_rpc_message(
232268
"""Get the message to send to initiate this remote procedure call."""
233269
return self._rpc.get_rpc_message(endpoint, params)
234270

271+
@property
272+
def timeout(self) -> float | None:
273+
"""Permitted time to wait for a reply to this call."""
274+
timeout = self._timeout
275+
if timeout is ...:
276+
return _DEFAULT_TIMEOUT
277+
return timeout
278+
235279
def receive_result(self) -> Any:
236280
"""Receive call response on the receive queue."""
237-
message = self._get_message()
281+
try:
282+
message = self._get_message(self.timeout)
283+
except TimeoutError:
284+
raise LMStudioTimeoutError from None
238285
return self._rpc.handle_rx_message(message)
239286

240287

tests/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ $ tox -m test -- -k test_download_model
6161

6262
Test files should follow the following naming conventions:
6363

64-
- `test_XYZ.py`: mix of async and sync test cases for `XYZ` that aren't amenable to automated conversion
65-
(for whatever reason; for example, `anyio.fail_after` has no sync counterpart)
64+
- `test_XYZ.py`: either a mix of async and sync test cases for `XYZ` that aren't amenable to
65+
automated conversion (for whatever reason; for example, `anyio.fail_after` has no sync counterpart),
66+
or else test cases for a behaviour which currently only exists in one API or the other
6667
- `async/test_XYZ_async.py` : async test cases for `XYZ` that are amenable to automated sync conversion;
6768
all test method names should also end in `_async`.
6869
- `sync/test_XYZ_sync.py` : sync test cases auto-generated from `test_XYZ_async.py`

tests/test_timeouts.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Test client timeout behaviour."""
2+
3+
import logging
4+
5+
from contextlib import contextmanager
6+
from typing import Generator
7+
8+
import pytest
9+
from pytest import LogCaptureFixture as LogCap
10+
11+
from lmstudio import (
12+
Client,
13+
LMStudioTimeoutError,
14+
get_sync_api_timeout,
15+
set_sync_api_timeout,
16+
)
17+
18+
from .support import EXPECTED_LLM_ID
19+
20+
# Sync only, as async API uses standard async timeout constructs like anyio.move_on_after
21+
22+
23+
@contextmanager
24+
def sync_api_timeout(timeout: float | None) -> Generator[float | None, None, None]:
25+
default_timeout = get_sync_api_timeout()
26+
try:
27+
yield default_timeout
28+
finally:
29+
set_sync_api_timeout(default_timeout)
30+
31+
32+
@pytest.mark.parametrize("timeout", (None, 0, 1.5, 3600, 3600 * 24 * 7))
33+
def test_timeout_updates_sync(timeout: float | None) -> None:
34+
with sync_api_timeout(timeout) as previous_timeout:
35+
assert previous_timeout is not None
36+
assert previous_timeout > 0
37+
set_sync_api_timeout(timeout)
38+
assert get_sync_api_timeout() == timeout
39+
assert get_sync_api_timeout() == previous_timeout
40+
41+
42+
@pytest.mark.lmstudio
43+
def test_rpc_timeout_sync(caplog: LogCap) -> None:
44+
caplog.set_level(logging.DEBUG)
45+
46+
with Client() as client:
47+
model = client.llm.model(EXPECTED_LLM_ID)
48+
with sync_api_timeout(0):
49+
with pytest.raises(LMStudioTimeoutError):
50+
model.get_info()
51+
52+
53+
@pytest.mark.lmstudio
54+
def test_channel_timeout_sync(caplog: LogCap) -> None:
55+
caplog.set_level(logging.DEBUG)
56+
57+
with Client() as client:
58+
model = client.llm.model(EXPECTED_LLM_ID)
59+
with sync_api_timeout(0):
60+
with pytest.raises(LMStudioTimeoutError):
61+
model.respond("This will time out")

0 commit comments

Comments
 (0)