Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def run(config_path: str, dlc: bool = False):
elif config.mode == "bench":
bench(config)

if dlc:
from trinity.utils.dlc_utils import stop_ray_cluster

stop_ray_cluster()


def studio(port: int = 8501):
from streamlit.web import cli as stcli
Expand Down
52 changes: 43 additions & 9 deletions trinity/utils/dlc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@

logger = get_logger(__name__)

CLUSTER_ACTOR_NAME = "cluster_status"


@ray.remote
class ClusterStatus:
def __init__(self):
self.finished = False

def finish(self) -> None:
self.finished = True

def running(self) -> bool:
return not self.finished


def get_dlc_env_vars() -> dict:
envs = {
Expand Down Expand Up @@ -71,16 +85,36 @@ def setup_ray_cluster(namespace: str):
logger.error(f"ret.stdout: {ret.stdout!r}")
logger.error(f"ret.stderr: {ret.stderr!r}")
sys.exit(1)

wait_for_ray_setup()
ray.init(
address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
namespace=namespace,
ignore_reinit_error=True,
)
if is_master:
wait_for_ray_setup()
ray.init(
address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
namespace=namespace,
ignore_reinit_error=True,
)
# master wait for worker nodes to join
wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"])
else:
# woker wait on the cluster status actor
cluster_status = ClusterStatus.options(
name=CLUSTER_ACTOR_NAME,
get_if_exists=True,
).remote()
while True:
if ray.get(cluster_status.running.remote()):
time.sleep(5)
else:
break
sys.exit(0)


if not is_master:
# woker just exit
sys.exit(0)
def stop_ray_cluster():
"""
Stop the ray cluster by sending a signal to the cluster status actor.
"""
cluster_status = ClusterStatus.options(
name=CLUSTER_ACTOR_NAME,
get_if_exists=True,
).remote()
ray.get(cluster_status.finish.remote())