Skip to content

Commit 3bd9e5b

Browse files
committed
Migrate async API to AsyncWebsocketHandler
1 parent c1796ff commit 3bd9e5b

File tree

5 files changed

+92
-172
lines changed

5 files changed

+92
-172
lines changed

src/lmstudio/_ws_impl.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Coroutine,
2121
Callable,
2222
Generator,
23+
Self,
2324
TypeAlias,
2425
TypeVar,
2526
)
@@ -45,13 +46,19 @@
4546

4647

4748
class AsyncTaskManager:
48-
def __init__(self, *, on_activation: Callable[[], Any] | None) -> None:
49+
def __init__(self, *, on_activation: Callable[[], Any] | None = None) -> None:
4950
self._activated = False
5051
self._event_loop: asyncio.AbstractEventLoop | None = None
5152
self._on_activation = on_activation
5253
self._task_queue: asyncio.Queue[Callable[[], Awaitable[Any]]] = asyncio.Queue()
5354
self._terminate = asyncio.Event()
5455
self._terminated = asyncio.Event()
56+
# For the case where the task manager is run via its context manager
57+
self._tm_started = asyncio.Event()
58+
self._tm_task: asyncio.Task[Any] | None = None
59+
60+
ACTIVATION_TIMEOUT = 5 # Just starts an async task, should be fast
61+
TERMINATION_TIMEOUT = 20 # May have to shut down TCP links
5562

5663
@property
5764
def activated(self) -> bool:
@@ -65,6 +72,20 @@ def active(self) -> bool:
6572
and not self._terminated.is_set()
6673
)
6774

75+
async def __aenter__(self) -> Self:
76+
# Handle reentrancy the same way files do:
77+
# allow nested use as a CM, but close on the first exit
78+
if self._tm_task is None:
79+
self._tm_task = asyncio.create_task(self.run_until_terminated())
80+
with move_on_after(self.ACTIVATION_TIMEOUT):
81+
await self._tm_started.wait()
82+
return self
83+
84+
async def __aexit__(self, *args: Any) -> None:
85+
await self.request_termination()
86+
with move_on_after(self.TERMINATION_TIMEOUT):
87+
await self._terminated.wait()
88+
6889
def check_running_in_task_loop(self, *, allow_inactive: bool = False) -> bool:
6990
"""Returns if running in this manager's event loop, raises RuntimeError otherwise."""
7091
this_loop = self._event_loop
@@ -138,6 +159,7 @@ def _init_event_loop(self) -> None:
138159
notify = self._on_activation
139160
if notify is not None:
140161
notify()
162+
self._tm_started.set()
141163

142164
async def run_until_terminated(
143165
self, func: Callable[[], Coroutine[Any, Any, Any]] | None = None
@@ -218,9 +240,6 @@ def call_soon_threadsafe(self, func: Callable[[], Any]) -> asyncio.Handle:
218240
AsyncRemoteCallInfo: TypeAlias = tuple[int, Callable[[], Awaitable[Any]]]
219241

220242

221-
# TODO: Improve code sharing between AsyncWebsocketHandler and
222-
# the async-native AsyncLMStudioWebsocket implementation
223-
# (likely by migrating the websocket over to using the handler)
224243
class AsyncWebsocketHandler:
225244
"""Async task handler for a single websocket connection."""
226245

0 commit comments

Comments
 (0)