Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions misc/open_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python
"""Open a client instance for link failure testing."""
import asyncio
import logging
import sys
import time

from lmstudio import AsyncClient, Client

LINK_POLLING_INTERVAL = 1

async def open_client_async():
"""Start async client, wait for link failure."""
print("Connecting async client...")
async with AsyncClient() as client:
await client.list_downloaded_models()
print ("Async client connected. Close LM Studio to terminate.")
while True:
await asyncio.sleep(LINK_POLLING_INTERVAL)
await client.list_downloaded_models()

def open_client_sync():
"""Start sync client, wait for link failure."""
print("Connecting sync client...")
with Client() as client:
client.list_downloaded_models()
print ("Sync client connected. Close LM Studio to terminate.")
while True:
time.sleep(LINK_POLLING_INTERVAL)
client.list_downloaded_models()

if __name__ == "__main__":
# Link polling makes debug logging excessively spammy
log_level = logging.DEBUG if "--debug" in sys.argv else logging.INFO
logging.basicConfig(level=log_level)
if "--async" in sys.argv:
asyncio.run(open_client_async())
else:
open_client_sync()
46 changes: 33 additions & 13 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ async def _log_thread_execution(self) -> None:
try:
# Run the event loop until termination is requested
await never_set.wait()
except asyncio.CancelledError:
raise
except BaseException:
err_msg = "Terminating websocket thread due to exception"
self._logger.debug(err_msg, exc_info=True)
Expand All @@ -309,7 +311,7 @@ def __init__(
task_manager: AsyncTaskManager,
ws_url: str,
auth_details: DictObject,
enqueue_message: Callable[[DictObject], bool],
enqueue_message: Callable[[DictObject | None], Awaitable[bool]],
log_context: LogEventContext | None = None,
) -> None:
self._auth_details = auth_details
Expand Down Expand Up @@ -357,14 +359,16 @@ async def _logged_ws_handler(self) -> None:
self._logger.info("Websocket handling task started")
try:
await self._handle_ws()
except asyncio.CancelledError:
raise
except BaseException:
err_msg = "Terminating websocket task due to exception"
self._logger.debug(err_msg, exc_info=True)
finally:
# Ensure the foreground thread is unblocked even if the
# background async task errors out completely
self._connection_attempted.set()
self._logger.info("Websocket task terminated")
self._logger.info("Websocket task terminated")

