Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 57 additions & 26 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
# Python 3.10 compatibility: use concurrent.futures.TimeoutError instead of the builtin
# In 3.11+, these are the same type, in 3.10 futures have their own timeout exception
from concurrent.futures import Future as SyncFuture, TimeoutError as SyncFutureTimeout
from contextvars import ContextVar
from contextlib import AsyncExitStack, contextmanager
from functools import partial
from typing import (
Any,
Awaitable,
Coroutine,
Callable,
ClassVar,
Generator,
TypeAlias,
TypeVar,
Expand All @@ -35,6 +37,7 @@
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException

from .schemas import DictObject
from .sdk_api import LMStudioRuntimeError
from .json_api import (
LMStudioWebsocket,
LMStudioWebsocketError,
Expand All @@ -58,6 +61,8 @@


class AsyncTaskManager:
_LMS_TASK_MANAGER: ClassVar[ContextVar[Self]] = ContextVar("_LMS_TASK_MANAGER")

def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None:
self._activated = False
self._event_loop: asyncio.AbstractEventLoop | None = None
Expand Down Expand Up @@ -98,15 +103,19 @@ async def __aexit__(self, *args: Any) -> None:
with move_on_after(self.TERMINATION_TIMEOUT):
await self._terminated.wait()

def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
"""Returns if running in this manager's event loop, raises RuntimeError otherwise."""
@classmethod
def get_running_task_manager(cls) -> Self:
try:
return cls._LMS_TASK_MANAGER.get()
except LookupError:
err_msg = "No async task manager active in current context"
raise LMStudioRuntimeError(err_msg) from None

def ensure_running_in_task_loop(self) -> None:
this_loop = self._event_loop
if this_loop is None:
# Task manager isn't active -> no coroutine can be running in it
if allow_inactive:
# No exception, but indicate the task manager isn't actually running
return False
raise RuntimeError(f"{self!r} is currently inactive.")
raise LMStudioRuntimeError(f"{self!r} is currently inactive.")
try:
running_loop = asyncio.get_running_loop()
except RuntimeError:
Expand All @@ -116,12 +125,27 @@ def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
if running_loop is not this_loop:
err_details = f"Expected: {this_loop!r} Running: {running_loop!r}"
err_msg = f"{self!r} is running in a different event loop ({err_details})."
raise RuntimeError(err_msg)
raise LMStudioRuntimeError(err_msg)

def is_running_in_task_loop(self) -> bool:
try:
self.ensure_running_in_task_loop()
except LMStudioRuntimeError:
return False
return True

def ensure_running_in_task_manager(self) -> None:
# Task manager must be active in the running event loop
self.ensure_running_in_task_loop()
running_tm = self.get_running_task_manager()
if running_tm is not self:
err_details = f"Expected: {self!r} Running: {running_tm!r}"
err_msg = f"Task is running in a different task manager ({err_details})."
raise LMStudioRuntimeError(err_msg)

async def request_termination(self) -> bool:
"""Request termination of the task manager from the same thread."""
if not self.check_running_in_task_loop(allow_inactive=True):
if not self.is_running_in_task_loop():
return False
if self._terminate.is_set():
return False
Expand All @@ -139,7 +163,7 @@ def request_termination_threadsafe(self) -> SyncFuture[bool]:

async def wait_for_termination(self) -> None:
"""Wait in the same thread for the task manager to indicate it has terminated."""
if not self.check_running_in_task_loop(allow_inactive=True):
if not self.is_running_in_task_loop():
return
await self._terminated.wait()

Expand All @@ -163,11 +187,13 @@ def terminate_threadsafe(self) -> None:
if self.request_termination_threadsafe().result():
self.wait_for_termination_threadsafe()

def _init_event_loop(self) -> None:
def _mark_as_running(self: Self) -> None:
# Explicit type hint to work around https://github.com/python/mypy/issues/16871
if self._event_loop is not None:
raise RuntimeError()
raise LMStudioRuntimeError("Async task manager is already running")
self._event_loop = asyncio.get_running_loop()
self._activated = True
self._LMS_TASK_MANAGER.set(self)
notify = self._on_activation
if notify is not None:
notify()
Expand All @@ -177,7 +203,7 @@ async def run_until_terminated(
self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None
) -> None:
"""Run task manager until termination is requested."""
self._init_event_loop()
self._mark_as_running()
# Use anyio and exceptiongroup to handle the lack of native task
# and exception groups prior to Python 3.11
try:
Expand Down Expand Up @@ -206,18 +232,15 @@ async def schedule_task(self, func: Callable[[], Awaitable[Any]]) -> None:

Important: task must NOT access any scoped resources from the scheduling scope.
"""
self.check_running_in_task_loop()
self.ensure_running_in_task_loop()
await self._task_queue.put(func)

def schedule_task_threadsafe(self, func: Callable[[], Awaitable[Any]]) -> None:
"""Schedule given task in the task manager's base coroutine from any thread.

Important: task must NOT access any scoped resources from the scheduling scope.
"""
loop = self._event_loop
if loop is None:
raise RuntimeError(f"{self!r} is currently inactive.")
asyncio.run_coroutine_threadsafe(self.schedule_task(func), loop)
self.run_coroutine_threadsafe(self.schedule_task(func))

def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T]:
"""Call given coroutine in the task manager's event loop from any thread.
Expand Down Expand Up @@ -280,7 +303,6 @@ def __init__(
async def connect(self) -> bool:
"""Connect websocket from the task manager's event loop."""
task_manager = self._task_manager
assert task_manager.check_running_in_task_loop()
await task_manager.schedule_task(self._logged_ws_handler)
await self._connection_attempted.wait()
return self._ws is not None
Expand All @@ -293,7 +315,9 @@ def connect_threadsafe(self) -> bool:

async def disconnect(self) -> None:
"""Disconnect websocket from the task manager's event loop."""
assert self._task_manager.check_running_in_task_loop()
self._task_manager.ensure_running_in_task_loop()
# Websocket handler task may already have been cancelled,
# but the closure can be requested multiple times without issue
self._ws_disconnected.set()
ws = self._ws
if ws is None:
Expand Down Expand Up @@ -321,9 +345,10 @@ async def _logged_ws_handler(self) -> None:
self._logger.debug("Websocket task terminated")

async def _handle_ws(self) -> None:
assert self._task_manager.check_running_in_task_loop()
resources = AsyncExitStack()
try:
# For reliable shutdown, handler must run entirely inside the task manager
self._task_manager.ensure_running_in_task_manager()
ws: AsyncWebSocketSession = await resources.enter_async_context(
aconnect_ws(self._ws_url)
)
Expand Down Expand Up @@ -370,7 +395,7 @@ def _clear_task_state() -> None:

async def send_json(self, message: DictObject) -> None:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
self._task_manager.ensure_running_in_task_loop()
ws = self._ws
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
Expand All @@ -396,14 +421,14 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:

@contextmanager
def open_channel(self) -> Generator[AsyncChannelInfo, None, None]:
assert self._task_manager.check_running_in_task_loop()
self._task_manager.ensure_running_in_task_loop()
rx_queue: RxQueue = asyncio.Queue()
with self._mux.assign_channel_id(rx_queue) as call_id:
yield call_id, rx_queue.get

@contextmanager
def start_call(self) -> Generator[AsyncRemoteCallInfo, None, None]:
assert self._task_manager.check_running_in_task_loop()
self._task_manager.ensure_running_in_task_loop()
rx_queue: RxQueue = asyncio.Queue()
with self._mux.assign_call_id(rx_queue) as call_id:
yield call_id, rx_queue.get
Expand Down Expand Up @@ -444,7 +469,9 @@ def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) ->

async def _receive_json(self) -> Any:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
if __debug__:
# This should only be called as part of the self._handle_ws task
self._task_manager.ensure_running_in_task_manager()
ws = self._ws
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
Expand All @@ -459,7 +486,9 @@ async def _receive_json(self) -> Any:

async def _authenticate(self) -> bool:
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
if __debug__:
# This should only be called as part of the self._handle_ws task
self._task_manager.ensure_running_in_task_manager()
ws = self._ws
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
Expand All @@ -479,7 +508,9 @@ async def _process_next_message(self) -> bool:
Returns True if a message queue was updated.
"""
# This is only called if the websocket has been created
assert self._task_manager.check_running_in_task_loop()
if __debug__:
# This should only be called as part of the self._handle_ws task
self._task_manager.ensure_running_in_task_manager()
ws = self._ws
if ws is None:
# Assume app is shutting down and the owning task has already been cancelled
Expand Down