Skip to content

Commit 5ce3aa2

Browse files
authored
fix: cluster multiple ray instance support (RLinf#376)
Signed-off-by: Hao Lin <linhaomails@gmail.com>
1 parent 9a71046 commit 5ce3aa2

File tree

14 files changed

+70
-31
lines changed

14 files changed

+70
-31
lines changed

examples/coding_online_rl/main_coding_online_rl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main(cfg) -> None:
4040
cfg = validate_cfg(cfg)
4141
print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2))
4242

43-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
43+
cluster = Cluster(cluster_cfg=cfg.cluster)
4444
component_placement = ModelParallelComponentPlacement(cfg, cluster)
4545

4646
singleton_placement_strategy = PackedPlacementStrategy(

examples/coding_online_rl/main_coding_rl_llm_judge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main(cfg) -> None:
4040
cfg = validate_cfg(cfg)
4141
print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2))
4242

43-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
43+
cluster = Cluster(cluster_cfg=cfg.cluster)
4444
component_placement = ModelParallelComponentPlacement(cfg, cluster)
4545

4646
rollout_worker_cls = get_rollout_backend_worker(cfg, component_placement)

examples/embodiment/eval_embodied_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main(cfg) -> None:
3232
cfg = validate_cfg(cfg)
3333
cfg.runner.only_eval = True
3434

35-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
35+
cluster = Cluster(cluster_cfg=cfg.cluster)
3636
component_placement = HybridComponentPlacement(cfg, cluster)
3737

3838
# Create rollout worker group

examples/embodiment/train_embodied_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def main(cfg) -> None:
3636
cfg = validate_cfg(cfg)
3737
print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2))
3838

39-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
39+
cluster = Cluster(cluster_cfg=cfg.cluster)
4040
component_placement = HybridComponentPlacement(cfg, cluster)
4141

4242
# Create actor worker group

examples/multiturn_demo/main_mcp_with_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def main(cfg) -> None:
4343
cfg = validate_cfg(cfg)
4444
print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2))
4545

46-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
46+
cluster = Cluster(cluster_cfg=cfg.cluster)
4747
component_placement = ModelParallelComponentPlacement(cfg, cluster)
4848

4949
# Generator group

examples/multiturn_demo/main_tool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def main(cfg) -> None:
4343
cfg = validate_cfg(cfg)
4444
print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2))
4545

46-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
46+
cluster = Cluster(cluster_cfg=cfg.cluster)
4747
component_placement = ModelParallelComponentPlacement(cfg, cluster)
4848

4949
# Generator group

examples/reasoning/main_grpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main(cfg) -> None:
4141
cfg = validate_cfg(cfg)
4242
print(json.dumps(OmegaConf.to_container(cfg, resolve=True), indent=2))
4343

44-
cluster = Cluster(num_nodes=cfg.cluster.num_nodes)
44+
cluster = Cluster(cluster_cfg=cfg.cluster)
4545
component_placement = ModelParallelComponentPlacement(cfg, cluster)
4646

4747
rollout_worker_cls = get_rollout_backend_worker(cfg, component_placement)

rlinf/scheduler/cluster/cluster.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ class Cluster:
7777
ClusterEnvVar.COMM_NET_DEVICES: None,
7878
}
7979

