Skip to content

Commit fd69ba4

Browse files
authored
Add utils for PAI DLC (#38)
1 parent 2086bec commit fd69ba4

File tree

2 files changed

+102
-6
lines changed

2 files changed

+102
-6
lines changed

trinity/cli/launcher.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def activate_data_module(data_workflow_url: str, config_path: str):
148148
return
149149

150150

151-
def run(config_path: str):
151+
def run(config_path: str, dlc: bool = False):
152152
config = load_config(config_path)
153153
config.check_and_update()
154154
# try to activate data module
@@ -157,8 +157,13 @@ def run(config_path: str):
157157
data_processor_config.dj_config_path or data_processor_config.dj_process_desc
158158
):
159159
activate_data_module(data_processor_config.data_workflow_url, config_path)
160-
if not ray.is_initialized():
161-
ray.init(namespace=f"{config.monitor.project}-{config.monitor.name}")
160+
ray_namespace = f"{config.monitor.project}-{config.monitor.name}"
161+
if dlc:
162+
from trinity.utils.dlc_utils import setup_ray_cluster
163+
164+
setup_ray_cluster(namespace=ray_namespace)
165+
else:
166+
ray.init(namespace=ray_namespace, ignore_reinit_error=True)
162167
if config.mode == "explore":
163168
explore(config)
164169
elif config.mode == "train":
@@ -191,18 +196,23 @@ def main() -> None:
191196

192197
# run command
193198
run_parser = subparsers.add_parser("run", help="Run RFT process.")
194-
run_parser.add_argument("--config", type=str, required=True, help="config file path.")
199+
run_parser.add_argument("--config", type=str, required=True, help="Path to the config file.")
200+
run_parser.add_argument(
201+
"--dlc", action="store_true", help="Specify when running in Aliyun PAI DLC."
202+
)
195203

196204
# studio command
197205
studio_parser = subparsers.add_parser("studio", help="Run studio.")
198-
studio_parser.add_argument("--port", type=int, default=8501, help="studio port.")
206+
studio_parser.add_argument(
207+
"--port", type=int, default=8501, help="The port for Trinity-Studio."
208+
)
199209

200210
# TODO: add more commands like `monitor`, `label`
201211

202212
args = parser.parse_args()
203213
if args.command == "run":
204214
# TODO: support parse all args from command line
205-
run(args.config)
215+
run(args.config, args.dlc)
206216
elif args.command == "studio":
207217
studio(args.port)
208218

trinity/utils/dlc_utils.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import subprocess
3+
import sys
4+
import time
5+
6+
import ray
7+
8+
from trinity.utils.log import get_logger
9+
10+
logger = get_logger(__name__)
11+
12+
13+
def get_dlc_env_vars() -> dict:
14+
envs = {
15+
"RANK": int(os.environ.get("RANK", -1)), # type: ignore
16+
"WORLD_SIZE": int(os.environ.get("WORLD_SIZE", -1)), # type: ignore
17+
"MASTER_ADDR": os.environ.get("MASTER_ADDR", None),
18+
"MASTER_PORT": os.environ.get("MASTER_PORT", None),
19+
}
20+
for key, value in envs.items():
21+
if value is None or value == -1:
22+
logger.error(f"DLC env var `{key}` is not set.")
23+
raise ValueError(f"DLC env var `{key}` is not set.")
24+
return envs
25+
26+
27+
def is_running() -> bool:
28+
"""Check if ray cluster is running."""
29+
ret = subprocess.run("ray status", shell=True, capture_output=True)
30+
return ret.returncode == 0
31+
32+
33+
def wait_for_ray_setup() -> None:
34+
while True:
35+
if is_running():
36+
break
37+
else:
38+
logger.info("Waiting for ray cluster to be ready...")
39+
time.sleep(1)
40+
41+
42+
def wait_for_ray_worker_nodes(world_size: int) -> None:
43+
while True:
44+
alive_nodes = [node for node in ray.nodes() if node["Alive"]]
45+
if len(alive_nodes) >= world_size:
46+
break
47+
else:
48+
logger.info(
49+
f"{len(alive_nodes)} nodes have joined so far, waiting for {world_size - len(alive_nodes)} nodes..."
50+
)
51+
time.sleep(1)
52+
53+
54+
def setup_ray_cluster(namespace: str):
55+
env_vars = get_dlc_env_vars()
56+
is_master = env_vars["RANK"] == 0
57+
58+
if is_running():
59+
# reuse existing ray cluster
60+
if is_master:
61+
ray.init(namespace=namespace, ignore_reinit_error=True)
62+
else:
63+
if is_master:
64+
cmd = f"ray start --head --port={env_vars['MASTER_PORT']}"
65+
else:
66+
cmd = f"ray start --address={env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}"
67+
ret = subprocess.run(cmd, shell=True, capture_output=True)
68+
logger.info(f"Starting ray cluster: {cmd}")
69+
if ret.returncode != 0:
70+
logger.error(f"Failed to start ray cluster: {cmd}")
71+
logger.error(f"ret.stdout: {ret.stdout!r}")
72+
logger.error(f"ret.stderr: {ret.stderr!r}")
73+
sys.exit(1)
74+
if is_master:
75+
wait_for_ray_setup()
76+
ray.init(
77+
address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}",
78+
namespace=namespace,
79+
ignore_reinit_error=True,
80+
)
81+
# master wait for worker nodes to join
82+
wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"])
83+
84+
if not is_master:
85+
# woker just exit
86+
sys.exit(0)

0 commit comments

Comments
 (0)