diff --git a/src/lmstudio/plugin/_dev_runner.py b/src/lmstudio/plugin/_dev_runner.py index f6c279b..4e1b75d 100644 --- a/src/lmstudio/plugin/_dev_runner.py +++ b/src/lmstudio/plugin/_dev_runner.py @@ -15,7 +15,6 @@ assert_never, ) - from .runner import ( ENV_CLIENT_ID, ENV_CLIENT_KEY, @@ -113,22 +112,41 @@ async def register_dev_plugin(self) -> AsyncGenerator[tuple[str, str], None]: message: DevPluginRegistrationEndDict = {"type": "end"} await channel.send_message(message) + async def _run_plugin_task( + self, result_queue: asyncio.Queue[int], debug: bool = False + ) -> None: + async with self.register_dev_plugin() as (client_id, client_key): + wait_for_subprocess = asyncio.ensure_future( + asyncio.to_thread( + partial( + _run_plugin_in_child_process, + self._plugin_path, + client_id, + client_key, + debug, + ) + ) + ) + try: + result = await wait_for_subprocess + except asyncio.CancelledError: + # Likely a Ctrl-C press, which is the expected termination process + result_queue.put_nowait(0) + raise + # Subprocess terminated, pass along its return code in the parent process + await result_queue.put(result.returncode) + async def run_plugin( self, *, allow_local_imports: bool = True, debug: bool = False ) -> int: if not allow_local_imports: raise ValueError("Local imports are always permitted for dev plugins") - async with self.register_dev_plugin() as (client_id, client_key): - result = await asyncio.to_thread( - partial( - _run_plugin_in_child_process, - self._plugin_path, - client_id, - client_key, - debug, - ) - ) - return result.returncode + result_queue: asyncio.Queue[int] = asyncio.Queue() + # Run in the task manager, so this gets cleaned up before the websocket handler + await self._task_manager.schedule_task( + partial(self._run_plugin_task, result_queue, debug) + ) + return await result_queue.get() # TODO: support the same source code change monitoring features as `lms dev`