Skip to content

Commit f4af8ec

Browse files
committed
Merge remote-tracking branch 'origin/main' into plugin-dev-support
2 parents 05b80b7 + f0fd18b commit f4af8ec

File tree

5 files changed

+143
-33
lines changed

5 files changed

+143
-33
lines changed

misc/open_client.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#!/usr/bin/env python
2+
"""Open a client instance for link failure testing."""
3+
import asyncio
4+
import logging
5+
import sys
6+
import time
7+
8+
from lmstudio import AsyncClient, Client
9+
10+
LINK_POLLING_INTERVAL = 1
11+
12+
async def open_client_async():
13+
"""Start async client, wait for link failure."""
14+
print("Connecting async client...")
15+
async with AsyncClient() as client:
16+
await client.list_downloaded_models()
17+
print ("Async client connected. Close LM Studio to terminate.")
18+
while True:
19+
await asyncio.sleep(LINK_POLLING_INTERVAL)
20+
await client.list_downloaded_models()
21+
22+
def open_client_sync():
23+
"""Start sync client, wait for link failure."""
24+
print("Connecting sync client...")
25+
with Client() as client:
26+
client.list_downloaded_models()
27+
print ("Sync client connected. Close LM Studio to terminate.")
28+
while True:
29+
time.sleep(LINK_POLLING_INTERVAL)
30+
client.list_downloaded_models()
31+
32+
if __name__ == "__main__":
33+
# Link polling makes debug logging excessively spammy
34+
log_level = logging.DEBUG if "--debug" in sys.argv else logging.INFO
35+
logging.basicConfig(level=log_level)
36+
if "--async" in sys.argv:
37+
asyncio.run(open_client_async())
38+
else:
39+
open_client_sync()

src/lmstudio/_ws_impl.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ async def _log_thread_execution(self) -> None:
290290
try:
291291
# Run the event loop until termination is requested
292292
await never_set.wait()
293+
except asyncio.CancelledError:
294+
raise
293295
except BaseException:
294296
err_msg = "Terminating websocket thread due to exception"
295297
self._logger.debug(err_msg, exc_info=True)
@@ -309,7 +311,7 @@ def __init__(
309311
task_manager: AsyncTaskManager,
310312
ws_url: str,
311313
auth_details: DictObject,
312-
enqueue_message: Callable[[DictObject], bool],
314+
enqueue_message: Callable[[DictObject | None], Awaitable[bool]],
313315
log_context: LogEventContext | None = None,
314316
) -> None:
315317
self._auth_details = auth_details
@@ -357,14 +359,16 @@ async def _logged_ws_handler(self) -> None:
357359
self._logger.info("Websocket handling task started")
358360
try:
359361
await self._handle_ws()
362+
except asyncio.CancelledError:
363+
raise
360364
except BaseException:
361365
err_msg = "Terminating websocket task due to exception"
362366
self._logger.debug(err_msg, exc_info=True)
363367
finally:
364368
# Ensure the foreground thread is unblocked even if the
365369
# background async task errors out completely
366370
self._connection_attempted.set()
367-
self._logger.info("Websocket task terminated")
371+
self._logger.info("Websocket task terminated")
368372

