Skip to content

Commit 850ba9c

Browse files
committed
Forward parent process interrupts to child process
1 parent 6a49231 commit 850ba9c

File tree

4 files changed

+79
-36
lines changed

4 files changed

+79
-36
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ jobs:
4141
max-parallel: 8
4242
matrix:
4343
python-version: ["3.10", "3.11", "3.12", "3.13"]
44-
# There's no platform specific SDK code, but explicitly check Windows
45-
# to ensure there aren't any inadvertent POSIX-only assumptions
44+
# While the main SDK is platform independent, the subprocess execution
45+
# in the plugin runner and tests requires some Windows-specific code
46+
# Note: a green tick in CI is currently misleading due to
47+
# https://github.com/lmstudio-ai/lmstudio-python/issues/140
4648
os: [ubuntu-22.04, windows-2022]
4749

4850
# Check https://github.com/actions/action-versions/tree/main/config/actions

src/lmstudio/plugin/_dev_runner.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
import asyncio
44
import io
55
import os
6+
import signal
67
import subprocess
78
import sys
89

910
from contextlib import asynccontextmanager
10-
from pathlib import Path
1111
from functools import partial
12-
from typing import AsyncGenerator, Iterable, TypeAlias
12+
from pathlib import Path
13+
from threading import Event as SyncEvent
14+
from typing import Any, AsyncGenerator, Iterable, TypeAlias
1315

