|
9 | 9 | from enum import Enum |
10 | 10 | from typing import Any, Dict, List, Optional |
11 | 11 |
|
| 12 | +import ray |
12 | 13 | from omegaconf import OmegaConf |
13 | 14 |
|
14 | 15 | from trinity.common.constants import ( |
@@ -360,8 +361,9 @@ class AlgorithmConfig: |
360 | 361 | class ClusterConfig: |
361 | 362 | """Config for the cluster.""" |
362 | 363 |
|
363 | | - node_num: int = 1 |
364 | | - gpu_per_node: int = 8 |
| 364 | + ray_address: str = "auto" |
| 365 | + node_num: Optional[int] = None |
| 366 | + gpu_per_node: Optional[int] = None |
365 | 367 |
|
366 | 368 |
|
367 | 369 | @Experimental |
@@ -611,6 +613,44 @@ def _check_deprecated(self) -> None: |
611 | 613 | "`explorer.runner_num` is deprecated, please use `explorer.runner_per_model` instead." |
612 | 614 | ) |
613 | 615 |
|
| 616 | + def _update_config_from_ray_cluster(self) -> None: |
| 617 | + """Update config if `node_num` or `gpu_per_node` are not set.""" |
| 618 | + if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None: |
| 619 | + return |
| 620 | + |
| 621 | + # init ray cluster to detect node_num and gpu_per_node |
| 622 | + was_initialized = ray.is_initialized() |
| 623 | + if not was_initialized: |
| 624 | + ray.init( |
| 625 | + address=self.cluster.ray_address, |
| 626 | + ignore_reinit_error=True, |
| 627 | + namespace=self.ray_namespace, |
| 628 | + ) |
| 629 | + |
| 630 | + alive_nodes = [n for n in ray.nodes() if n["alive"]] |
| 631 | + if not alive_nodes: |
| 632 | + raise RuntimeError("Could not find any alive nodes in the Ray cluster.") |
| 633 | + |
| 634 | + # set node_num |
| 635 | + if self.cluster.node_num is None: |
| 636 | + self.cluster.node_num = len(alive_nodes) |
| 637 | + logger.info(f"Auto-detected and set node_num: {self.cluster.node_num}") |
| 638 | + |
| 639 | + # set gpu_per_node |
| 640 | + if self.cluster.gpu_per_node is None: |
| 641 | + gpu_per_node = 0 |
| 642 | + for node in alive_nodes: |
| 643 | + node_gpus = node.get("Resources", {}).get("GPU") |
| 644 | + if node_gpus and node_gpus > 0: |
| 645 | + gpu_per_node = int(node_gpus) |
| 646 | + break |
| 647 | + |
| 648 | + self.cluster.gpu_per_node = gpu_per_node |
| 649 | + logger.info(f"Auto-detected and set gpu_per_node: {self.cluster.gpu_per_node}") |
| 650 | + |
| 651 | + if not was_initialized: |
| 652 | + ray.shutdown() |
| 653 | + |
614 | 654 | def _check_interval(self) -> None: |
615 | 655 | assert self.synchronizer.sync_interval > 0 |
616 | 656 |
|
@@ -901,6 +941,9 @@ def check_and_update(self) -> Config: # noqa: C901 |
901 | 941 | if self.ray_namespace is None or len(self.ray_namespace) == 0: |
902 | 942 | self.ray_namespace = f"{self.project}/{self.name}" |
903 | 943 |
|
| 944 | + # check cluster infomation |
| 945 | + self._update_config_from_ray_cluster() |
| 946 | + |
904 | 947 | # check algorithm |
905 | 948 | self._check_algorithm() |
906 | 949 |
|
|
0 commit comments