Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions misc/atexit-example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python
"""Check accessing the default client from an atexit hook."""
import atexit

# Intentionally named with a hyphen to ensure this can't be imported

import lmstudio as lms

# TODO: Turn this into a CI test case (perhaps via subprocess invocation?)

# Prior to lmstudio-python 1.5.0, the atexit hook below would hang on shutdown.
# By the time atexit hooks run, asyncio.to_thread no longer works due to all
# concurrent.futures managed thread pools (including those used by asyncio)
# being shut down before the interpreter waits for non-daemon threads to terminate.
# Since the synchronous client relied on asyncio.to_thread to deliver
# messages from the async background comms thread to the blocking
# foreground thread, the sync API didn't work in this scenario.
# In 1.5.0, the sync message reception was reworked to queue messages entirely
# in the async background thread with blocking async queue reads, eliminating
# the blocking queue write operations, and allowing the client to continue running
# in atexit threads
@atexit.register
def access_default_client():
"""Ensure default client can be accessed from an atexit hook."""
print("During shutdown:", end=" ", flush=True)
print(lms.list_loaded_models())

print("Prior to shutdown:", lms.list_loaded_models())
88 changes: 80 additions & 8 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
import weakref

from concurrent.futures import Future as SyncFuture
from contextlib import (
AsyncExitStack,
)
from contextlib import AsyncExitStack
from functools import partial
from typing import (
Any,
Awaitable,
Coroutine,
Callable,
Iterable,
TypeVar,
)

Expand Down Expand Up @@ -221,6 +221,8 @@ def __init__(
# Annoyingly, we have to mark the background thread as a daemon thread to
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
# TODO: skip thread daemonization when running in a subinterpreter
# (and also disable the convenience API in subinterpreters to avoid hanging on shutdown)
super().__init__(name=name, daemon=True)
weakref.finalize(self, self.terminate)

Expand Down Expand Up @@ -278,6 +280,25 @@ def call_in_background(self, func: Callable[[], Any]) -> None:
self._task_manager.call_soon_threadsafe(func)


# By default, the weakref finalization atexit hook is registered lazily.
# This can lead to shutdown sequencing issues if SDK users attempt to access
# client instances (such as the default sync client) from atexit hooks
# registered at import time (so they may end up running after the weakref
# finalization hook has already terminated background threads)
# Creating this finalizer here ensures the weakref finalization hook is
# registered at import time, and hence runs *after* any such hooks
# (assuming the lmstudio SDK is imported before the hooks are registered)
def _register_weakref_atexit_hook() -> None:
class C:
pass

weakref.finalize(C(), int)


_register_weakref_atexit_hook()
del _register_weakref_atexit_hook


class AsyncWebsocketThread(BackgroundThread):
def __init__(self, log_context: LogEventContext | None = None) -> None:
super().__init__(task_target=self._log_thread_execution)
Expand Down Expand Up @@ -439,6 +460,20 @@ def send_json_threadsafe(self, message: DictObject) -> None:
future = self._task_manager.run_coroutine_threadsafe(self.send_json(message))
future.result() # Block until the message is sent

def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:
"""Run given coroutine in the event loop and wait for the result."""
return self._task_manager.run_coroutine_threadsafe(coro).result()

def rx_queue_get_threadsafe(self, rx_queue: asyncio.Queue[Any]) -> Any:
future = self._task_manager.run_coroutine_threadsafe(rx_queue.get())
return future.result()

def rx_queue_put_threadsafe(
self, rx_queue: asyncio.Queue[Any], message: Any
) -> None:
future = self._task_manager.run_coroutine_threadsafe(rx_queue.put(message))
return future.result()

async def _receive_json(self) -> Any:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
Expand Down Expand Up @@ -502,15 +537,21 @@ def __init__(
ws_thread: AsyncWebsocketThread,
ws_url: str,
auth_details: DictObject,
enqueue_message: Callable[[DictObject | None], bool],
get_queue: Callable[[DictObject | None], asyncio.Queue[Any] | None],
iter_queues: Callable[[], Iterable[asyncio.Queue[Any]]],
log_context: LogEventContext,
) -> None:
async def enqueue_async(message: DictObject | None) -> bool:
return await asyncio.to_thread(enqueue_message, message)

self._get_queue = get_queue
self._iter_queues = iter_queues
self._ws_handler = AsyncWebsocketHandler(
ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context
ws_thread.task_manager,
ws_url,
auth_details,
self._enqueue_message,
log_context,
)
self._logger = logger = new_logger(type(self).__name__)
logger.update_context(log_context)

def connect(self) -> bool:
return self._ws_handler.connect_threadsafe()
Expand All @@ -521,6 +562,37 @@ def disconnect(self) -> None:
def send_json(self, message: DictObject) -> None:
self._ws_handler.send_json_threadsafe(message)

def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[], Any]]:
rx_queue: asyncio.Queue[Any] = asyncio.Queue()
return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue)

