|
15 | 15 | assert_never, |
16 | 16 | ) |
17 | 17 |
|
18 | | - |
19 | 18 | from .runner import ( |
20 | 19 | ENV_CLIENT_ID, |
21 | 20 | ENV_CLIENT_KEY, |
@@ -113,22 +112,41 @@ async def register_dev_plugin(self) -> AsyncGenerator[tuple[str, str], None]: |
113 | 112 | message: DevPluginRegistrationEndDict = {"type": "end"} |
114 | 113 | await channel.send_message(message) |
115 | 114 |
|
| 115 | + async def _run_plugin_task( |
| 116 | + self, result_queue: asyncio.Queue[int], debug: bool = False |
| 117 | + ) -> None: |
| 118 | + async with self.register_dev_plugin() as (client_id, client_key): |
| 119 | + wait_for_subprocess = asyncio.ensure_future( |
| 120 | + asyncio.to_thread( |
| 121 | + partial( |
| 122 | + _run_plugin_in_child_process, |
| 123 | + self._plugin_path, |
| 124 | + client_id, |
| 125 | + client_key, |
| 126 | + debug, |
| 127 | + ) |
| 128 | + ) |
| 129 | + ) |
| 130 | + try: |
| 131 | + result = await wait_for_subprocess |
| 132 | + except asyncio.CancelledError: |
| 133 | + # Likely a Ctrl-C press, which is the expected termination process |
| 134 | + result_queue.put_nowait(0) |
| 135 | + raise |
| 136 | + # Subprocess terminated, pass along its return code in the parent process |
| 137 | + await result_queue.put(result.returncode) |
| 138 | + |
116 | 139 | async def run_plugin( |
117 | 140 | self, *, allow_local_imports: bool = True, debug: bool = False |
118 | 141 | ) -> int: |
119 | 142 | if not allow_local_imports: |
120 | 143 | raise ValueError("Local imports are always permitted for dev plugins") |
121 | | - async with self.register_dev_plugin() as (client_id, client_key): |
122 | | - result = await asyncio.to_thread( |
123 | | - partial( |
124 | | - _run_plugin_in_child_process, |
125 | | - self._plugin_path, |
126 | | - client_id, |
127 | | - client_key, |
128 | | - debug, |
129 | | - ) |
130 | | - ) |
131 | | - return result.returncode |
| 144 | + result_queue: asyncio.Queue[int] = asyncio.Queue() |
| 145 | + # Run in the task manager, so this gets cleaned up before the websocket handler |
| 146 | + await self._task_manager.schedule_task( |
| 147 | + partial(self._run_plugin_task, result_queue, debug) |
| 148 | + ) |
| 149 | + return await result_queue.get() |
132 | 150 |
|
133 | 151 |
|
134 | 152 | # TODO: support the same source code change monitoring features as `lms dev` |
|
0 commit comments