369373
async def _handle_ws(self) -> None:
370374
assert self._task_manager.check_running_in_task_loop()
@@ -396,12 +400,19 @@ def _clear_task_state() -> None:
396400
await self._receive_messages()
397401
finally:
398402
self._logger.info("Websocket demultiplexing task terminated.")
403+
# Notify foreground thread of background thread termination
404+
# (this covers termination due to link failure)
405+
await self._enqueue_message(None)
399406
dc_timeout = self.WS_DISCONNECT_TIMEOUT
400407
with move_on_after(dc_timeout, shield=True) as cancel_scope:
401408
# Workaround an anyio/httpx-ws issue with task cancellation:
402409
# https://github.com/frankie567/httpx-ws/issues/107
403410
self._ws = None
404-
await ws.close()
411+
try:
412+
await ws.close()
413+
except Exception:
414+
# Closing may fail if the link is already down
415+
pass
405416
if cancel_scope.cancelled_caught:
406417
self._logger.warn(
407418
f"Failed to close websocket in {dc_timeout} seconds."
@@ -413,7 +424,9 @@ async def send_json(self, message: DictObject) -> None:
413424
# This is only called if the websocket has been created
414425
assert self._task_manager.check_running_in_task_loop()
415426
ws = self._ws
416-
assert ws is not None
427+
if ws is None:
428+
# Assume app is shutting down and the owning task has already been cancelled
429+
return
417430
try:
418431
await ws.send_json(message)
419432
except Exception as exc:
@@ -430,7 +443,9 @@ async def _receive_json(self) -> Any:
430443
# This is only called if the websocket has been created
431444
assert self._task_manager.check_running_in_task_loop()
432445
ws = self._ws
433-
assert ws is not None
446+
if ws is None:
447+
# Assume app is shutting down and the owning task has already been cancelled
448+
return
434449
try:
435450
return await ws.receive_json()
436451
except Exception as exc:
@@ -443,7 +458,9 @@ async def _authenticate(self) -> bool:
443458
# This is only called if the websocket has been created
444459
assert self._task_manager.check_running_in_task_loop()
445460
ws = self._ws
446-
assert ws is not None
461+
if ws is None:
462+
# Assume app is shutting down and the owning task has already been cancelled
463+
return False
447464
auth_message = self._auth_details
448465
await self.send_json(auth_message)
449466
auth_result = await self._receive_json()
@@ -461,11 +478,11 @@ async def _process_next_message(self) -> bool:
461478
# This is only called if the websocket has been created
462479
assert self._task_manager.check_running_in_task_loop()
463480
ws = self._ws
464-
assert ws is not None
481+
if ws is None:
482+
# Assume app is shutting down and the owning task has already been cancelled
483+
return False
465484
message = await ws.receive_json()
466-
# Enqueueing messages may be a blocking call
467-
# TODO: Require it to return an Awaitable, move to_thread call to the sync bridge
468-
return await asyncio.to_thread(self._enqueue_message, message)
485+
return await self._enqueue_message(message)
469486

470487
async def _receive_messages(self) -> None:
471488
"""Process received messages until task is cancelled."""
@@ -475,7 +492,7 @@ async def _receive_messages(self) -> None:
475492
except (LMStudioWebsocketError, HTTPXWSException):
476493
if self._ws is not None and not self._ws_disconnected.is_set():
477494
# Websocket failed unexpectedly (rather than due to client shutdown)
478-
self._logger.exception("Websocket failed, terminating session.")
495+
self._logger.error("Websocket failed, terminating session.")
479496
break
480497

481498

@@ -485,11 +502,14 @@ def __init__(
485502
ws_thread: AsyncWebsocketThread,
486503
ws_url: str,
487504
auth_details: DictObject,
488-
enqueue_message: Callable[[DictObject], bool],
505+
enqueue_message: Callable[[DictObject | None], bool],
489506
log_context: LogEventContext,
490507
) -> None:
508+
async def enqueue_async(message: DictObject | None) -> bool:
509+
return await asyncio.to_thread(enqueue_message, message)
510+
491511
self._ws_handler = AsyncWebsocketHandler(
492-
ws_thread.task_manager, ws_url, auth_details, enqueue_message, log_context
512+
ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context
493513
)
494514

495515
def connect(self) -> bool:

src/lmstudio/async_api.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Async I/O protocol implementation for the LM Studio remote access API."""
22

33
import asyncio
4-
import asyncio.queues
54
import warnings
65

76
from abc import abstractmethod
@@ -28,6 +27,8 @@
2827
TypeIs,
2928
)
3029

30+
from anyio import create_task_group
31+
from anyio.abc import TaskGroup
3132
from httpx import RequestError, HTTPStatusError
3233
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
3334

@@ -168,7 +169,10 @@ async def rx_stream(
168169
# Avoid emitting tracebacks that delve into supporting libraries
169170
# (we can't easily suppress the SDK's own frames for iterators)
170171
message = await self._rx_queue.get()
171-
contents = self._api_channel.handle_rx_message(message)
172+
if message is None:
173+
contents = None
174+
else:
175+
contents = self._api_channel.handle_rx_message(message)
172176
if contents is None:
173177
self._is_finished = True
174178
break
@@ -209,6 +213,8 @@ def get_rpc_message(
209213
async def receive_result(self) -> Any:
210214
"""Receive call response on the receive queue."""
211215
message = await self._rx_queue.get()
216+
if message is None:
217+
return None
212218
return self._rpc.handle_rx_message(message)
213219

214220

@@ -225,8 +231,10 @@ def __init__(
225231
) -> None:
226232
"""Initialize asynchronous websocket client."""
227233
super().__init__(ws_url, auth_details, log_context)
228-
self._resource_manager = AsyncExitStack()
234+
self._resource_manager = rm = AsyncExitStack()
235+
rm.push_async_callback(self._notify_client_termination)
229236
self._rx_task: asyncio.Task[None] | None = None
237+
self._terminate = asyncio.Event()
230238

231239
@property
232240
def _httpx_ws(self) -> AsyncWebSocketSession | None:
@@ -246,7 +254,9 @@ async def __aexit__(self, *args: Any) -> None:
246254
async def _send_json(self, message: DictObject) -> None:
247255
# Callers are expected to call `_ensure_connected` before this method
248256
ws = self._ws
249-
assert ws is not None
257+
if ws is None:
258+
# Assume app is shutting down and the owning task has already been cancelled
259+
return
250260
try:
251261
await ws.send_json(message)
252262
except Exception as exc:
@@ -258,7 +268,9 @@ async def _send_json(self, message: DictObject) -> None:
258268
async def _receive_json(self) -> Any:
259269
# Callers are expected to call `_ensure_connected` before this method
260270
ws = self._ws
261-
assert ws is not None
271+
if ws is None:
272+
# Assume app is shutting down and the owning task has already been cancelled
273+
return
262274
try:
263275
return await ws.receive_json()
264276
except Exception as exc:
@@ -296,7 +308,7 @@ async def connect(self) -> Self:
296308
self._rx_task = rx_task = asyncio.create_task(self._receive_messages())
297309

298310
async def _terminate_rx_task() -> None:
299-
rx_task.cancel()
311+
self._terminate.set()
300312
try:
301313
await rx_task
302314
except asyncio.CancelledError:
@@ -310,19 +322,34 @@ async def disconnect(self) -> None:
310322
"""Drop the LM Studio API connection."""
311323
self._ws = None
312324
self._rx_task = None
313-
await self._notify_client_termination()
325+
self._terminate.set()
314326
await self._resource_manager.aclose()
315327
self._logger.info(f"Websocket session disconnected ({self._ws_url})")
316328

317329
aclose = disconnect
318330

331+
async def _cancel_on_termination(self, tg: TaskGroup) -> None:
332+
await self._terminate.wait()
333+
tg.cancel_scope.cancel()
334+
319335
async def _process_next_message(self) -> bool:
320336
"""Process the next message received on the websocket.
321337
322338
Returns True if a message queue was updated.
323339
"""
324340
self._ensure_connected("receive messages")
325-
message = await self._receive_json()
341+
async with create_task_group() as tg:
342+
tg.start_soon(self._cancel_on_termination, tg)
343+
try:
344+
message = await self._receive_json()
345+
except (LMStudioWebsocketError, HTTPXWSException):
346+
if self._ws is not None and not self._terminate.is_set():
347+
# Websocket failed unexpectedly (rather than due to client shutdown)
348+
self._logger.error("Websocket failed, terminating session.")
349+
self._terminate.set()
350+
tg.cancel_scope.cancel()
351+
if self._terminate.is_set():
352+
return (await self._notify_client_termination()) > 0
326353
rx_queue = self._mux.map_rx_message(message)
327354
if rx_queue is None:
328355
return False
@@ -331,18 +358,20 @@ async def _process_next_message(self) -> bool:
331358

332359
async def _receive_messages(self) -> None:
333360
"""Process received messages until connection is terminated."""
334-
while True:
335-
try:
336-
await self._process_next_message()
337-
except (LMStudioWebsocketError, HTTPXWSException):
338-
self._logger.exception("Websocket failed, terminating session.")
339-
await self.disconnect()
340-
break
361+
while not self._terminate.is_set():
362+
await self._process_next_message()
341363

342-
async def _notify_client_termination(self) -> None:
364+
async def _notify_client_termination(self) -> int:
343365
"""Send None to all clients with open receive queues."""
366+
num_clients = 0
344367
for rx_queue in self._mux.all_queues():
345368
await rx_queue.put(None)
369+
num_clients += 1
370+
self._logger.info(
371+
f"Notified {num_clients} clients of websocket termination",
372+
num_clients=num_clients,
373+
)
374+
return num_clients
346375

347376
async def _connect_to_endpoint(self, channel: AsyncChannel[Any]) -> None:
348377
"""Connect channel to specified endpoint."""
@@ -367,6 +396,9 @@ async def open_channel(
367396
self._logger.event_context,
368397
)
369398
await self._connect_to_endpoint(channel)
399+
if self._terminate.is_set():
400+
# Link has been terminated, ensure client gets a response
401+
await rx_queue.put(None)
370402
yield channel
371403

372404
async def _send_call(
@@ -401,6 +433,9 @@ async def remote_call(
401433
call_id, rx_queue, self._logger.event_context, notice_prefix
402434
)
403435
await self._send_call(rpc, endpoint, params)
436+
if self._terminate.is_set():
437+
# Link has been terminated, ensure client gets a response
438+
await rx_queue.put(None)
404439
return await rpc.receive_result()
405440

406441

src/lmstudio/json_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,10 @@ def _format_server_error(details: SerializedLMSExtendedError) -> str:
352352
lines.extend(_get_data_lines(details.error_data, " "))
353353
if details.cause is not None:
354354
lines.extend(("", " Reported cause:"))
355-
lines.extend(f" {details.cause}")
355+
lines.append(f" {details.cause}")
356356
if details.suggestion is not None:
357357
lines.extend(("", " Suggested potential remedy:"))
358-
lines.extend(f" {details.suggestion}")
358+
lines.append(f" {details.suggestion}")
359359
# Only use the multi-line format if at least one
360360
# of the extended error fields is populated
361361
if lines:

0 commit comments

Comments
 (0)