|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | +import sys |
| 4 | +import time |
| 5 | + |
| 6 | +import ray |
| 7 | + |
| 8 | +from trinity.utils.log import get_logger |
| 9 | + |
| 10 | +logger = get_logger(__name__) |
| 11 | + |
| 12 | + |
| 13 | +def get_dlc_env_vars() -> dict: |
| 14 | + envs = { |
| 15 | + "RANK": int(os.environ.get("RANK", -1)), # type: ignore |
| 16 | + "WORLD_SIZE": int(os.environ.get("WORLD_SIZE", -1)), # type: ignore |
| 17 | + "MASTER_ADDR": os.environ.get("MASTER_ADDR", None), |
| 18 | + "MASTER_PORT": os.environ.get("MASTER_PORT", None), |
| 19 | + } |
| 20 | + for key, value in envs.items(): |
| 21 | + if value is None or value == -1: |
| 22 | + logger.error(f"DLC env var `{key}` is not set.") |
| 23 | + raise ValueError(f"DLC env var `{key}` is not set.") |
| 24 | + return envs |
| 25 | + |
| 26 | + |
| 27 | +def is_running() -> bool: |
| 28 | + """Check if ray cluster is running.""" |
| 29 | + ret = subprocess.run("ray status", shell=True, capture_output=True) |
| 30 | + return ret.returncode == 0 |
| 31 | + |
| 32 | + |
| 33 | +def wait_for_ray_setup() -> None: |
| 34 | + while True: |
| 35 | + if is_running(): |
| 36 | + break |
| 37 | + else: |
| 38 | + logger.info("Waiting for ray cluster to be ready...") |
| 39 | + time.sleep(1) |
| 40 | + |
| 41 | + |
| 42 | +def wait_for_ray_worker_nodes(world_size: int) -> None: |
| 43 | + while True: |
| 44 | + alive_nodes = [node for node in ray.nodes() if node["Alive"]] |
| 45 | + if len(alive_nodes) >= world_size: |
| 46 | + break |
| 47 | + else: |
| 48 | + logger.info( |
| 49 | + f"{len(alive_nodes)} nodes have joined so far, waiting for {world_size - len(alive_nodes)} nodes..." |
| 50 | + ) |
| 51 | + time.sleep(1) |
| 52 | + |
| 53 | + |
| 54 | +def setup_ray_cluster(namespace: str): |
| 55 | + env_vars = get_dlc_env_vars() |
| 56 | + is_master = env_vars["RANK"] == 0 |
| 57 | + |
| 58 | + if is_running(): |
| 59 | + # reuse existing ray cluster |
| 60 | + if is_master: |
| 61 | + ray.init(namespace=namespace, ignore_reinit_error=True) |
| 62 | + else: |
| 63 | + if is_master: |
| 64 | + cmd = f"ray start --head --port={env_vars['MASTER_PORT']}" |
| 65 | + else: |
| 66 | + cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}" |
| 67 | + ret = subprocess.run(cmd, shell=True, capture_output=True) |
| 68 | + logger.info(f"Starting ray cluster: {cmd}") |
| 69 | + if ret.returncode != 0: |
| 70 | + logger.error(f"Failed to start ray cluster: {cmd}") |
| 71 | + logger.error(f"ret.stdout: {ret.stdout!r}") |
| 72 | + logger.error(f"ret.stderr: {ret.stderr!r}") |
| 73 | + sys.exit(1) |
| 74 | + if is_master: |
| 75 | + wait_for_ray_setup() |
| 76 | + ray.init( |
| 77 | + address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", |
| 78 | + namespace=namespace, |
| 79 | + ignore_reinit_error=True, |
| 80 | + ) |
| 81 | + # master wait for worker nodes to join |
| 82 | + wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) |
| 83 | + |
| 84 | + if not is_master: |
| 85 | + # woker just exit |
| 86 | + sys.exit(0) |
0 commit comments