diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 482224918d..5175615b82 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -148,7 +148,7 @@ def activate_data_module(data_workflow_url: str, config_path: str): return -def run(config_path: str): +def run(config_path: str, dlc: bool = False): config = load_config(config_path) config.check_and_update() # try to activate data module @@ -157,8 +157,13 @@ def run(config_path: str): data_processor_config.dj_config_path or data_processor_config.dj_process_desc ): activate_data_module(data_processor_config.data_workflow_url, config_path) - if not ray.is_initialized(): - ray.init(namespace=f"{config.monitor.project}-{config.monitor.name}") + ray_namespace = f"{config.monitor.project}-{config.monitor.name}" + if dlc: + from trinity.utils.dlc_utils import setup_ray_cluster + + setup_ray_cluster(namespace=ray_namespace) + else: + ray.init(namespace=ray_namespace, ignore_reinit_error=True) if config.mode == "explore": explore(config) elif config.mode == "train": @@ -191,18 +196,23 @@ def main() -> None: # run command run_parser = subparsers.add_parser("run", help="Run RFT process.") - run_parser.add_argument("--config", type=str, required=True, help="config file path.") + run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.") + run_parser.add_argument( + "--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC." + ) # studio command studio_parser = subparsers.add_parser("studio", help="Run studio.") - studio_parser.add_argument("--port", type=int, default=8501, help="studio port.") + studio_parser.add_argument( + "--port", type=int, default=8501, help="The port for Trinity-Studio." + ) # TODO: add more commands like `monitor`, `label` args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line - run(args.config) + run(args.config, args.dlc) elif args.command == "studio": studio(args.port) diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py new file mode 100644 index 0000000000..8d7a7d3a06 --- /dev/null +++ b/trinity/utils/dlc_utils.py @@ -0,0 +1,86 @@ +import os +import subprocess +import sys +import time + +import ray + +from trinity.utils.log import get_logger + +logger = get_logger(__name__) + + +def get_dlc_env_vars() -> dict: + envs = { + "RANK": int(os.environ.get("RANK", -1)), # type: ignore + "WORLD_SIZE": int(os.environ.get("WORLD_SIZE", -1)), # type: ignore + "MASTER_ADDR": os.environ.get("MASTER_ADDR", None), + "MASTER_PORT": os.environ.get("MASTER_PORT", None), + } + for key, value in envs.items(): + if value is None or value == -1: + logger.error(f"DLC env var `{key}` is not set.") + raise ValueError(f"DLC env var `{key}` is not set.") + return envs + + +def is_running() -> bool: + """Check if ray cluster is running.""" + ret = subprocess.run("ray status", shell=True, capture_output=True) + return ret.returncode == 0 + + +def wait_for_ray_setup() -> None: + while True: + if is_running(): + break + else: + logger.info("Waiting for ray cluster to be ready...") + time.sleep(1) + + +def wait_for_ray_worker_nodes(world_size: int) -> None: + while True: + alive_nodes = [node for node in ray.nodes() if node["Alive"]] + if len(alive_nodes) >= world_size: + break + else: + logger.info( + f"{len(alive_nodes)} nodes have joined so far, waiting for {world_size - len(alive_nodes)} nodes..." + ) + time.sleep(1) + + +def setup_ray_cluster(namespace: str): + env_vars = get_dlc_env_vars() + is_master = env_vars["RANK"] == 0 + + if is_running(): + # reuse existing ray cluster + if is_master: + ray.init(namespace=namespace, ignore_reinit_error=True) + else: + if is_master: + cmd = f"ray start --head --port={env_vars['MASTER_PORT']}" + else: + cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}" + ret = subprocess.run(cmd, shell=True, capture_output=True) + logger.info(f"Starting ray cluster: {cmd}") + if ret.returncode != 0: + logger.error(f"Failed to start ray cluster: {cmd}") + logger.error(f"ret.stdout: {ret.stdout!r}") + logger.error(f"ret.stderr: {ret.stderr!r}") + sys.exit(1) + if is_master: + wait_for_ray_setup() + ray.init( + address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", + namespace=namespace, + ignore_reinit_error=True, + ) + # master wait for worker nodes to join + wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) + + if not is_master: + # woker just exit + sys.exit(0)