diff --git a/src/lmstudio/_ws_impl.py b/src/lmstudio/_ws_impl.py index 2565f29..47eb7ff 100644 --- a/src/lmstudio/_ws_impl.py +++ b/src/lmstudio/_ws_impl.py @@ -15,6 +15,7 @@ # Python 3.10 compatibility: use concurrent.futures.TimeoutError instead of the builtin # In 3.11+, these are the same type, in 3.10 futures have their own timeout exception from concurrent.futures import Future as SyncFuture, TimeoutError as SyncFutureTimeout +from contextvars import ContextVar from contextlib import AsyncExitStack, contextmanager from functools import partial from typing import ( @@ -22,6 +23,7 @@ Awaitable, Coroutine, Callable, + ClassVar, Generator, TypeAlias, TypeVar, @@ -35,6 +37,7 @@ from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException from .schemas import DictObject +from .sdk_api import LMStudioRuntimeError from .json_api import ( LMStudioWebsocket, LMStudioWebsocketError, @@ -58,6 +61,8 @@ class AsyncTaskManager: + _LMS_TASK_MANAGER: ClassVar[ContextVar[Self]] = ContextVar("_LMS_TASK_MANAGER") + def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None: self._activated = False self._event_loop: asyncio.AbstractEventLoop | None = None @@ -98,15 +103,19 @@ async def __aexit__(self, *args: Any) -> None: with move_on_after(self.TERMINATION_TIMEOUT): await self._terminated.wait() - def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool: - """Returns if running in this manager's event loop, raises RuntimeError otherwise.""" + @classmethod + def get_running_task_manager(cls) -> Self: + try: + return cls._LMS_TASK_MANAGER.get() + except LookupError: + err_msg = "No async task manager active in current context" + raise LMStudioRuntimeError(err_msg) from None + + def ensure_running_in_task_loop(self) -> None: this_loop = self._event_loop if this_loop is None: # Task manager isn't active -> no coroutine can be running in it - if allow_inactive: - # No exception, but indicate the task manager isn't actually running - return False - raise RuntimeError(f"{self!r} is currently inactive.") + raise LMStudioRuntimeError(f"{self!r} is currently inactive.") try: running_loop = asyncio.get_running_loop() except RuntimeError: @@ -116,12 +125,27 @@ def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool: if running_loop is not this_loop: err_details = f"Expected: {this_loop!r} Running: {running_loop!r}" err_msg = f"{self!r} is running in a different event loop ({err_details})." - raise RuntimeError(err_msg) + raise LMStudioRuntimeError(err_msg) + + def is_running_in_task_loop(self) -> bool: + try: + self.ensure_running_in_task_loop() + except LMStudioRuntimeError: + return False return True + def ensure_running_in_task_manager(self) -> None: + # Task manager must be active in the running event loop + self.ensure_running_in_task_loop() + running_tm = self.get_running_task_manager() + if running_tm is not self: + err_details = f"Expected: {self!r} Running: {running_tm!r}" + err_msg = f"Task is running in a different task manager ({err_details})." + raise LMStudioRuntimeError(err_msg) + async def request_termination(self) -> bool: """Request termination of the task manager from the same thread.""" - if not self.check_running_in_task_loop(allow_inactive=True): + if not self.is_running_in_task_loop(): return False if self._terminate.is_set(): return False @@ -139,7 +163,7 @@ def request_termination_threadsafe(self) -> SyncFuture[bool]: async def wait_for_termination(self) -> None: """Wait in the same thread for the task manager to indicate it has terminated.""" - if not self.check_running_in_task_loop(allow_inactive=True): + if not self.is_running_in_task_loop(): return await self._terminated.wait() @@ -163,11 +187,13 @@ def terminate_threadsafe(self) -> None: if self.request_termination_threadsafe().result(): self.wait_for_termination_threadsafe() - def _init_event_loop(self) -> None: + def _mark_as_running(self: Self) -> None: + # Explicit type hint to work around https://github.com/python/mypy/issues/16871 if self._event_loop is not None: - raise RuntimeError() + raise LMStudioRuntimeError("Async task manager is already running") self._event_loop = asyncio.get_running_loop() self._activated = True + self._LMS_TASK_MANAGER.set(self) notify = self._on_activation if notify is not None: notify() @@ -177,7 +203,7 @@ async def run_until_terminated( self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None ) -> None: """Run task manager until termination is requested.""" - self._init_event_loop() + self._mark_as_running() # Use anyio and exceptiongroup to handle the lack of native task # and exception groups prior to Python 3.11 try: @@ -206,7 +232,7 @@ async def schedule_task(self, func: Callable[[], Awaitable[Any]]) -> None: Important: task must NOT access any scoped resources from the scheduling scope. """ - self.check_running_in_task_loop() + self.ensure_running_in_task_loop() await self._task_queue.put(func) def schedule_task_threadsafe(self, func: Callable[[], Awaitable[Any]]) -> None: @@ -214,10 +240,7 @@ def schedule_task_threadsafe(self, func: Callable[[], Awaitable[Any]]) -> None: Important: task must NOT access any scoped resources from the scheduling scope. """ - loop = self._event_loop - if loop is None: - raise RuntimeError(f"{self!r} is currently inactive.") - asyncio.run_coroutine_threadsafe(self.schedule_task(func), loop) + self.run_coroutine_threadsafe(self.schedule_task(func)) def run_coroutine_threadsafe(self, coro: Coroutine[Any, Any, T]) -> SyncFuture[T]: """Call given coroutine in the task manager's event loop from any thread. @@ -280,7 +303,6 @@ def __init__( async def connect(self) -> bool: """Connect websocket from the task manager's event loop.""" task_manager = self._task_manager - assert task_manager.check_running_in_task_loop() await task_manager.schedule_task(self._logged_ws_handler) await self._connection_attempted.wait() return self._ws is not None @@ -293,7 +315,9 @@ def connect_threadsafe(self) -> bool: async def disconnect(self) -> None: """Disconnect websocket from the task manager's event loop.""" - assert self._task_manager.check_running_in_task_loop() + self._task_manager.ensure_running_in_task_loop() + # Websocket handler task may already have been cancelled, + # but the closure can be requested multiple times without issue self._ws_disconnected.set() ws = self._ws if ws is None: @@ -321,9 +345,10 @@ async def _logged_ws_handler(self) -> None: self._logger.debug("Websocket task terminated") async def _handle_ws(self) -> None: - assert self._task_manager.check_running_in_task_loop() resources = AsyncExitStack() try: + # For reliable shutdown, handler must run entirely inside the task manager + self._task_manager.ensure_running_in_task_manager() ws: AsyncWebSocketSession = await resources.enter_async_context( aconnect_ws(self._ws_url) ) @@ -370,7 +395,7 @@ def _clear_task_state() -> None: async def send_json(self, message: DictObject) -> None: # This is only called if the websocket has been created - assert self._task_manager.check_running_in_task_loop() + self._task_manager.ensure_running_in_task_loop() ws = self._ws if ws is None: # Assume app is shutting down and the owning task has already been cancelled @@ -396,14 +421,14 @@ def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T: @contextmanager def open_channel(self) -> Generator[AsyncChannelInfo, None, None]: - assert self._task_manager.check_running_in_task_loop() + self._task_manager.ensure_running_in_task_loop() rx_queue: RxQueue = asyncio.Queue() with self._mux.assign_channel_id(rx_queue) as call_id: yield call_id, rx_queue.get @contextmanager def start_call(self) -> Generator[AsyncRemoteCallInfo, None, None]: - assert self._task_manager.check_running_in_task_loop() + self._task_manager.ensure_running_in_task_loop() rx_queue: RxQueue = asyncio.Queue() with self._mux.assign_call_id(rx_queue) as call_id: yield call_id, rx_queue.get @@ -444,7 +469,9 @@ def _rx_queue_get_threadsafe(self, rx_queue: RxQueue, timeout: float | None) -> async def _receive_json(self) -> Any: # This is only called if the websocket has been created - assert self._task_manager.check_running_in_task_loop() + if __debug__: + # This should only be called as part of the self._handle_ws task + self._task_manager.ensure_running_in_task_manager() ws = self._ws if ws is None: # Assume app is shutting down and the owning task has already been cancelled @@ -459,7 +486,9 @@ async def _receive_json(self) -> Any: async def _authenticate(self) -> bool: # This is only called if the websocket has been created - assert self._task_manager.check_running_in_task_loop() + if __debug__: + # This should only be called as part of the self._handle_ws task + self._task_manager.ensure_running_in_task_manager() ws = self._ws if ws is None: # Assume app is shutting down and the owning task has already been cancelled @@ -479,7 +508,9 @@ async def _process_next_message(self) -> bool: Returns True if a message queue was updated. """ # This is only called if the websocket has been created - assert self._task_manager.check_running_in_task_loop() + if __debug__: + # This should only be called as part of the self._handle_ws task + self._task_manager.ensure_running_in_task_manager() ws = self._ws if ws is None: # Assume app is shutting down and the owning task has already been cancelled