Skip to content

Commit 926be41

Browse files
authored
Improve shutdown in dev plugin runner (#129)
1 parent 3576258 commit 926be41

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

src/lmstudio/plugin/_dev_runner.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
assert_never,
1616
)
1717

18-
1918
from .runner import (
2019
ENV_CLIENT_ID,
2120
ENV_CLIENT_KEY,
@@ -113,22 +112,41 @@ async def register_dev_plugin(self) -> AsyncGenerator[tuple[str, str], None]:
113112
message: DevPluginRegistrationEndDict = {"type": "end"}
114113
await channel.send_message(message)
115114

115+
async def _run_plugin_task(
116+
self, result_queue: asyncio.Queue[int], debug: bool = False
117+
) -> None:
118+
async with self.register_dev_plugin() as (client_id, client_key):
119+
wait_for_subprocess = asyncio.ensure_future(
120+
asyncio.to_thread(
121+
partial(
122+
_run_plugin_in_child_process,
123+
self._plugin_path,
124+
client_id,
125+
client_key,
126+
debug,
127+
)
128+
)
129+
)
130+
try:
131+
result = await wait_for_subprocess
132+
except asyncio.CancelledError:
133+
# Likely a Ctrl-C press, which is the expected termination process
134+
result_queue.put_nowait(0)
135+
raise
136+
# Subprocess terminated, pass along its return code in the parent process
137+
await result_queue.put(result.returncode)
138+
116139
async def run_plugin(
117140
self, *, allow_local_imports: bool = True, debug: bool = False
118141
) -> int:
119142
if not allow_local_imports:
120143
raise ValueError("Local imports are always permitted for dev plugins")
121-
async with self.register_dev_plugin() as (client_id, client_key):
122-
result = await asyncio.to_thread(
123-
partial(
124-
_run_plugin_in_child_process,
125-
self._plugin_path,
126-
client_id,
127-
client_key,
128-
debug,
129-
)
130-
)
131-
return result.returncode
144+
result_queue: asyncio.Queue[int] = asyncio.Queue()
145+
# Run in the task manager, so this gets cleaned up before the websocket handler
146+
await self._task_manager.schedule_task(
147+
partial(self._run_plugin_task, result_queue, debug)
148+
)
149+
return await result_queue.get()
132150

133151

134152
# TODO: support the same source code change monitoring features as `lms dev`

0 commit comments

Comments
 (0)