async def _handle_ws(self) -> None:
assert self._task_manager.check_running_in_task_loop()
Expand Down Expand Up @@ -396,12 +400,19 @@ def _clear_task_state() -> None:
await self._receive_messages()
finally:
self._logger.info("Websocket demultiplexing task terminated.")
# Notify foreground thread of background thread termination
# (this covers termination due to link failure)
await self._enqueue_message(None)
dc_timeout = self.WS_DISCONNECT_TIMEOUT
with move_on_after(dc_timeout, shield=True) as cancel_scope:
# Workaround an anyio/httpx-ws issue with task cancellation:
# https://github.com/frankie567/httpx-ws/issues/107
self._ws = None
await ws.close()
try:
await ws.close()
except Exception:
# Closing may fail if the link is already down
pass
if cancel_scope.cancelled_caught:
self._logger.warn(
f"Failed to close websocket in {dc_timeout} seconds."
Expand All @@ -413,7 +424,9 @@ async def send_json(self, message: DictObject) -> None:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
ws = self._ws
assert ws is not None
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
return
try:
await ws.send_json(message)
except Exception as exc:
Expand All @@ -430,7 +443,9 @@ async def _receive_json(self) -> Any:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
ws = self._ws
assert ws is not None
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
return
try:
return await ws.receive_json()
except Exception as exc:
Expand All @@ -443,7 +458,9 @@ async def _authenticate(self) -> bool:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
ws = self._ws
assert ws is not None
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
return False
auth_message = self._auth_details
await self.send_json(auth_message)
auth_result = await self._receive_json()
Expand All @@ -461,11 +478,11 @@ async def _process_next_message(self) -> bool:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
ws = self._ws
assert ws is not None
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
return False
message = await ws.receive_json()
# Enqueueing messages may be a blocking call
# TODO: Require it to return an Awaitable, move to_thread call to the sync bridge
return await asyncio.to_thread(self._enqueue_message, message)
return await self._enqueue_message(message)

async def _receive_messages(self) -> None:
"""Process received messages until task is cancelled."""
Expand All @@ -475,7 +492,7 @@ async def _receive_messages(self) -> None:
except (LMStudioWebsocketError, HTTPXWSException):
if self._ws is not None and not self._ws_disconnected.is_set():
# Websocket failed unexpectedly (rather than due to client shutdown)
self._logger.exception("Websocket failed, terminating session.")
self._logger.error("Websocket failed, terminating session.")
break


Expand All @@ -485,11 +502,14 @@ def __init__(
ws_thread: AsyncWebsocketThread,
ws_url: str,
auth_details: DictObject,
enqueue_message: Callable[[DictObject], bool],
enqueue_message: Callable[[DictObject | None], bool],
log_context: LogEventContext,
) -> None:
async def enqueue_async(message: DictObject | None) -> bool:
return await asyncio.to_thread(enqueue_message, message)

self._ws_handler = AsyncWebsocketHandler(
ws_thread.task_manager, ws_url, auth_details, enqueue_message, log_context
ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context
)

def connect(self) -> bool:
Expand Down
67 changes: 51 additions & 16 deletions src/lmstudio/async_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Async I/O protocol implementation for the LM Studio remote access API."""

import asyncio
import asyncio.queues
import warnings

from abc import abstractmethod
Expand All @@ -28,6 +27,8 @@
TypeIs,
)

from anyio import create_task_group
from anyio.abc import TaskGroup
from httpx import RequestError, HTTPStatusError
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException

Expand Down Expand Up @@ -163,7 +164,10 @@ async def rx_stream(
# Avoid emitting tracebacks that delve into supporting libraries
# (we can't easily suppress the SDK's own frames for iterators)
message = await self._rx_queue.get()
contents = self._api_channel.handle_rx_message(message)
if message is None:
contents = None
else:
contents = self._api_channel.handle_rx_message(message)
if contents is None:
self._is_finished = True
break
Expand Down Expand Up @@ -204,6 +208,8 @@ def get_rpc_message(
async def receive_result(self) -> Any:
"""Receive call response on the receive queue."""
message = await self._rx_queue.get()
if message is None:
return None
return self._rpc.handle_rx_message(message)


Expand All @@ -220,8 +226,10 @@ def __init__(
) -> None:
"""Initialize asynchronous websocket client."""
super().__init__(ws_url, auth_details, log_context)
self._resource_manager = AsyncExitStack()
self._resource_manager = rm = AsyncExitStack()
rm.push_async_callback(self._notify_client_termination)
self._rx_task: asyncio.Task[None] | None = None
self._terminate = asyncio.Event()

@property
def _httpx_ws(self) -> AsyncWebSocketSession | None:
Expand All @@ -241,7 +249,9 @@ async def __aexit__(self, *args: Any) -> None:
async def _send_json(self, message: DictObject) -> None:
# Callers are expected to call `_ensure_connected` before this method
ws = self._ws
assert ws is not None
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
return
try:
await ws.send_json(message)
except Exception as exc:
Expand All @@ -253,7 +263,9 @@ async def _send_json(self, message: DictObject) -> None:
async def _receive_json(self) -> Any:
# Callers are expected to call `_ensure_connected` before this method
ws = self._ws
assert ws is not None
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
return
try:
return await ws.receive_json()
except Exception as exc:
Expand Down Expand Up @@ -291,7 +303,7 @@ async def connect(self) -> Self:
self._rx_task = rx_task = asyncio.create_task(self._receive_messages())

async def _terminate_rx_task() -> None:
rx_task.cancel()
self._terminate.set()
try:
await rx_task
except asyncio.CancelledError:
Expand All @@ -305,19 +317,34 @@ async def disconnect(self) -> None:
"""Drop the LM Studio API connection."""
self._ws = None
self._rx_task = None
await self._notify_client_termination()
self._terminate.set()
await self._resource_manager.aclose()
self._logger.info(f"Websocket session disconnected ({self._ws_url})")

aclose = disconnect

async def _cancel_on_termination(self, tg: TaskGroup) -> None:
await self._terminate.wait()
tg.cancel_scope.cancel()

async def _process_next_message(self) -> bool:
"""Process the next message received on the websocket.

Returns True if a message queue was updated.
"""
self._ensure_connected("receive messages")
message = await self._receive_json()
async with create_task_group() as tg:
tg.start_soon(self._cancel_on_termination, tg)
try:
message = await self._receive_json()
except (LMStudioWebsocketError, HTTPXWSException):
if self._ws is not None and not self._terminate.is_set():
# Websocket failed unexpectedly (rather than due to client shutdown)
self._logger.error("Websocket failed, terminating session.")
self._terminate.set()
tg.cancel_scope.cancel()
if self._terminate.is_set():
return (await self._notify_client_termination()) > 0
rx_queue = self._mux.map_rx_message(message)
if rx_queue is None:
return False
Expand All @@ -326,18 +353,20 @@ async def _process_next_message(self) -> bool:

async def _receive_messages(self) -> None:
"""Process received messages until connection is terminated."""
while True:
try:
await self._process_next_message()
except (LMStudioWebsocketError, HTTPXWSException):
self._logger.exception("Websocket failed, terminating session.")
await self.disconnect()
break
while not self._terminate.is_set():
await self._process_next_message()

async def _notify_client_termination(self) -> None:
async def _notify_client_termination(self) -> int:
"""Send None to all clients with open receive queues."""
num_clients = 0
for rx_queue in self._mux.all_queues():
await rx_queue.put(None)
num_clients += 1
self._logger.info(
f"Notified {num_clients} clients of websocket termination",
num_clients=num_clients,
)
return num_clients

async def _connect_to_endpoint(self, channel: AsyncChannel[Any]) -> None:
"""Connect channel to specified endpoint."""
Expand All @@ -362,6 +391,9 @@ async def open_channel(
self._logger.event_context,
)
await self._connect_to_endpoint(channel)
if self._terminate.is_set():
# Link has been terminated, ensure client gets a response
await rx_queue.put(None)
yield channel

async def _send_call(
Expand Down Expand Up @@ -396,6 +428,9 @@ async def remote_call(
call_id, rx_queue, self._logger.event_context, notice_prefix
)
await self._send_call(rpc, endpoint, params)
if self._terminate.is_set():
# Link has been terminated, ensure client gets a response
await rx_queue.put(None)
return await rpc.receive_result()


Expand Down
4 changes: 2 additions & 2 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,10 @@ def _format_server_error(details: SerializedLMSExtendedError) -> str:
lines.extend(_get_data_lines(details.error_data, " "))
if details.cause is not None:
lines.extend(("", " Reported cause:"))
lines.extend(f" {details.cause}")
lines.append(f" {details.cause}")
if details.suggestion is not None:
lines.extend(("", " Suggested potential remedy:"))
lines.extend(f" {details.suggestion}")
lines.append(f" {details.suggestion}")
# Only use the multi-line format if at least one
# of the extended error fields is populated
if lines:
Expand Down
Loading