Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/lmstudio/_ws_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
TypeAlias,
TypeVar,
)
from typing_extensions import (
# Native in 3.11+
Self,
)

from anyio import create_task_group, move_on_after
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
Expand All @@ -45,13 +49,19 @@


class AsyncTaskManager:
def __init__(self, *, on_activation: Callable[[], Any] | None) -> None:
def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None:
self._activated = False
self._event_loop: asyncio.AbstractEventLoop | None = None
self._on_activation = on_activation
self._task_queue: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue()
self._terminate = asyncio.Event()
self._terminated = asyncio.Event()
# For the case where the task manager is run via its context manager
self._tm_started = asyncio.Event()
self._tm_task: asyncio.Task[Any] | None = None

ACTIVATION_TIMEOUT = 5 # Just starts an async task, should be fast
TERMINATION_TIMEOUT = 20 # May have to shut down TCP links

@property
def activated(self) -> bool:
Expand All @@ -65,6 +75,20 @@ def active(self) -> bool:
and not self._terminated.is_set()
)

async def __aenter__(self) -> Self:
# Handle reentrancy the same way files do:
# allow nested use as a CM, but close on the first exit
if self._tm_task is None:
self._tm_task = asyncio.create_task(self.run_until_terminated())
with move_on_after(self.ACTIVATION_TIMEOUT):
await self._tm_started.wait()
return self

async def __aexit__(self, *args: Any) -> None:
await self.request_termination()
with move_on_after(self.TERMINATION_TIMEOUT):
await self._terminated.wait()

def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
"""Returns if running in this manager's event loop, raises RuntimeError otherwise."""
this_loop = self._event_loop
Expand Down Expand Up @@ -138,6 +162,7 @@ def _init_event_loop(self) -> None:
notify = self._on_activation
if notify is not None:
notify()
self._tm_started.set()

async def run_until_terminated(
self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None
Expand Down Expand Up @@ -218,9 +243,6 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle:
AsyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]]


# TODO: Improve code sharing between AsyncWebsocketHandler and
# the async-native AsyncLMStudioWebsocket implementation
# (likely by migrating the websocket over to using the handler)
class AsyncWebsocketHandler:
"""Async task handler for a single websocket connection."""

Expand All @@ -243,7 +265,7 @@ def __init__(
self._ws_disconnected = asyncio.Event()
self._rx_task: asyncio.Task[None] | None = None
self._logger = logger = new_logger(type(self).__name__)
logger.update_context(log_context)
logger.update_context(log_context, ws_url=ws_url)
self._mux = MultiplexingManager(logger)

async def connect(self) -> bool:
Expand Down Expand Up @@ -275,7 +297,7 @@ def disconnect_threadsafe(self) -> None:
task_manager.run_coroutine_threadsafe(self.disconnect()).result()

async def _logged_ws_handler(self) -> None:
self._logger.info("Websocket handling task started")
self._logger.debug("Websocket handling task started")
try:
await self._handle_ws()
except (asyncio.CancelledError, GeneratorExit):
Expand All @@ -287,7 +309,7 @@ async def _logged_ws_handler(self) -> None:
# Ensure connections attempt are unblocked even if the
# background async task errors out completely
self._connection_attempted.set()
self._logger.info("Websocket task terminated")
self._logger.debug("Websocket task terminated")

async def _handle_ws(self) -> None:
assert self._task_manager.check_running_in_task_loop()
Expand All @@ -311,12 +333,12 @@ def _clear_task_state() -> None:
if not await self._authenticate():
return
self._connection_attempted.set()
self._logger.info(f"Websocket session established ({self._ws_url})")
self._logger.info("Websocket session established")
# Task will run until message reception fails or is cancelled
try:
await self._receive_messages()
finally:
self._logger.info("Websocket demultiplexing task terminated.")
self._logger.debug("Websocket demultiplexing task terminated.")
# Notify foreground thread of background thread termination
# (this covers termination due to link failure)
await self.notify_client_termination()
Expand Down
Loading