Skip to content

Commit a96cfec

Browse files
authored
Migrate async API to AsyncWebsocketHandler (#126)
1 parent c1796ff commit a96cfec

File tree

5 files changed

+101
-178
lines changed

5 files changed

+101
-178
lines changed

src/lmstudio/_ws_impl.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
TypeAlias,
2424
TypeVar,
2525
)
26+
from typing_extensions import (
27+
# Native in 3.11+
28+
Self,
29+
)
2630

2731
from anyio import create_task_group, move_on_after
2832
from httpx_ws import aconnect_ws, AsyncWebSocketSession, HTTPXWSException
@@ -45,13 +49,19 @@
4549

4650

4751
class AsyncTaskManager:
48-
def __init__(self, *, on_activation: Callable[[], Any] | None) -> None:
52+
def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None:
4953
self._activated = False
5054
self._event_loop: asyncio.AbstractEventLoop | None = None
5155
self._on_activation = on_activation
5256
self._task_queue: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue()
5357
self._terminate = asyncio.Event()
5458
self._terminated = asyncio.Event()
59+
# For the case where the task manager is run via its context manager
60+
self._tm_started = asyncio.Event()
61+
self._tm_task: asyncio.Task[Any] | None = None
62+
63+
ACTIVATION_TIMEOUT = 5 # Just starts an async task, should be fast
64+
TERMINATION_TIMEOUT = 20 # May have to shut down TCP links
5565

5666
@property
5767
def activated(self) -> bool:
@@ -65,6 +75,20 @@ def active(self) -> bool:
6575
and not self._terminated.is_set()
6676
)
6777

78+
async def __aenter__(self) -> Self:
79+
# Handle reentrancy the same way files do:
80+
# allow nested use as a CM, but close on the first exit
81+
if self._tm_task is None:
82+
self._tm_task = asyncio.create_task(self.run_until_terminated())
83+
with move_on_after(self.ACTIVATION_TIMEOUT):
84+
await self._tm_started.wait()
85+
return self
86+
87+
async def __aexit__(self, *args: Any) -> None:
88+
await self.request_termination()
89+
with move_on_after(self.TERMINATION_TIMEOUT):
90+
await self._terminated.wait()
91+
6892
def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
6993
"""Returns if running in this manager's event loop, raises RuntimeError otherwise."""
7094
this_loop = self._event_loop
@@ -138,6 +162,7 @@ def _init_event_loop(self) -> None:
138162
notify = self._on_activation
139163
if notify is not None:
140164
notify()
165+
self._tm_started.set()
141166

142167
async def run_until_terminated(
143168
self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None
@@ -218,9 +243,6 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle:
218243
AsyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]]
219244

220245

221-
# TODO: Improve code sharing between AsyncWebsocketHandler and
222-
# the async-native AsyncLMStudioWebsocket implementation
223-
# (likely by migrating the websocket over to using the handler)
224246
class AsyncWebsocketHandler:
225247
"""Async task handler for a single websocket connection."""
226248

@@ -243,7 +265,7 @@ def __init__(
243265
self._ws_disconnected = asyncio.Event()
244266
self._rx_task: asyncio.Task[None] | None = None
245267
self._logger = logger = new_logger(type(self).__name__)
246-
logger.update_context(log_context)
268+
logger.update_context(log_context, ws_url=ws_url)
247269
self._mux = MultiplexingManager(logger)
248270

249271
async def connect(self) -> bool:
@@ -275,7 +297,7 @@ def disconnect_threadsafe(self) -> None:
275297
task_manager.run_coroutine_threadsafe(self.disconnect()).result()
276298

277299
async def _logged_ws_handler(self) -> None:
278-
self._logger.info("Websocket handling task started")
300+
self._logger.debug("Websocket handling task started")
279301
try:
280302
await self._handle_ws()
281303
except (asyncio.CancelledError, GeneratorExit):
@@ -287,7 +309,7 @@ async def _logged_ws_handler(self) -> None:
287309
# Ensure connections attempt are unblocked even if the
288310
# background async task errors out completely
289311
self._connection_attempted.set()
290-
self._logger.info("Websocket task terminated")
312+
self._logger.debug("Websocket task terminated")
291313

292314
async def _handle_ws(self) -> None:
293315
assert self._task_manager.check_running_in_task_loop()
@@ -311,12 +333,12 @@ def _clear_task_state() -> None:
311333
if not await self._authenticate():
312334
return
313335
self._connection_attempted.set()
314-
self._logger.info(f"Websocket session established ({self._ws_url})")
336+
self._logger.info("Websocket session established")
315337
# Task will run until message reception fails or is cancelled
316338
try:
317339
await self._receive_messages()
318340
finally:
319-
self._logger.info("Websocket demultiplexing task terminated.")
341+
self._logger.debug("Websocket demultiplexing task terminated.")
320342
# Notify foreground thread of background thread termination
321343
# (this covers termination due to link failure)
322344
await self.notify_client_termination()

0 commit comments

Comments
 (0)