diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 9dfe4df8ee..b0fb4b856f 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -187,6 +187,11 @@ def run(config_path: str, dlc: bool = False): elif config.mode == "bench": bench(config) + if dlc: + from trinity.utils.dlc_utils import stop_ray_cluster + + stop_ray_cluster() + def studio(port: int = 8501): from streamlit.web import cli as stcli diff --git a/trinity/utils/dlc_utils.py b/trinity/utils/dlc_utils.py index 8d7a7d3a06..3edcc9539f 100644 --- a/trinity/utils/dlc_utils.py +++ b/trinity/utils/dlc_utils.py @@ -9,6 +9,20 @@ logger = get_logger(__name__) +CLUSTER_ACTOR_NAME = "cluster_status" + + +@ray.remote +class ClusterStatus: + def __init__(self): + self.finished = False + + def finish(self) -> None: + self.finished = True + + def running(self) -> bool: + return not self.finished + def get_dlc_env_vars() -> dict: envs = { @@ -71,16 +85,36 @@ def setup_ray_cluster(namespace: str): logger.error(f"ret.stdout: {ret.stdout!r}") logger.error(f"ret.stderr: {ret.stderr!r}") sys.exit(1) + + wait_for_ray_setup() + ray.init( + address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", + namespace=namespace, + ignore_reinit_error=True, + ) if is_master: - wait_for_ray_setup() - ray.init( - address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", - namespace=namespace, - ignore_reinit_error=True, - ) # master wait for worker nodes to join wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) + else: + # woker wait on the cluster status actor + cluster_status = ClusterStatus.options( + name=CLUSTER_ACTOR_NAME, + get_if_exists=True, + ).remote() + while True: + if ray.get(cluster_status.running.remote()): + time.sleep(5) + else: + break + sys.exit(0) + - if not is_master: - # woker just exit - sys.exit(0) +def stop_ray_cluster(): + """ + Stop the ray cluster by sending a signal to the cluster status actor. + """ + cluster_status = ClusterStatus.options( + name=CLUSTER_ACTOR_NAME, + get_if_exists=True, + ).remote() + ray.get(cluster_status.finish.remote())