1416
from typing_extensions import (
1517
# Native in 3.11+
@@ -116,6 +118,7 @@ async def register_dev_plugin(self) -> AsyncGenerator[tuple[str, str], None]:
116118
async def _run_plugin_task(
117119
self, result_queue: asyncio.Queue[int], debug: bool = False
118120
) -> None:
121+
notify_subprocess_thread = SyncEvent()
119122
async with self.register_dev_plugin() as (client_id, client_key):
120123
wait_for_subprocess = asyncio.ensure_future(
121124
asyncio.to_thread(
@@ -124,18 +127,20 @@ async def _run_plugin_task(
124127
self._plugin_path,
125128
client_id,
126129
client_key,
127-
debug,
130+
notify_subprocess_thread,
131+
debug=debug,
128132
)
129133
)
130134
)
131135
try:
132136
result = await wait_for_subprocess
133137
except asyncio.CancelledError:
134138
# Likely a Ctrl-C press, which is the expected termination process
139+
notify_subprocess_thread.set()
135140
result_queue.put_nowait(0)
136141
raise
137142
# Subprocess terminated, pass along its return code in the parent process
138-
await result_queue.put(result.returncode)
143+
await result_queue.put(result)
139144

140145
async def run_plugin(
141146
self, *, allow_local_imports: bool = True, debug: bool = False
@@ -150,10 +155,49 @@ async def run_plugin(
150155
return await result_queue.get()
151156

152157

158+
def _get_creation_flags() -> int:
159+
if sys.platform == "win32":
160+
return subprocess.CREATE_NEW_PROCESS_GROUP
161+
return 0
162+
163+
164+
def _start_child_process(
165+
command: list[str], *, text: bool | None = True, **kwds: Any
166+
) -> subprocess.Popen[str]:
167+
creationflags = kwds.pop("creationflags", 0)
168+
creationflags |= _get_creation_flags()
169+
return subprocess.Popen(command, text=text, creationflags=creationflags, **kwds)
170+
171+
172+
def _get_interrupt_signal() -> signal.Signals:
173+
if sys.platform == "win32":
174+
return signal.CTRL_C_EVENT
175+
return signal.SIGINT
176+
177+
178+
_PLUGIN_INTERRUPT_SIGNAL = _get_interrupt_signal()
179+
_PLUGIN_STATUS_POLL_INTERVAL = 1
180+
_PLUGIN_STOP_TIMEOUT = 2
181+
182+
183+
def _interrupt_child_process(process: subprocess.Popen[Any], timeout: float) -> int:
184+
process.send_signal(_PLUGIN_INTERRUPT_SIGNAL)
185+
try:
186+
return process.wait(timeout)
187+
except TimeoutError:
188+
process.kill()
189+
raise
190+
191+
153192
# TODO: support the same source code change monitoring features as `lms dev`
154193
def _run_plugin_in_child_process(
155-
plugin_path: Path, client_id: str, client_key: str, debug: bool = False
156-
) -> subprocess.CompletedProcess[str]:
194+
plugin_path: Path,
195+
client_id: str,
196+
client_key: str,
197+
abort_event: SyncEvent,
198+
*,
199+
debug: bool = False,
200+
) -> int:
157201
env = os.environ.copy()
158202
env[ENV_CLIENT_ID] = client_id
159203
env[ENV_CLIENT_KEY] = client_key
@@ -176,7 +220,17 @@ def _run_plugin_in_child_process(
176220
*debug_option,
177221
os.fspath(plugin_path),
178222
]
179-
return subprocess.run(command, text=True, env=env)
223+
process = _start_child_process(command, env=env)
224+
while True:
225+
result = process.poll()
226+
if result is not None:
227+
print("Child process terminated unexpectedly")
228+
break
229+
if abort_event.wait(_PLUGIN_STATUS_POLL_INTERVAL):
230+
print("Gracefully terminating child process...")
231+
result = _interrupt_child_process(process, _PLUGIN_STOP_TIMEOUT)
232+
break
233+
return result
180234

181235

182236
async def run_plugin_async(

src/lmstudio/plugin/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def main(argv: Sequence[str] | None = None) -> int:
5050
try:
5151
runner.run_plugin(plugin_path, allow_local_imports=True)
5252
except KeyboardInterrupt:
53-
print("Plugin execution terminated by user", flush=True)
53+
print("Plugin execution terminated by console interrupt", flush=True)
5454
else:
5555
# Retrieve args from API host, spawn plugin in subprocess
5656
try:

tests/test_plugin_examples.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Test plugin examples can run as dev plugins."""
22

3-
import signal
43
import subprocess
54
import sys
65
import time
@@ -13,6 +12,12 @@
1312

1413
import pytest
1514

15+
from lmstudio.plugin._dev_runner import (
16+
_interrupt_child_process,
17+
_start_child_process,
18+
_PLUGIN_STOP_TIMEOUT,
19+
)
20+
1621
_THIS_DIR = Path(__file__).parent.resolve()
1722
_PLUGIN_EXAMPLES_DIR = (_THIS_DIR / "../examples/plugins").resolve()
1823

@@ -47,30 +52,18 @@ def _exec_plugin(plugin_path: Path) -> subprocess.Popen[str]:
4752
"--dev",
4853
str(plugin_path),
4954
]
50-
return subprocess.Popen(
51-
cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=1
52-
)
55+
return _start_child_process(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
5356

5457

5558
_PLUGIN_START_TIMEOUT = 5
56-
_PLUGIN_STOP_TIMEOUT = 5
57-
58-
59-
def _get_interrupt_signal() -> signal.Signals:
60-
if sys.platform == "win32":
61-
return signal.CTRL_C_EVENT
62-
return signal.SIGINT
6359

6460

65-
_INTERRUPT_SIGNAL = _get_interrupt_signal()
66-
6761
def _exec_and_interrupt(plugin_path: Path) -> tuple[list[str], list[str], list[str]]:
62+
# Start the plugin in a child process
6863
process = _exec_plugin(plugin_path)
6964
# Ensure pipes don't fill up and block subprocess execution
7065
stdout_q: Queue[str] = Queue()
71-
stdout_thread = Thread(
72-
target=_monitor_stream, args=[process.stdout, stdout_q], kwargs={"debug": True}
73-
)
66+
stdout_thread = Thread(target=_monitor_stream, args=[process.stdout, stdout_q])
7467
stdout_thread.start()
7568
stderr_q: Queue[str] = Queue()
7669
stderr_thread = Thread(target=_monitor_stream, args=[process.stderr, stderr_q])
@@ -94,15 +87,14 @@ def _exec_and_interrupt(plugin_path: Path) -> tuple[list[str], list[str], list[s
9487
finally:
9588
# Instruct the process to terminate
9689
print("Sending termination request to plugin subprocess")
97-
process.send_signal(_INTERRUPT_SIGNAL)
90+
stop_deadline = time.monotonic() + _PLUGIN_STOP_TIMEOUT
91+
_interrupt_child_process(process, (stop_deadline - time.monotonic()))
9892
# Give threads a chance to halt their file reads
9993
# (process terminating will close the pipes)
100-
stop_deadline = time.monotonic() + _PLUGIN_STOP_TIMEOUT
10194
stdout_thread.join(timeout=(stop_deadline - time.monotonic()))
10295
stderr_thread.join(timeout=(stop_deadline - time.monotonic()))
103-
process.wait(timeout=(stop_deadline - time.monotonic()))
10496
with process:
105-
# Close pipes
97+
# Closes open pipes
10698
pass
10799
# Collect remainder of subprocess output
108100
shutdown_lines = [*_drain_queue(stdout_q)]
@@ -121,9 +113,4 @@ def test_plugin_execution(plugin_path: Path) -> None:
121113
for log_line in stderr_lines:
122114
assert log_line.startswith("INFO:")
123115
assert startup_lines[-1].endswith("Ctrl-C to terminate...\n")
124-
# Outside an actual terminal, pipe may be closed before the termination is reported
125-
# TODO: Consider migrating to using pexpect, so this check can be more robust
126-
assert (
127-
not shutdown_lines
128-
or shutdown_lines[-1] == "Plugin execution terminated by user\n"
129-
)
116+
assert shutdown_lines[-1] == "Plugin execution terminated by console interrupt\n"

0 commit comments

Comments
 (0)