|
10 | 10 |
|
11 | 11 | import asyncio
|
12 | 12 | import concurrent.futures
|
| 13 | +import os |
13 | 14 | import pickle
|
14 | 15 | from collections import defaultdict
|
15 | 16 | from dataclasses import dataclass
|
16 | 17 | from datetime import datetime, timezone
|
17 | 18 | from pathlib import Path
|
18 | 19 | from typing import Any, Callable, Optional, Union, cast
|
19 | 20 |
|
| 21 | +import psutil |
20 | 22 | from tqdm.auto import tqdm
|
21 | 23 |
|
22 | 24 | from rdagent.core.conf import RD_AGENT_SETTINGS
|
@@ -347,18 +349,28 @@ async def run(self, step_n: int | None = None, loop_n: int | None = None, all_du
|
347 | 349 | 0 # if we rerun the loop, we should revert the loop index to 0 to make sure every loop is correctly kicked
|
348 | 350 | )
|
349 | 351 |
|
| 352 | + tasks: list[asyncio.Task] = [] |
350 | 353 | while True:
|
351 | 354 | try:
|
352 | 355 | # run one kickoff_loop and execute_loop
|
353 |
| - await asyncio.gather( |
354 |
| - self.kickoff_loop(), *[self.execute_loop() for _ in range(RD_AGENT_SETTINGS.get_max_parallel())] |
355 |
| - ) |
| 356 | + tasks = [ |
| 357 | + asyncio.create_task(t) |
| 358 | + for t in [ |
| 359 | + self.kickoff_loop(), |
| 360 | + *[self.execute_loop() for _ in range(RD_AGENT_SETTINGS.get_max_parallel())], |
| 361 | + ] |
| 362 | + ] |
| 363 | + await asyncio.gather(*tasks) |
356 | 364 | break
|
357 | 365 | except self.LoopResumeError as e:
|
358 | 366 | logger.warning(f"Stop all the routines and resume loop: {e}")
|
359 | 367 | self.loop_idx = 0
|
| 368 | + # cancel all previous tasks before resuming all loops. |
| 369 | + for t in tasks: |
| 370 | + t.cancel() |
360 | 371 | except self.LoopTerminationError as e:
|
361 | 372 | logger.warning(f"Reach stop criterion and stop loop: {e}")
|
| 373 | + kill_subprocesses() # NOTE: coroutine-based workflow can't automatically stop subprocesses. |
362 | 374 | break
|
363 | 375 | finally:
|
364 | 376 | self.close_pbar()
|
@@ -488,3 +500,25 @@ def __setstate__(self, state: dict[str, Any]) -> None:
|
488 | 500 | self.__dict__.update(state)
|
489 | 501 | self.queue = asyncio.Queue()
|
490 | 502 | self.semaphores = {}
|
| 503 | + |
| 504 | + |
| 505 | +def kill_subprocesses() -> None: |
| 506 | + """ |
| 507 | + Due to the coroutine-based nature of the workflow, the event loop of the main process can't |
| 508 | + stop all the subprocesses start by `curr_loop.run_in_executor`. So we need to kill them manually. |
| 509 | + Otherwise, the subprocesses will keep running in the background and the the main process keeps waiting. |
| 510 | + """ |
| 511 | + current_proc = psutil.Process(os.getpid()) |
| 512 | + for child in current_proc.children(recursive=True): |
| 513 | + try: |
| 514 | + print(f"Terminating subprocess PID {child.pid} ({child.name()})") |
| 515 | + child.terminate() |
| 516 | + except Exception as ex: |
| 517 | + print(f"Could not terminate subprocess {child.pid}: {ex}") |
| 518 | + _, alive = psutil.wait_procs(current_proc.children(recursive=True), timeout=3) |
| 519 | + for p in alive: |
| 520 | + try: |
| 521 | + print(f"Killing still alive subprocess PID {p.pid} ({p.name()})") |
| 522 | + p.kill() |
| 523 | + except Exception as ex: |
| 524 | + print(f"Could not kill subprocess {p.pid}: {ex}") |
0 commit comments