|
9 | 9 |
|
10 | 10 | logger = get_logger(__name__) |
11 | 11 |
|
| 12 | +CLUSTER_ACTOR_NAME = "cluster_status" |
| 13 | + |
| 14 | + |
| 15 | +@ray.remote |
| 16 | +class ClusterStatus: |
| 17 | + def __init__(self): |
| 18 | + self.finished = False |
| 19 | + |
| 20 | + def finish(self) -> None: |
| 21 | + self.finished = True |
| 22 | + |
| 23 | + def running(self) -> bool: |
| 24 | + return not self.finished |
| 25 | + |
12 | 26 |
|
13 | 27 | def get_dlc_env_vars() -> dict: |
14 | 28 | envs = { |
@@ -71,16 +85,36 @@ def setup_ray_cluster(namespace: str): |
71 | 85 | logger.error(f"ret.stdout: {ret.stdout!r}") |
72 | 86 | logger.error(f"ret.stderr: {ret.stderr!r}") |
73 | 87 | sys.exit(1) |
| 88 | + |
| 89 | + wait_for_ray_setup() |
| 90 | + ray.init( |
| 91 | + address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", |
| 92 | + namespace=namespace, |
| 93 | + ignore_reinit_error=True, |
| 94 | + ) |
74 | 95 | if is_master: |
75 | | - wait_for_ray_setup() |
76 | | - ray.init( |
77 | | - address=f"{env_vars['MASTER_ADDR']}:{env_vars['MASTER_PORT']}", |
78 | | - namespace=namespace, |
79 | | - ignore_reinit_error=True, |
80 | | - ) |
81 | 96 | # master wait for worker nodes to join |
82 | 97 | wait_for_ray_worker_nodes(env_vars["WORLD_SIZE"]) |
| 98 | + else: |
| 99 | + # woker wait on the cluster status actor |
| 100 | + cluster_status = ClusterStatus.options( |
| 101 | + name=CLUSTER_ACTOR_NAME, |
| 102 | + get_if_exists=True, |
| 103 | + ).remote() |
| 104 | + while True: |
| 105 | + if ray.get(cluster_status.running.remote()): |
| 106 | + time.sleep(5) |
| 107 | + else: |
| 108 | + break |
| 109 | + sys.exit(0) |
| 110 | + |
83 | 111 |
|
84 | | - if not is_master: |
85 | | - # woker just exit |
86 | | - sys.exit(0) |
| 112 | +def stop_ray_cluster(): |
| 113 | + """ |
| 114 | + Stop the ray cluster by sending a signal to the cluster status actor. |
| 115 | + """ |
| 116 | + cluster_status = ClusterStatus.options( |
| 117 | + name=CLUSTER_ACTOR_NAME, |
| 118 | + get_if_exists=True, |
| 119 | + ).remote() |
| 120 | + ray.get(cluster_status.finish.remote()) |
0 commit comments