diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2753dff..6e8ea34 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,8 +41,10 @@ jobs: max-parallel: 8 matrix: python-version: ["3.10", "3.11", "3.12", "3.13"] - # There's no platform specific SDK code, but explicitly check Windows - # to ensure there aren't any inadvertent POSIX-only assumptions + # While the main SDK is platform independent, the subprocess execution + # in the plugin runner and tests requires some Windows-specific code + # Note: a green tick in CI is currently misleading due to + # https://github.com/lmstudio-ai/lmstudio-python/issues/140 os: [ubuntu-22.04, windows-2022] # Check https://github.com/actions/action-versions/tree/main/config/actions diff --git a/src/lmstudio/plugin/_dev_runner.py b/src/lmstudio/plugin/_dev_runner.py index 4e1b75d..cd3a952 100644 --- a/src/lmstudio/plugin/_dev_runner.py +++ b/src/lmstudio/plugin/_dev_runner.py @@ -1,14 +1,17 @@ """Plugin dev client implementation.""" import asyncio +import io import os +import signal import subprocess import sys from contextlib import asynccontextmanager -from pathlib import Path from functools import partial -from typing import AsyncGenerator, Iterable, TypeAlias +from pathlib import Path +from threading import Event as SyncEvent +from typing import Any, AsyncGenerator, Iterable, TypeAlias from typing_extensions import ( # Native in 3.11+ @@ -115,6 +118,7 @@ async def register_dev_plugin(self) -> AsyncGenerator[tuple[str, str], None]: async def _run_plugin_task( self, result_queue: asyncio.Queue[int], debug: bool = False ) -> None: + notify_subprocess_thread = SyncEvent() async with self.register_dev_plugin() as (client_id, client_key): wait_for_subprocess = asyncio.ensure_future( asyncio.to_thread( @@ -123,7 +127,8 @@ async def _run_plugin_task( self._plugin_path, client_id, client_key, - debug, + notify_subprocess_thread, + debug=debug, ) ) ) @@ -131,10 +136,11 @@ async def _run_plugin_task( result = await wait_for_subprocess except asyncio.CancelledError: # Likely a Ctrl-C press, which is the expected termination process + notify_subprocess_thread.set() result_queue.put_nowait(0) raise # Subprocess terminated, pass along its return code in the parent process - await result_queue.put(result.returncode) + await result_queue.put(result) async def run_plugin( self, *, allow_local_imports: bool = True, debug: bool = False @@ -149,24 +155,82 @@ async def run_plugin( return await result_queue.get() +def _get_creation_flags() -> int: + if sys.platform == "win32": + return subprocess.CREATE_NEW_PROCESS_GROUP + return 0 + + +def _start_child_process( + command: list[str], *, text: bool | None = True, **kwds: Any +) -> subprocess.Popen[str]: + creationflags = kwds.pop("creationflags", 0) + creationflags |= _get_creation_flags() + return subprocess.Popen(command, text=text, creationflags=creationflags, **kwds) + + +def _get_interrupt_signal() -> signal.Signals: + if sys.platform == "win32": + return signal.CTRL_C_EVENT + return signal.SIGINT + + +_PLUGIN_INTERRUPT_SIGNAL = _get_interrupt_signal() +_PLUGIN_STATUS_POLL_INTERVAL = 1 +_PLUGIN_STOP_TIMEOUT = 2 + + +def _interrupt_child_process(process: subprocess.Popen[Any], timeout: float) -> int: + process.send_signal(_PLUGIN_INTERRUPT_SIGNAL) + try: + return process.wait(timeout) + except TimeoutError: + process.kill() + raise + + # TODO: support the same source code change monitoring features as `lms dev` def _run_plugin_in_child_process( - plugin_path: Path, client_id: str, client_key: str, debug: bool = False -) -> subprocess.CompletedProcess[str]: + plugin_path: Path, + client_id: str, + client_key: str, + abort_event: SyncEvent, + *, + debug: bool = False, +) -> int: env = os.environ.copy() env[ENV_CLIENT_ID] = client_id env[ENV_CLIENT_KEY] = client_key package_name = __spec__.parent assert package_name is not None debug_option = ("--debug",) if debug else () + # If stdout is unbuffered, specify the same in the child process + stdout = sys.__stdout__ + unbuffered_arg: tuple[str, ...] + if stdout is None or not isinstance(stdout.buffer, io.BufferedWriter): + unbuffered_arg = ("-u",) + else: + unbuffered_arg = () + command: list[str] = [ sys.executable, + *unbuffered_arg, "-m", package_name, *debug_option, os.fspath(plugin_path), ] - return subprocess.run(command, text=True, env=env) + process = _start_child_process(command, env=env) + while True: + result = process.poll() + if result is not None: + print("Child process terminated unexpectedly") + break + if abort_event.wait(_PLUGIN_STATUS_POLL_INTERVAL): + print("Gracefully terminating child process...") + result = _interrupt_child_process(process, _PLUGIN_STOP_TIMEOUT) + break + return result async def run_plugin_async( diff --git a/src/lmstudio/plugin/cli.py b/src/lmstudio/plugin/cli.py index 5a7a9fe..f666b51 100644 --- a/src/lmstudio/plugin/cli.py +++ b/src/lmstudio/plugin/cli.py @@ -41,16 +41,29 @@ def main(argv: Sequence[str] | None = None) -> int: parser.print_usage() print(f"ERROR: Failed to find plugin folder at {plugin_path!r}") return 1 - warnings.filterwarnings( - "ignore", ".*the plugin API is not yet stable", FutureWarning - ) log_level = logging.DEBUG if args.debug else logging.INFO logging.basicConfig(level=log_level) + if sys.platform == "win32": + # Accept Ctrl-C events even in non-default process groups + # (allows for graceful termination when Ctrl-C is received + # from a controlling process rather than from a console) + # Based on https://github.com/python/cpython/blob/3.14/Lib/test/win_console_handler.py + # and https://stackoverflow.com/questions/35772001/how-can-i-handle-a-signal-sigint-on-a-windows-os-machine/35792192#35792192 + from ctypes import c_void_p, windll, wintypes + + SetConsoleCtrlHandler = windll.kernel32.SetConsoleCtrlHandler + SetConsoleCtrlHandler.argtypes = (c_void_p, wintypes.BOOL) + SetConsoleCtrlHandler.restype = wintypes.BOOL + if not SetConsoleCtrlHandler(None, 0): + print("Failed to enable Ctrl-C events, termination may be abrupt") if not args.dev: + warnings.filterwarnings( + "ignore", ".*the plugin API is not yet stable", FutureWarning + ) try: runner.run_plugin(plugin_path, allow_local_imports=True) except KeyboardInterrupt: - print("Plugin execution terminated with Ctrl-C") + print("Plugin execution terminated by console interrupt", flush=True) else: # Retrieve args from API host, spawn plugin in subprocess try: diff --git a/src/lmstudio/plugin/runner.py b/src/lmstudio/plugin/runner.py index 8511a65..129a74e 100644 --- a/src/lmstudio/plugin/runner.py +++ b/src/lmstudio/plugin/runner.py @@ -228,7 +228,9 @@ async def run_plugin(self, *, allow_local_imports: bool = False) -> int: await asyncio.gather(*(e.wait() for e in hook_ready_events)) await self.plugins.remote_call("pluginInitCompleted") # Indicate that prompt processing is ready - print(f"Plugin {plugin!r} running, press Ctrl-C to terminate...") + print( + f"Plugin {plugin!r} running, press Ctrl-C to terminate...", flush=True + ) # Task group will wait for the plugins to run return 0 diff --git a/tests/test_plugin_examples.py b/tests/test_plugin_examples.py new file mode 100644 index 0000000..44f7dab --- /dev/null +++ b/tests/test_plugin_examples.py @@ -0,0 +1,130 @@ +"""Test plugin examples can run as dev plugins.""" + +import subprocess +import sys +import time + + +from pathlib import Path +from queue import Empty, Queue +from threading import Thread +from typing import Iterable, TextIO + +import pytest + +from lmstudio.plugin._dev_runner import ( + _interrupt_child_process, + _start_child_process, + _PLUGIN_STOP_TIMEOUT, +) +from lmstudio.plugin.runner import _PLUGIN_API_STABILITY_WARNING + + +_THIS_DIR = Path(__file__).parent.resolve() +_PLUGIN_EXAMPLES_DIR = (_THIS_DIR / "../examples/plugins").resolve() + + +def _get_plugin_paths() -> list[Path]: + return [p for p in _PLUGIN_EXAMPLES_DIR.iterdir() if p.is_dir()] + + +def _monitor_stream(stream: TextIO, queue: Queue[str], *, debug: bool = False) -> None: + for line in stream: + if debug: + print(line) + queue.put(line) + + +def _drain_queue(queue: Queue[str]) -> Iterable[str]: + while True: + try: + yield queue.get(block=False) + except Empty: + break + + +def _exec_plugin(plugin_path: Path) -> subprocess.Popen[str]: + # Run plugin in dev mode with IO pipes line buffered + # (as the test process is monitoring for specific output) + cmd = [ + sys.executable, + "-u", + "-m", + "lmstudio.plugin", + "--dev", + str(plugin_path), + ] + return _start_child_process(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + +_PLUGIN_START_TIMEOUT = 5 + + +def _exec_and_interrupt(plugin_path: Path) -> tuple[list[str], list[str], list[str]]: + # Start the plugin in a child process + process = _exec_plugin(plugin_path) + # Ensure pipes don't fill up and block subprocess execution + stdout_q: Queue[str] = Queue() + stdout_thread = Thread(target=_monitor_stream, args=[process.stdout, stdout_q]) + stdout_thread.start() + stderr_q: Queue[str] = Queue() + stderr_thread = Thread(target=_monitor_stream, args=[process.stderr, stderr_q]) + stderr_thread.start() + startup_lines: list[str] = [] + # Wait for plugin to start + start_deadline = time.monotonic() + _PLUGIN_START_TIMEOUT + try: + print(f"Monitoring {stdout_q!r} for plugin started message") + while True: + remaining_time = start_deadline - time.monotonic() + print(f"Waiting {remaining_time} seconds for plugin to start") + try: + line = stdout_q.get(timeout=remaining_time) + except Empty: + assert False, "Plugin subprocess failed to start" + print(line) + startup_lines.append(line) + if "Ctrl-C to terminate" in line: + break + finally: + # Instruct the process to terminate + print("Sending termination request to plugin subprocess") + stop_deadline = time.monotonic() + _PLUGIN_STOP_TIMEOUT + _interrupt_child_process(process, (stop_deadline - time.monotonic())) + # Give threads a chance to halt their file reads + # (process terminating will close the pipes) + stdout_thread.join(timeout=(stop_deadline - time.monotonic())) + stderr_thread.join(timeout=(stop_deadline - time.monotonic())) + with process: + # Closes open pipes + pass + # Collect remainder of subprocess output + shutdown_lines = [*_drain_queue(stdout_q)] + stderr_lines = [*_drain_queue(stderr_q)] + return startup_lines, shutdown_lines, stderr_lines + + +def _plugin_case_id(plugin_path: Path) -> str: + return plugin_path.name + + +@pytest.mark.lmstudio +@pytest.mark.parametrize("plugin_path", _get_plugin_paths(), ids=_plugin_case_id) +def test_plugin_execution(plugin_path: Path) -> None: + startup_lines, shutdown_lines, stderr_lines = _exec_and_interrupt(plugin_path) + # Stderr should start with the API stability warning... + warning_lines = [ + *_PLUGIN_API_STABILITY_WARNING.splitlines(keepends=True), + "\n", + "warnings.warn(_PLUGIN_API_STABILITY_WARNING, FutureWarning)\n", + ] + for warning_line in warning_lines: + stderr_line = stderr_lines.pop(0) + assert stderr_line.endswith(warning_line) + # ... and then consist solely of logged information messages + for log_line in stderr_lines: + assert log_line.startswith("INFO:") + # Startup should finish with the notification of how to terminate the dev plugin + assert startup_lines[-1].endswith("Ctrl-C to terminate...\n") + # Shutdown should finish with a graceful shutdown notice from the plugin runner + assert shutdown_lines[-1] == "Plugin execution terminated by console interrupt\n"