Skip to content

Commit e2066fe

Browse files
authored
Allow invocation from atexit hooks (#123)
1 parent 4e996ad commit e2066fe

File tree

4 files changed

+136
-51
lines changed

4 files changed

+136
-51
lines changed

misc/atexit-example.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env python
2+
"""Check accessing the default client from an atexit hook."""
3+
import atexit
4+
5+
# Intentionally named with a hyphen to ensure this can't be imported
6+
7+
import lmstudio as lms
8+
9+
# TODO: Turn this into a CI test case (perhaps via subprocess invocation?)
10+
11+
# Prior to lmstudio-python 1.5.0, the atexit hook below would hang on shutdown.
12+
# By the time atexit hooks run, asyncio.to_thread no longer works due to all
13+
# concurrent.futures managed thread pools (including those used by asyncio)
14+
# being shut down before the interpreter waits for non-daemon threads to terminate.
15+
# Since the synchronous client relied on asyncio.to_thread to deliver
16+
# messages from the async background comms thread to the blocking
17+
# foreground thread, the sync API didn't work in this scenario.
18+
# In 1.5.0, the sync message reception was reworked to queue messages entirely
19+
# in the async background thread with blocking async queue reads, eliminating
20+
# the blocking queue write operations, and allowing the client to continue running
21+
# in atexit threads
22+
@atexit.register
23+
def access_default_client():
24+
"""Ensure default client can be accessed from an atexit hook."""
25+
print("During shutdown:", end=" ", flush=True)
26+
print(lms.list_loaded_models())
27+
28+
print("Prior to shutdown:", lms.list_loaded_models())

src/lmstudio/_ws_impl.py

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import weakref
1515

1616
from concurrent.futures import Future as SyncFuture
17-
from contextlib import (
18-
AsyncExitStack,
19-
)
17+
from contextlib import AsyncExitStack
18+
from functools import partial
2019
from typing import (
2120
Any,
2221
Awaitable,
2322
Coroutine,
2423
Callable,
24+
Iterable,
2525
TypeVar,
2626
)
2727

@@ -221,6 +221,8 @@ def __init__(
221221
# Annoyingly, we have to mark the background thread as a daemon thread to
222222
# prevent hanging at shutdown. Even checking `sys.is_finalizing()` is inadequate
223223
# https://discuss.python.org/t/should-sys-is-finalizing-report-interpreter-finalization-instead-of-runtime-finalization/76695
224+
# TODO: skip thread daemonization when running in a subinterpreter
225+
# (and also disable the convenience API in subinterpreters to avoid hanging on shutdown)
224226
super().__init__(name=name, daemon=True)
225227
weakref.finalize(self, self.terminate)
226228

@@ -278,6 +280,25 @@ def call_in_background(self, func: Callable[[], Any]) -> None:
278280
self._task_manager.call_soon_threadsafe(func)
279281

280282

283+
# By default, the weakref finalization atexit hook is registered lazily.
284+
# This can lead to shutdown sequencing issues if SDK users attempt to access
285+
# client instances (such as the default sync client) from atexit hooks
286+
# registered at import time (so they may end up running after the weakref
287+
# finalization hook has already terminated background threads)
288+
# Creating this finalizer here ensures the weakref finalization hook is
289+
# registered at import time, and hence runs *after* any such hooks
290+
# (assuming the lmstudio SDK is imported before the hooks are registered)
291+
def _register_weakref_atexit_hook() -> None:
292+
class C:
293+
pass
294+
295+
weakref.finalize(C(), int)
296+
297+
298+
_register_weakref_atexit_hook()
299+
del _register_weakref_atexit_hook
300+
301+
281302
class AsyncWebsocketThread(BackgroundThread):
282303
def __init__(self, log_context: LogEventContext | None = None) -> None:
283304
super().__init__(task_target=self._log_thread_execution)
@@ -439,6 +460,20 @@ def send_json_threadsafe(self, message: DictObject) -> None:
439460
future = self._task_manager.run_coroutine_threadsafe(self.send_json(message))
440461
future.result() # Block until the message is sent
441462

463+
def run_background_coroutine(self, coro: Coroutine[Any, Any, T]) -> T:
464+
"""Run given coroutine in the event loop and wait for the result."""
465+
return self._task_manager.run_coroutine_threadsafe(coro).result()
466+
467+
def rx_queue_get_threadsafe(self, rx_queue: asyncio.Queue[Any]) -> Any:
468+
future = self._task_manager.run_coroutine_threadsafe(rx_queue.get())
469+
return future.result()
470+
471+
def rx_queue_put_threadsafe(
472+
self, rx_queue: asyncio.Queue[Any], message: Any
473+
) -> None:
474+
future = self._task_manager.run_coroutine_threadsafe(rx_queue.put(message))
475+
return future.result()
476+
442477
async def _receive_json(self) -> Any:
443478
# This is only called if the websocket has been created
444479
assert self._task_manager.check_running_in_task_loop()
@@ -502,15 +537,21 @@ def __init__(
502537
ws_thread: AsyncWebsocketThread,
503538
ws_url: str,
504539
auth_details: DictObject,
505-
enqueue_message: Callable[[DictObject | None], bool],
540+
get_queue: Callable[[DictObject | None], asyncio.Queue[Any] | None],
541+
iter_queues: Callable[[], Iterable[asyncio.Queue[Any]]],
506542
log_context: LogEventContext,
507543
) -> None:
508-
async def enqueue_async(message: DictObject | None) -> bool:
509-
return await asyncio.to_thread(enqueue_message, message)
510-
544+
self._get_queue = get_queue
545+
self._iter_queues = iter_queues
511546
self._ws_handler = AsyncWebsocketHandler(
512-
ws_thread.task_manager, ws_url, auth_details, enqueue_async, log_context
547+
ws_thread.task_manager,
548+
ws_url,
549+
auth_details,
550+
self._enqueue_message,
551+
log_context,
513552
)
553+
self._logger = logger = new_logger(type(self).__name__)
554+
logger.update_context(log_context)
514555

515556
def connect(self) -> bool:
516557
return self._ws_handler.connect_threadsafe()
@@ -521,6 +562,37 @@ def disconnect(self) -> None:
521562
def send_json(self, message: DictObject) -> None:
522563
self._ws_handler.send_json_threadsafe(message)
523564

565+
def new_rx_queue(self) -> tuple[asyncio.Queue[Any], Callable[[], Any]]:
566+
rx_queue: asyncio.Queue[Any] = asyncio.Queue()
567+
return rx_queue, partial(self._ws_handler.rx_queue_get_threadsafe, rx_queue)
568+
569+
async def _enqueue_message(self, message: Any) -> bool:
570+
rx_queue = self._get_queue(message)
571+
if message is None:
572+
return await self.notify_client_termination() > 0
573+
if rx_queue is None:
574+
return False
575+
await rx_queue.put(message)
576+
return True
577+
578+
async def notify_client_termination(self) -> int:
579+
"""Send None to all clients with open receive queues (from background thread)."""
580+
num_clients = 0
581+
for rx_queue in self._iter_queues():
582+
await rx_queue.put(None)
583+
num_clients += 1
584+
self._logger.debug(
585+
f"Notified {num_clients} clients of websocket termination",
586+
num_clients=num_clients,
587+
)
588+
return num_clients
589+
590+
def notify_client_termination_threadsafe(self) -> int:
591+
"""Send None to all clients with open receive queues (from foreground thread)."""
592+
return self._ws_handler.run_background_coroutine(
593+
self.notify_client_termination()
594+
)
595+
524596
# These attributes are currently accessed directly...
525597
@property
526598
def _ws(self) -> AsyncWebSocketSession | None:

src/lmstudio/json_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,8 @@ def _redact_json(data: DictObject | None) -> DictObject | None:
526526
return redacted
527527

528528

529+
# TODO: Now that even the sync API uses asyncio.Queue,
530+
# change the multiplexing manager to no longer be generic
529531
TQueue = TypeVar("TQueue")
530532

531533

src/lmstudio/sync_api.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Sync I/O protocol implementation for the LM Studio remote access API."""
22

3+
import asyncio
34
import itertools
45
import time
5-
import queue
66
import weakref
77

88
from abc import abstractmethod
@@ -158,14 +158,14 @@ class SyncChannel(Generic[T]):
158158
def __init__(
159159
self,
160160
channel_id: int,
161-
rx_queue: queue.Queue[Any],
161+
get_message: Callable[[], Any],
162162
endpoint: ChannelEndpoint[T, Any, Any],
163163
send_json: Callable[[DictObject], None],
164164
log_context: LogEventContext,
165165
) -> None:
166166
"""Initialize synchronous websocket streaming channel."""
167167
self._is_finished = False
168-
self._rx_queue = rx_queue
168+
self._get_message = get_message
169169
self._api_channel = ChannelHandler(channel_id, endpoint, log_context)
170170
self._send_json = send_json
171171

@@ -193,7 +193,7 @@ def rx_stream(
193193
with sdk_public_api():
194194
# Avoid emitting tracebacks that delve into supporting libraries
195195
# (we can't easily suppress the SDK's own frames for iterators)
196-
message = self._rx_queue.get()
196+
message = self._get_message()
197197
contents = self._api_channel.handle_rx_message(message)
198198
if contents is None:
199199
self._is_finished = True
@@ -216,12 +216,12 @@ class SyncRemoteCall:
216216
def __init__(
217217
self,
218218
call_id: int,
219-
rx_queue: queue.Queue[Any],
219+
get_message: Callable[[], Any],
220220
log_context: LogEventContext,
221221
notice_prefix: str = "RPC",
222222
) -> None:
223223
"""Initialize synchronous remote procedure call."""
224-
self._rx_queue = rx_queue
224+
self._get_message = get_message
225225
self._rpc = RemoteCallHandler(call_id, log_context, notice_prefix)
226226
self._logger = logger = new_logger(type(self).__name__)
227227
logger.update_context(log_context, call_id=call_id)
@@ -234,12 +234,12 @@ def get_rpc_message(
234234

235235
def receive_result(self) -> Any:
236236
"""Receive call response on the receive queue."""
237-
message = self._rx_queue.get()
237+
message = self._get_message()
238238
return self._rpc.handle_rx_message(message)
239239

240240

241241
class SyncLMStudioWebsocket(
242-
LMStudioWebsocket[SyncToAsyncWebsocketBridge, queue.Queue[Any]]
242+
LMStudioWebsocket[SyncToAsyncWebsocketBridge, asyncio.Queue[Any]]
243243
):
244244
"""Synchronous websocket client that handles demultiplexing of reply messages."""
245245

@@ -279,7 +279,8 @@ def connect(self) -> Self:
279279
self._ws_thread,
280280
self._ws_url,
281281
self._auth_details,
282-
self._enqueue_message,
282+
self._get_rx_queue,
283+
self._mux.all_queues,
283284
self._logger.event_context,
284285
)
285286
if not ws.connect():
@@ -298,42 +299,26 @@ def disconnect(self) -> None:
298299
self._ws = None
299300
if ws is not None:
300301
self._logger.debug(f"Disconnecting websocket session ({self._ws_url})")
301-
self._notify_client_termination()
302+
ws.notify_client_termination_threadsafe()
302303
ws.disconnect()
303304
self._logger.info(f"Websocket session disconnected ({self._ws_url})")
304305

305306
close = disconnect
306307

307-
def _enqueue_message(self, message: Any) -> bool:
308-
if message is None:
309-
self._logger.info(f"Websocket session failed ({self._ws_url})")
310-
self._ws = None
311-
return self._notify_client_termination() > 0
312-
rx_queue = self._mux.map_rx_message(message)
313-
if rx_queue is None:
314-
return False
315-
rx_queue.put(message)
316-
return True
317-
318-
def _notify_client_termination(self) -> int:
319-
"""Send None to all clients with open receive queues."""
320-
num_clients = 0
321-
for rx_queue in self._mux.all_queues():
322-
rx_queue.put(None)
323-
num_clients += 1
324-
self._logger.debug(
325-
f"Notified {num_clients} clients of websocket termination",
326-
num_clients=num_clients,
327-
)
328-
return num_clients
329-
330308
def _send_json(self, message: DictObject) -> None:
331309
# Callers are expected to call `_ensure_connected` before this method
332310
ws = self._ws
333311
assert ws is not None
334312
# Background thread handles the exception conversion
335313
ws.send_json(message)
336314

315+
def _get_rx_queue(self, message: Any) -> asyncio.Queue[Any] | None:
316+
if message is None:
317+
self._logger.info(f"Websocket session failed ({self._ws_url})")
318+
self._ws = None
319+
return None
320+
return self._mux.map_rx_message(message)
321+
337322
def _connect_to_endpoint(self, channel: SyncChannel[Any]) -> None:
338323
"""Connect channel to specified endpoint."""
339324
self._ensure_connected("open channel endpoints")
@@ -347,19 +332,18 @@ def open_channel(
347332
endpoint: ChannelEndpoint[T, Any, Any],
348333
) -> Generator[SyncChannel[T], None, None]:
349334
"""Open a streaming channel over the websocket."""
350-
rx_queue: queue.Queue[Any] = queue.Queue()
335+
ws = self._ws
336+
assert ws is not None
337+
rx_queue, getter = ws.new_rx_queue()
351338
with self._mux.assign_channel_id(rx_queue) as channel_id:
352339
channel = SyncChannel(
353340
channel_id,
354-
rx_queue,
341+
getter,
355342
endpoint,
356343
self._send_json,
357344
self._logger.event_context,
358345
)
359346
self._connect_to_endpoint(channel)
360-
if self._ws is None:
361-
# Link has been terminated, ensure client gets a response
362-
rx_queue.put(None)
363347
yield channel
364348

365349
def _send_call(
@@ -388,15 +372,14 @@ def remote_call(
388372
notice_prefix: str = "RPC",
389373
) -> Any:
390374
"""Make a remote procedure call over the websocket."""
391-
rx_queue: queue.Queue[Any] = queue.Queue()
375+
ws = self._ws
376+
assert ws is not None
377+
rx_queue, getter = ws.new_rx_queue()
392378
with self._mux.assign_call_id(rx_queue) as call_id:
393379
rpc = SyncRemoteCall(
394-
call_id, rx_queue, self._logger.event_context, notice_prefix
380+
call_id, getter, self._logger.event_context, notice_prefix
395381
)
396382
self._send_call(rpc, endpoint, params)
397-
if self._ws is None:
398-
# Link has been terminated, ensure client gets a response
399-
rx_queue.put(None)
400383
return rpc.receive_result()
401384

402385

0 commit comments

Comments
 (0)