Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 135 additions & 67 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

# Synchronous API still uses an async websocket (just in a background thread)
from anyio import create_task_group
from anyio import create_task_group, get_cancelled_exc_class
from exceptiongroup import suppress
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException

Expand All @@ -47,43 +47,11 @@
T = TypeVar("T")


class BackgroundThread(threading.Thread):
"""Background async event loop thread."""

def __init__(
self,
task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None,
name: str | None = None,
) -> None:
# Accepts the same args as `threading.Thread`, *except*:
# * a `task_target` coroutine replaces the `target` function
# * No `daemon` option (always runs as a daemon)
# Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
# Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
self._task_target = task_target
self._loop_started = threading.Event()
self._terminate = asyncio.Event()
self._event_loop: asyncio.AbstractEventLoop | None = None
# Annoyingly, we have to mark the background thread as a daemon thread to
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
super().__init__(name=name, daemon=True)
weakref.finalize(self, self.terminate)

def run(self) -> None:
"""Run an async event loop in the background thread."""
# Only public to override threading.Thread.run
asyncio.run(self._run_until_terminated())

def wait_for_loop(self) -> asyncio.AbstractEventLoop | None:
"""Wait for the event loop to start from a synchronous foreground thread."""
if self._event_loop is None and not self._loop_started.is_set():
self._loop_started.wait()
return self._event_loop

async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None:
"""Wait for the event loop to start from an asynchronous foreground thread."""
return await asyncio.to_thread(self.wait_for_loop)
class _BackgroundTaskHandlerMixin:
# Subclasses need to handle providing these instance attributes
_event_loop: asyncio.AbstractEventLoop | None
_task_target: Callable[[], Coroutine[Any, Any, Any]] | None
_terminate: asyncio.Event

def called_in_background_loop(self) -> bool:
"""Returns true if currently running in this thread's event loop, false otherwise."""
Expand Down Expand Up @@ -123,10 +91,12 @@ async def terminate_async(self) -> bool:
"""Request termination of the event loop from an asynchronous foreground thread."""
return await asyncio.to_thread(self.terminate)

def _init_event_loop(self) -> None:
self._event_loop = asyncio.get_running_loop()

