Skip to content

Commit 72a8b1e

Browse files
authored
Update config from ray cluster (#324)
1 parent 2fd62a4 commit 72a8b1e

File tree

5 files changed

+62
-10
lines changed

5 files changed

+62
-10
lines changed

tests/common/config_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ def test_config_flatten(self):
8282
self.assertIsInstance(key, str)
8383
self.assertNotIsInstance(value, dict)
8484

85+
def test_update_config_from_ray_cluster(self):
86+
config = get_template_config()
87+
config.cluster.node_num = None
88+
config.cluster.gpu_per_node = None
89+
90+
config._update_config_from_ray_cluster()
91+
self.assertEqual(config.cluster.node_num, 2)
92+
self.assertEqual(config.cluster.gpu_per_node, 2)
93+
8594
def tearDown(self):
8695
if os.path.exists(CHECKPOINT_ROOT_DIR):
8796
shutil.rmtree(CHECKPOINT_ROOT_DIR)

tests/explorer/explorer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_explorer(self):
131131

132132
def run_serve(config):
133133
config.check_and_update()
134-
run_stage(config, "auto")
134+
run_stage(config)
135135

136136

137137
def run_agent(base_url, model_path: str):

trinity/cli/launcher.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def both(config: Config) -> None:
142142
}
143143

144144

145-
def run_stage(config: Config, ray_address: str) -> None:
145+
def run_stage(config: Config) -> None:
146146
ray.init(
147-
address=ray_address,
147+
address=config.cluster.ray_address,
148148
ignore_reinit_error=True,
149149
namespace=config.ray_namespace,
150150
runtime_env={"env_vars": config.get_envs()},
@@ -168,11 +168,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
168168
load_plugins()
169169
config = load_config(config_path)
170170

171-
ray_address = "auto"
172-
173171
if dlc:
174172
cluster_namespace = f"{config.project}-{config.name}"
175-
ray_address = setup_ray_cluster(namespace=cluster_namespace)
173+
config.cluster.ray_address = setup_ray_cluster(namespace=cluster_namespace)
176174

177175
if not is_running():
178176
raise RuntimeError("Ray is not running, please start it by `ray start --head`.")
@@ -203,7 +201,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
203201
if prev_stage_checkpoint is not None:
204202
stage_config.model.model_path = prev_stage_checkpoint
205203
stage_config.check_and_update()
206-
run_stage(stage_config, ray_address=ray_address)
204+
run_stage(stage_config)
207205
logger.info(
208206
"===========================================================\n"
209207
f"> Stage {i + 1}/{len(config.stages)} finished.\n"
@@ -212,7 +210,7 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None):
212210
prev_stage_checkpoint = get_latest_hf_checkpoint_path(stage_config)
213211
else:
214212
config.check_and_update()
215-
run_stage(config, ray_address=ray_address)
213+
run_stage(config)
216214

217215
finally:
218216
if dlc:

trinity/common/config.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from enum import Enum
1010
from typing import Any, Dict, List, Optional
1111

12+
import ray
1213
from omegaconf import OmegaConf
1314

1415
from trinity.common.constants import (
@@ -360,8 +361,9 @@ class AlgorithmConfig:
360361
class ClusterConfig:
361362
"""Config for the cluster."""
362363

363-
node_num: int = 1
364-
gpu_per_node: int = 8
364+
ray_address: str = "auto"
365+
node_num: Optional[int] = None
366+
gpu_per_node: Optional[int] = None
365367

366368

367369
@Experimental
@@ -611,6 +613,44 @@ def _check_deprecated(self) -> None:
611613
"`explorer.runner_num` is deprecated, please use `explorer.runner_per_model` instead."
612614
)
613615

616+
def _update_config_from_ray_cluster(self) -> None:
617+
"""Update config if `node_num` or `gpu_per_node` are not set."""
618+
if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None:
619+
return
620+
621+
# init ray cluster to detect node_num and gpu_per_node
622+
was_initialized = ray.is_initialized()
623+
if not was_initialized:
624+
ray.init(
625+
address=self.cluster.ray_address,
626+
ignore_reinit_error=True,
627+
namespace=self.ray_namespace,
628+
)
629+
630+
alive_nodes = [n for n in ray.nodes() if n["alive"]]
631+
if not alive_nodes:
632+
raise RuntimeError("Could not find any alive nodes in the Ray cluster.")
633+
634+
# set node_num
635+
if self.cluster.node_num is None:
636+
self.cluster.node_num = len(alive_nodes)
637+
logger.info(f"Auto-detected and set node_num: {self.cluster.node_num}")
638+
639+
# set gpu_per_node
640+
if self.cluster.gpu_per_node is None:
641+
gpu_per_node = 0
642+
for node in alive_nodes:
643+
node_gpus = node.get("Resources", {}).get("GPU")
644+
if node_gpus and node_gpus > 0:
645+
gpu_per_node = int(node_gpus)
646+
break
647+
648+
self.cluster.gpu_per_node = gpu_per_node
649+
logger.info(f"Auto-detected and set gpu_per_node: {self.cluster.gpu_per_node}")
650+
651+
if not was_initialized:
652+
ray.shutdown()
653+
614654
def _check_interval(self) -> None:
615655
assert self.synchronizer.sync_interval > 0
616656

@@ -901,6 +941,9 @@ def check_and_update(self) -> Config: # noqa: C901
901941
if self.ray_namespace is None or len(self.ray_namespace) == 0:
902942
self.ray_namespace = f"{self.project}/{self.name}"
903943

944+
# check cluster infomation
945+
self._update_config_from_ray_cluster()
946+
904947
# check algorithm
905948
self._check_algorithm()
906949

trinity/common/verl_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
357357
else:
358358
rollout_gpu_num = 0
359359

360+
assert config.cluster.node_num is not None
361+
assert config.cluster.gpu_per_node is not None
360362
if config.cluster.node_num == 1:
361363
# for single node scenarios, rollout and training are on the same node
362364
self.trainer.nnodes = config.cluster.node_num

0 commit comments

Comments
 (0)