Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.

Commit 4617a01

Browse files
author
Jun Gong
authored
Allow users to specify a different cluster address when initializing Ray (#915)
* Allow users to specify a different cluster address when initializing Ray.
1 parent 1ddb2dc commit 4617a01

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

alpa/api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424

2525
def init(cluster: str = "ray",
26+
cluster_address: Optional[str] = None,
2627
num_nodes: Optional[int] = None,
2728
num_devices_per_node: Optional[int] = None,
2829
namespace: Optional[str] = "alpa_default_space"):
@@ -40,6 +41,12 @@ def init(cluster: str = "ray",
4041
Possible choices: {"local", "ray"}.
4142
"local" means using all local devices on a single node.
4243
"ray" means using all devices in a ray cluster.
44+
cluster_address: Address of the distributed cluster.
45+
If cluster is "ray", this parameter can be used to specify a different
46+
address that will be used to initialize the ray cluster.
47+
E.g., "ray://123.45.67.89:10001". If not specified, "auto" will be
48+
used instead.
49+
Ignored if cluster is "local".
4350
num_nodes: The number of nodes.
4451
num_devices_per_node: The number of devices per node.
4552
"""
@@ -49,7 +56,8 @@ def init(cluster: str = "ray",
4956
return
5057
is_initialized = True
5158

52-
init_global_cluster(cluster, num_nodes, num_devices_per_node, namespace)
59+
init_global_cluster(cluster, cluster_address, num_nodes,
60+
num_devices_per_node, namespace)
5361

5462

5563
def shutdown():

alpa/device_mesh.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2304,6 +2304,7 @@ def profile_all(self, *args, **kwargs):
23042304

23052305

23062306
def init_global_cluster(cluster: str,
2307+
cluster_address: Optional[str] = None,
23072308
num_nodes: Optional[int] = None,
23082309
num_devices_per_node: Optional[int] = None,
23092310
namespace: Optional[str] = None):
@@ -2313,7 +2314,8 @@ def init_global_cluster(cluster: str,
23132314
global_physical_mesh = LocalPhysicalDeviceMesh()
23142315
elif cluster == "ray":
23152316
if not ray.is_initialized():
2316-
ray.init(address="auto",
2317+
ray_addr = cluster_address if cluster_address else "auto"
2318+
ray.init(address=ray_addr,
23172319
ignore_reinit_error=True,
23182320
namespace=namespace)
23192321
update_jax_platform("cpu")

0 commit comments

Comments
 (0)