80+
class NamespaceConflictError(Exception):
81+
"""Raised when there is a namespace conflict in Ray initialization."""
82+
8083
@classmethod
8184
def find_free_port(cls):
8285
"""Find a free port on the node."""
@@ -109,16 +112,26 @@ def __init__(
109112
"""
110113
if self._has_initialized:
111114
return
115+
self._setup_logger()
112116
if num_nodes is not None or cluster_cfg is not None:
113117
self._ray_instance_count = 0
114-
self._init_and_launch_managers(num_nodes, cluster_cfg)
118+
while True:
119+
try:
120+
self._init_and_launch_managers(num_nodes, cluster_cfg)
121+
break
122+
except Cluster.NamespaceConflictError:
123+
# Switch the namespace when multiple ray instances are created in the same node
124+
self._ray_instance_count += 1
125+
self._logger.info(
126+
f"Ray namespace conflict detected. Retrying to initialize Cluster with a new namespace (attempt {self._ray_instance_count})."
127+
)
128+
Cluster.NAMESPACE = f"{Cluster.SYS_NAME}_{self._ray_instance_count}"
129+
continue
115130
else:
116131
self._init_from_existing_managers()
117132
self._has_initialized = True
118133

119-
def _init_and_launch_managers(
120-
self, num_nodes: int, cluster_cfg: Optional[DictConfig]
121-
):
134+
def _setup_logger(self):
122135
# Add logger
123136
self._logger = logging.getLogger(Cluster.SYS_NAME)
124137
self._logger.setLevel(Cluster.LOGGING_LEVEL)
@@ -133,6 +146,9 @@ def _init_and_launch_managers(
133146
handler.setFormatter(formatter)
134147
self._logger.addHandler(handler)
135148

149+
def _init_and_launch_managers(
150+
self, num_nodes: int, cluster_cfg: Optional[DictConfig]
151+
):
136152
if ray.is_initialized():
137153
if self._ray_instance_count > 0:
138154
# For reinit Ray to switch namespace
@@ -235,10 +251,7 @@ def _init_and_launch_managers(
235251
.remote()
236252
)
237253
except ValueError:
238-
# If the WorkerManager is already running, we need to switch the namespace
239-
self._ray_instance_count += 1
240-
Cluster.NAMESPACE = f"RLinf_{self._ray_instance_count}"
241-
return self._init_and_launch_managers(num_nodes)
254+
raise Cluster.NamespaceConflictError
242255

243256
def signal_handler(sig, frame):
244257
# Exit the main process if SIGUSR1 is received, which is sent by the worker group when an exception occurs.

rlinf/scheduler/cluster/config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,14 @@ def from_dict_cfg(cfg_dict: DictConfig) -> "ClusterConfig":
273273
Returns:
274274
ClusterConfig: The created ClusterConfig instance.
275275
"""
276-
dataclass_arg_check(
277-
ClusterConfig, cfg_dict, error_suffix="in cluster yaml config"
276+
_, _, valid_args = dataclass_arg_check(
277+
ClusterConfig,
278+
cfg_dict,
279+
no_check_unknown=True,
280+
error_suffix="in cluster yaml config",
278281
)
279-
return ClusterConfig(**cfg_dict)
282+
valid_cfg_dict = {key: cfg_dict[key] for key in valid_args if key in cfg_dict}
283+
return ClusterConfig(**valid_cfg_dict)
280284

281285
def get_node_labels_by_rank(self, node_rank: int) -> list[str]:
282286
"""Get the node group labels for a given node rank.

rlinf/scheduler/cluster/node.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,12 +284,15 @@ def __init__(self, cluster_num_nodes: int, cluster_cfg: Optional[ClusterConfig])
284284
num_nodes = len(node_infos)
285285
for node_info in node_infos:
286286
node_ray_id = node_info["NodeID"]
287-
probe = _RemoteNodeProbe.options(
288-
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
289-
node_id=node_ray_id, soft=False
290-
),
291-
name=f"NodeProbe_{node_ray_id}",
292-
).remote(node_info, num_nodes, cluster_cfg, sys.executable)
287+
try:
288+
probe = _RemoteNodeProbe.options(
289+
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
290+
node_id=node_ray_id, soft=False
291+
),
292+
name=f"NodeProbe_{node_ray_id}",
293+
).remote(node_info, num_nodes, cluster_cfg, sys.executable)
294+
except ValueError:
295+
raise Cluster.NamespaceConflictError
293296
self._probes.append(probe)
294297

295298
handles = []

0 commit comments

Comments
 (0)