async def _enqueue_message(self, message: Any) -> bool:
rx_queue = self._get_queue(message)
if message is None:
return await self.notify_client_termination() > 0
if rx_queue is None:
return False
await rx_queue.put(message)
return True

async def notify_client_termination(self) -> int:
"""Send None to all clients with open receive queues (from background thread)."""
num_clients = 0
for rx_queue in self._iter_queues():
await rx_queue.put(None)
num_clients += 1
self._logger.debug(
f"Notified {num_clients} clients of websocket termination",
num_clients=num_clients,
)
return num_clients

def notify_client_termination_threadsafe(self) -> int:
"""Send None to all clients with open receive queues (from foreground thread)."""
return self._ws_handler.run_background_coroutine(
self.notify_client_termination()
)

# These attributes are currently accessed directly...
@property
def _ws(self) -> AsyncWebSocketSession | None:
Expand Down
2 changes: 2 additions & 0 deletions src/lmstudio/json_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ def _redact_json(data: DictObject | None) -> DictObject | None:
return redacted


# TODO: Now that even the sync API uses asyncio.Queue,
# change the multiplexing manager to no longer be generic
TQueue = TypeVar("TQueue")


Expand Down
69 changes: 26 additions & 43 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Sync I/O protocol implementation for the LM Studio remote access API."""

import asyncio
import itertools
import time
import queue
import weakref

from abc import abstractmethod
Expand Down Expand Up @@ -158,14 +158,14 @@ class SyncChannel(Generic[T]):
def __init__(
self,
channel_id: int,
rx_queue: queue.Queue[Any],
get_message: Callable[[], Any],
endpoint: ChannelEndpoint[T, Any, Any],
send_json: Callable[[DictObject], None],
log_context: LogEventContext,
) -> None:
"""Initialize synchronous websocket streaming channel."""
self._is_finished = False
self._rx_queue = rx_queue
self._get_message = get_message
self._api_channel = ChannelHandler(channel_id, endpoint, log_context)
self._send_json = send_json

Expand Down Expand Up @@ -193,7 +193,7 @@ def rx_stream(
with sdk_public_api():
# Avoid emitting tracebacks that delve into supporting libraries
# (we can't easily suppress the SDK's own frames for iterators)
message = self._rx_queue.get()
message = self._get_message()
contents = self._api_channel.handle_rx_message(message)
if contents is None:
self._is_finished = True
Expand All @@ -216,12 +216,12 @@ class SyncRemoteCall:
def __init__(
self,
call_id: int,
rx_queue: queue.Queue[Any],
get_message: Callable[[], Any],
log_context: LogEventContext,
notice_prefix: str = "RPC",
) -> None:
"""Initialize synchronous remote procedure call."""
self._rx_queue = rx_queue
self._get_message = get_message
self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix)
self._logger = logger = new_logger(type(self).__name__)
logger.update_context(log_context, call_id=call_id)
Expand All @@ -234,12 +234,12 @@ def get_rpc_message(

def receive_result(self) -> Any:
"""Receive call response on the receive queue."""
message = self._rx_queue.get()
message = self._get_message()
return self._rpc.handle_rx_message(message)


class SyncLMStudioWebsocket(
LMStudioWebsocket[SyncToAsyncWebsocketBridge, queue.Queue[Any]]
LMStudioWebsocket[SyncToAsyncWebsocketBridge, asyncio.Queue[Any]]
):
"""Synchronous websocket client that handles demultiplexing of reply messages."""

Expand Down Expand Up @@ -279,7 +279,8 @@ def connect(self) -> Self:
self._ws_thread,
self._ws_url,
self._auth_details,
self._enqueue_message,
self._get_rx_queue,
self._mux.all_queues,
self._logger.event_context,
)
if not ws.connect():
Expand All @@ -298,42 +299,26 @@ def disconnect(self) -> None:
self._ws = None
if ws is not None:
self._logger.debug(f"Disconnecting websocket session ({self._ws_url})")
self._notify_client_termination()
ws.notify_client_termination_threadsafe()
ws.disconnect()
self._logger.info(f"Websocket session disconnected ({self._ws_url})")

close = disconnect

def _enqueue_message(self, message: Any) -> bool:
if message is None:
self._logger.info(f"Websocket session failed ({self._ws_url})")
self._ws = None
return self._notify_client_termination() > 0
rx_queue = self._mux.map_rx_message(message)
if rx_queue is None:
return False
rx_queue.put(message)
return True

def _notify_client_termination(self) -> int:
"""Send None to all clients with open receive queues."""
num_clients = 0
for rx_queue in self._mux.all_queues():
rx_queue.put(None)
num_clients += 1
self._logger.debug(
f"Notified {num_clients} clients of websocket termination",
num_clients=num_clients,
)
return num_clients

def _send_json(self, message: DictObject) -> None:
# Callers are expected to call `_ensure_connected` before this method
ws = self._ws
assert ws is not None
# Background thread handles the exception conversion
ws.send_json(message)

def _get_rx_queue(self, message: Any) -> asyncio.Queue[Any] | None:
if message is None:
self._logger.info(f"Websocket session failed ({self._ws_url})")
self._ws = None
return None
return self._mux.map_rx_message(message)

def _connect_to_endpoint(self, channel: SyncChannel[Any]) -> None:
"""Connect channel to specified endpoint."""
self._ensure_connected("open channel endpoints")
Expand All @@ -347,19 +332,18 @@ def open_channel(
endpoint: ChannelEndpoint[T, Any, Any],
) -> Generator[SyncChannel[T], None, None]:
"""Open a streaming channel over the websocket."""
rx_queue: queue.Queue[Any] = queue.Queue()
ws = self._ws
assert ws is not None
rx_queue, getter = ws.new_rx_queue()
with self._mux.assign_channel_id(rx_queue) as channel_id:
channel = SyncChannel(
channel_id,
rx_queue,
getter,
endpoint,
self._send_json,
self._logger.event_context,
)
self._connect_to_endpoint(channel)
if self._ws is None:
# Link has been terminated, ensure client gets a response
rx_queue.put(None)
yield channel

def _send_call(
Expand Down Expand Up @@ -388,15 +372,14 @@ def remote_call(
notice_prefix: str = "RPC",
) -> Any:
"""Make a remote procedure call over the websocket."""
rx_queue: queue.Queue[Any] = queue.Queue()
ws = self._ws
assert ws is not None
rx_queue, getter = ws.new_rx_queue()
with self._mux.assign_call_id(rx_queue) as call_id:
rpc = SyncRemoteCall(
call_id, rx_queue, self._logger.event_context, notice_prefix
call_id, getter, self._logger.event_context, notice_prefix
)
self._send_call(rpc, endpoint, params)
if self._ws is None:
# Link has been terminated, ensure client gets a response
rx_queue.put(None)
return rpc.receive_result()


Expand Down