async def _run_until_terminated(self) -> None:
"""Run task in the background thread until termination is requested."""
self._event_loop = asyncio.get_running_loop()
self._loop_started.set()
self._init_event_loop()
# Use anyio and exceptiongroup to handle the lack of native task
# and exception groups prior to Python 3.11
raise_on_termination, terminated_exc = self._raise_on_termination()
Expand Down Expand Up @@ -163,6 +133,49 @@ def schedule_background_task(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T
assert loop is not None
return asyncio.run_coroutine_threadsafe(coro, loop)


class BackgroundThread(_BackgroundTaskHandlerMixin, threading.Thread):
"""Background async event loop thread."""

def __init__(
self,
task_target: Callable[[], Coroutine[Any, Any, Any]] | None = None,
name: str | None = None,
) -> None:
# Accepts the same args as `threading.Thread`, *except*:
# * a `task_target` coroutine replaces the `target` function
# * No `daemon` option (always runs as a daemon)
# Variant: accept `debug` and `loop_factory` options to forward to `asyncio.run`
# Alternative: accept a `task_runner` callback, defaulting to `asyncio.run`
self._task_target = task_target
self._loop_started = threading.Event()
self._terminate = asyncio.Event()
self._event_loop: asyncio.AbstractEventLoop | None = None
# Annoyingly, we have to mark the background thread as a daemon thread to
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
super().__init__(name=name, daemon=True)
weakref.finalize(self, self.terminate)

def run(self) -> None:
"""Run an async event loop in the background thread."""
# Only public to override threading.Thread.run
asyncio.run(self._run_until_terminated())

def _init_event_loop(self) -> None:
super()._init_event_loop()
self._loop_started.set()

def wait_for_loop(self) -> asyncio.AbstractEventLoop | None:
"""Wait for the event loop to start from a synchronous foreground thread."""
if self._event_loop is None and not self._loop_started.is_set():
self._loop_started.wait()
return self._event_loop

async def wait_for_loop_async(self) -> asyncio.AbstractEventLoop | None:
"""Wait for the event loop to start from an asynchronous foreground thread."""
return await asyncio.to_thread(self.wait_for_loop)

def run_background_task(self, coro: Coroutine[Any, Any, T]) -> T:
"""Run given coroutine in the background event loop and wait for the result."""
return self.schedule_background_task(coro).result()
Expand All @@ -178,62 +191,83 @@ def call_in_background(self, callback: Callable[[], Any]) -> None:
loop.call_soon_threadsafe(callback)


# TODO: Allow multiple websockets to share a single event loop thread
# (reduces thread usage in sync API, blocker for async API migration)
class AsyncWebsocketThread(BackgroundThread):
def __init__(self, log_context: LogEventContext | None = None) -> None:
super().__init__(task_target=self._run_main_task)
self._logger = logger = get_logger(type(self).__name__)
logger.update_context(log_context, thread_id=self.name)

async def _run_main_task(self) -> None:
self._logger.info("Websocket handling thread started")
never_set = asyncio.Event()
try:
# Run the event loop until termination is requested
await never_set.wait()
except get_cancelled_exc_class():
pass
except BaseException:
err_msg = "Terminating websocket thread due to exception"
self._logger.debug(err_msg, exc_info=True)
self._logger.info("Websocket thread terminated")


# TODO: Improve code sharing between AsyncWebsocketHandler and
# the async-native AsyncLMStudioWebsocket implementation
class AsyncWebsocketHandler(_BackgroundTaskHandlerMixin):
"""Async task handler for a single websocket connection."""

def __init__(
self,
ws_thread: AsyncWebsocketThread,
ws_url: str,
auth_details: DictObject,
enqueue_message: Callable[[DictObject], bool],
log_context: LogEventContext,
log_context: LogEventContext | None = None,
) -> None:
self._auth_details = auth_details
self._connection_attempted = asyncio.Event()
self._connection_failure: Exception | None = None
self._auth_failure: Any | None = None
self._terminate = asyncio.Event()
self._ws_thread = ws_thread
self._ws_url = ws_url
self._ws: AsyncWebSocketSession | None = None
self._rx_task: asyncio.Task[None] | None = None
self._queue_message = enqueue_message
super().__init__(task_target=self._run_main_task)
self._logger = get_logger(type(self).__name__)
self._logger = logger = get_logger(type(self).__name__)
logger.update_context(log_context, thread_id=self.name)
logger.update_context(log_context)

def connect(self) -> bool:
if not self.is_alive():
self.start()
loop = self.wait_for_loop() # Block until connection has been attempted
ws_thread = self._ws_thread
if not ws_thread.is_alive():
raise RuntimeError("Websocket handling thread has failed unexpectedly")
loop = ws_thread.wait_for_loop() # Block until loop is available
if loop is None:
return False
raise RuntimeError("Websocket handling thread has no event loop")
ws_thread.schedule_background_task(self._run_until_terminated())
asyncio.run_coroutine_threadsafe(
self._connection_attempted.wait(), loop
).result()
return self._ws is not None

def disconnect(self) -> None:
if self._ws is not None:
self.terminate()
# Ensure thread has terminated
self.join()

async def _run_main_task(self) -> None:
self._logger.info("Websocket thread started")
async def _task_target(self) -> None:
self._logger.info("Websocket handling task started")
self._init_event_loop()
try:
await self._main_task()
await self._handle_ws()
except get_cancelled_exc_class():
pass
except BaseException:
err_msg = "Terminating websocket thread due to exception"
err_msg = "Terminating websocket task due to exception"
self._logger.debug(err_msg, exc_info=True)
finally:
# Ensure the foreground thread is unblocked even if the
# background async task errors out completely
self._connection_attempted.set()
self._logger.info("Websocket thread terminated")
self._logger.info("Websocket task terminated")

# TODO: Improve code sharing between this background thread async websocket
# and the async-native AsyncLMStudioWebsocket implementation
async def _main_task(self) -> None:
async def _handle_ws(self) -> None:
resources = AsyncExitStack()
try:
ws: AsyncWebSocketSession = await resources.enter_async_context(
Expand Down Expand Up @@ -274,6 +308,10 @@ async def _send_json(self, message: DictObject) -> None:
self._logger.debug(str(err), exc_info=True)
raise err from None

def send_json(self, message: DictObject) -> None:
future = self.schedule_background_task(self._send_json(message))
future.result() # Block until the message is sent

async def _receive_json(self) -> Any:
# This is only called if the websocket has been created
assert self.called_in_background_loop()
Expand Down Expand Up @@ -335,8 +373,6 @@ async def _demultiplexing_task(self) -> None:
finally:
self._logger.info("Websocket closed, terminating demultiplexing task.")

raise_on_termination, terminated_exc = self._raise_on_termination()

async def _receive_messages(self) -> None:
"""Process received messages until task is cancelled."""
while True:
Expand All @@ -349,6 +385,38 @@ async def _receive_messages(self) -> None:
self._terminate.set()
break


class SyncToAsyncWebsocketBridge:
def __init__(
self,
ws_thread: AsyncWebsocketThread,
ws_url: str,
auth_details: DictObject,
enqueue_message: Callable[[DictObject], bool],
log_context: LogEventContext,
) -> None:
self._ws_handler = AsyncWebsocketHandler(
ws_thread, ws_url, auth_details, enqueue_message, log_context
)

def connect(self) -> bool:
return self._ws_handler.connect()

def disconnect(self) -> None:
self._ws_handler.terminate()

def send_json(self, message: DictObject) -> None:
# Block until message has been sent
self.run_background_task(self._send_json(message))
self._ws_handler.send_json(message)

# These attributes are currently accessed directly...
@property
def _ws(self) -> AsyncWebSocketSession | None:
return self._ws_handler._ws

@property
def _connection_failure(self) -> Exception | None:
return self._ws_handler._connection_failure

@property
def _auth_failure(self) -> Any | None:
return self._ws_handler._auth_failure
20 changes: 15 additions & 5 deletions src/lmstudio/sync_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
_model_spec_to_api_dict,
_redact_json,
)
from ._ws_impl import AsyncWebsocketThread
from ._ws_impl import AsyncWebsocketThread, SyncToAsyncWebsocketBridge
from ._kv_config import TLoadConfig, TLoadConfigDict, parse_server_config
from ._sdk_models import (
EmbeddingRpcCountTokensParameter,
Expand Down Expand Up @@ -233,17 +233,21 @@ def receive_result(self) -> Any:
return self._rpc.handle_rx_message(message)


class SyncLMStudioWebsocket(LMStudioWebsocket[AsyncWebsocketThread, queue.Queue[Any]]):
class SyncLMStudioWebsocket(
LMStudioWebsocket[SyncToAsyncWebsocketBridge, queue.Queue[Any]]
):
"""Synchronous websocket client that handles demultiplexing of reply messages."""

def __init__(
self,
ws_thread: AsyncWebsocketThread,
ws_url: str,
auth_details: DictObject,
log_context: LogEventContext | None = None,
) -> None:
"""Initialize synchronous websocket client."""
super().__init__(ws_url, auth_details, log_context)
self._ws_thread = ws_thread

@property
def _httpx_ws(self) -> AsyncWebSocketSession | None:
Expand All @@ -266,7 +270,8 @@ def __exit__(self, *args: Any) -> None:
def connect(self) -> Self:
"""Connect to and authenticate with the LM Studio API."""
self._fail_if_connected("Attempted to connect already connected websocket")
ws = AsyncWebsocketThread(
ws = SyncToAsyncWebsocketBridge(
self._ws_thread,
self._ws_url,
self._auth_details,
self._enqueue_message,
Expand Down Expand Up @@ -409,7 +414,9 @@ def connect(self) -> SyncLMStudioWebsocket:
session_url = f"ws://{api_host}/{namespace}"
resources = self._resources
self._lmsws = lmsws = resources.enter_context(
SyncLMStudioWebsocket(session_url, self._client._auth_details)
SyncLMStudioWebsocket(
self._client._ws_thread, session_url, self._client._auth_details
)
)
return lmsws

Expand Down Expand Up @@ -1482,8 +1489,11 @@ def __init__(self, api_host: str | None = None) -> None:
"""Initialize API client."""
super().__init__(api_host)
self._resources = rm = ExitStack()
self._ws_thread = ws_thread = AsyncWebsocketThread(dict(client=repr(self)))
ws_thread.start()
rm.callback(ws_thread.terminate)
self._sessions: dict[str, SyncSession] = {}
# Suport GC-based resource management in the sync API by
# Support GC-based resource management in the sync API by
# finalizing at the client layer, and letting its resource
# manager handle clearing up everything else
rm.callback(self._sessions.clear)
Expand Down
Loading