Skip to content

Commit 452e79b

Browse files
amirafzalifacebook-github-bot
authored andcommitted
move monarch to HostMeshV1
Summary: 1. We can ride off SlurmJob to simplify the allocation logic 2. Some small modifications to ensure HostMeshV1 support Differential Revision: D84853200
1 parent b3be7ad commit 452e79b

File tree

2 files changed

+57
-80
lines changed

2 files changed

+57
-80
lines changed

examples/monarch/train_distributed.py

Lines changed: 55 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -14,78 +14,52 @@
1414
from typing import Dict
1515

1616
import torch
17-
from monarch._rust_bindings.monarch_hyperactor.alloc import AllocConstraints, AllocSpec
18-
from monarch._src.actor.allocator import RemoteAllocator, TorchXRemoteAllocInitializer
19-
from monarch.actor import Actor, current_rank, endpoint, ProcMesh, this_host
20-
from monarch.tools import commands
21-
from monarch.tools.components import hyperactor
22-
from monarch.tools.config import Config
17+
from monarch.actor import Actor, current_rank, endpoint, HostMesh, ProcMesh, this_host
18+
from monarch.job import SlurmJob
2319
from monarch.utils import setup_env_for_distributed
2420
from torchtitan.config import ConfigManager, JobConfig
2521
from torchtitan.tools.logging import init_logger, logger
2622
from torchtitan.train import Trainer
2723
from utils.failure import Failure, FailureActor, FailureController
2824

2925

30-
# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
26+
# ==== Allocation boilerplate ====
3127
class MonarchSlurm:
32-
# Cluster Configuration - update these values for your specific cluster
33-
machine: str = "gpu.xlarge"
34-
machine_memory: int = 2062607
3528
job_name_prefix: str = "monarch-torchft"
3629

3730
def __init__(self):
38-
self.job_handles: Dict[str, str] = {}
31+
self.job_handles: Dict[str, SlurmJob] = {}
3932
atexit.register(self.kill_jobs)
4033

41-
def get_config(self, mesh_name: str, nodes_per_mesh: int) -> Config:
42-
mesh = [f"{mesh_name}:{nodes_per_mesh}:{MonarchSlurm.machine}"]
43-
# to enable relative import of utils on actors
44-
current_dir = os.path.dirname(os.path.abspath(__file__))
45-
env = {"PYTHONPATH": current_dir}
46-
47-
appdef = hyperactor.host_mesh(meshes=mesh, env=env)
48-
49-
for role in appdef.roles:
50-
role.resource.memMB = MonarchSlurm.machine_memory
51-
52-
return Config(scheduler="slurm", appdef=appdef)
53-
54-
async def get_or_create_job(self, mesh_name: str, nodes_per_mesh: int = 1) -> None:
55-
config = self.get_config(mesh_name, nodes_per_mesh)
56-
job_name = f"{MonarchSlurm.job_name_prefix}-{mesh_name}"
57-
server_spec = await commands.get_or_create(job_name, config, force_restart=True)
58-
self.job_handles[mesh_name] = server_spec.name
34+
async def get_or_create_job(
35+
self, mesh_name: str, nodes_per_mesh: int = 1, gpus_per_node: int = 8
36+
) -> None:
37+
job = SlurmJob(
38+
meshes={mesh_name: nodes_per_mesh},
39+
gpus_per_node=gpus_per_node,
40+
job_name=f"{self.job_name_prefix}-{mesh_name}",
41+
)
42+
job.apply()
43+
self.job_handles[mesh_name] = job
5944

6045
def kill_jobs(self):
6146
for mesh_name in self.job_handles.keys():
6247
self.kill_job(mesh_name)
6348

6449
def kill_job(self, mesh_name: str):
6550
try:
66-
job_handle = self.job_handles[mesh_name]
51+
job = self.job_handles[mesh_name]
6752
logger.info(f"Destroying job for mesh {mesh_name}")
68-
commands.kill(f"slurm:///{job_handle}")
53+
job.kill()
6954
except Exception as e:
70-
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
71-
72-
def proc_mesh(
73-
self,
74-
mesh_name: str,
75-
num_hosts: int = 1,
76-
num_gpus: int = 8,
77-
) -> ProcMesh:
78-
allocator = RemoteAllocator(
79-
world_id=MonarchSlurm.job_name_prefix,
80-
initializer=TorchXRemoteAllocInitializer(
81-
f"slurm:///{self.job_handles[mesh_name]}"
82-
),
83-
)
84-
alloc = allocator.allocate(
85-
AllocSpec(AllocConstraints(), hosts=num_hosts, gpus=num_gpus)
86-
)
55+
logger.exception(f"Failed to destroy job for {mesh_name}: {e}")
8756

88-
return ProcMesh.from_alloc(alloc)
57+
def proc_mesh(self, mesh_name: str, num_procs: int) -> ProcMesh:
58+
job = self.job_handles[mesh_name]
59+
cached_path = f".{mesh_name}/job_state.pkl"
60+
mesh: HostMesh = getattr(job.state(cached_path=cached_path), mesh_name)
61+
proc_mesh = mesh.spawn_procs({"gpus": num_procs})
62+
return proc_mesh
8963

9064

9165
# ==== allocation boilerplate ====
@@ -177,34 +151,31 @@ async def start_replica(self) -> None:
177151
init_logger()
178152
logger.info(f"{self.uid} Spawning trainers")
179153

180-
trainers_proc_mesh: ProcMesh | None = None
181154
try:
182155
trainers_proc_mesh = self.scheduler.proc_mesh(
183156
f"replica_{self.replica_id}",
184-
self.spec.hosts_per_replica,
185-
self.spec.gpus_per_node,
186-
)
187-
await trainers_proc_mesh.logging_option(stream_to_client=True)
188-
await setup_env_for_distributed(trainers_proc_mesh)
189-
190-
training_actors = trainers_proc_mesh.spawn(
191-
"training_actors",
192-
TrainingActor,
193-
self.spec.job_config,
194-
self.replica_id,
157+
num_procs=self.spec.gpus_per_node,
195158
)
196159

197-
self.failure_actors = trainers_proc_mesh.spawn(
198-
"failure_actors", FailureActor
199-
)
160+
async with trainers_proc_mesh:
161+
await trainers_proc_mesh.logging_option(stream_to_client=True)
162+
await setup_env_for_distributed(trainers_proc_mesh)
200163

201-
logger.info(f"{self.uid} Starting trainers")
202-
await training_actors.start_training.call(self.spec.lighthouse_address)
203-
await trainers_proc_mesh.stop()
204-
except Exception as e:
205-
if trainers_proc_mesh:
206-
await trainers_proc_mesh.stop()
207-
raise e
164+
training_actors = trainers_proc_mesh.spawn(
165+
"training_actors",
166+
TrainingActor,
167+
self.spec.job_config,
168+
self.replica_id,
169+
)
170+
171+
self.failure_actors = trainers_proc_mesh.spawn(
172+
"failure_actors", FailureActor
173+
)
174+
175+
logger.info(f"{self.uid} Starting trainers")
176+
await training_actors.start_training.call(self.spec.lighthouse_address)
177+
except Exception:
178+
raise
208179

209180
@endpoint
210181
async def inject_failure(self, failure_type: Failure):
@@ -216,8 +187,7 @@ async def inject_failure(self, failure_type: Failure):
216187

217188
await self.failure_actors.fail.choose(failure_type)
218189
except Exception as e:
219-
error_msg = f"{self.uid} Injected failure: {e}"
220-
logger.error(error_msg)
190+
logger.exception(f"{self.uid} Injected failure: {e}")
221191
else:
222192
error_msg = f"{self.uid} No failure actors available"
223193
logger.error(error_msg)
@@ -268,7 +238,7 @@ async def start_training(self) -> None:
268238
async def start_lighthouse(self) -> None:
269239
if self.spec.remote_lighthouse:
270240
await self.scheduler.get_or_create_job("lighthouse")
271-
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_gpus=1)
241+
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_procs=1)
272242
else:
273243
self.lighthouse_mesh = this_host().spawn_procs({"gpus": 1})
274244

@@ -287,7 +257,7 @@ async def stop_lighthouse(self) -> None:
287257
await self.lighthouse_mesh.stop()
288258
logger.info("[Controller] Lighthouse stopped")
289259
except Exception as e:
290-
logger.warning(f"[Controller] Failed to stop lighthouse: {e}")
260+
logger.exception(f"[Controller] Failed to stop lighthouse: {e}")
291261

292262
async def _run_replica(self, replica_id: int, attempt_number: int) -> None:
293263
if attempt_number >= MAX_ATTEMPT:
@@ -300,7 +270,7 @@ async def _run_replica(self, replica_id: int, attempt_number: int) -> None:
300270
await self._teardown(replica_id)
301271
except Exception as e:
302272
await self._teardown(replica_id)
303-
logger.info(f"[Controller] replica {replica_id} failed: {e}")
273+
logger.exception(f"[Controller] replica {replica_id} failed: {e}")
304274
await self._run_replica(replica_id, attempt_number + 1)
305275

306276
async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> None:
@@ -332,11 +302,18 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
332302
async def _teardown(self, replica_id: int) -> None:
333303
try:
334304
replica = self.replicas[replica_id]
335-
await replica.proc_mesh.stop()
305+
try:
306+
await replica.proc_mesh.stop()
307+
except Exception as e:
308+
logger.exception(
309+
f"[Controller] Failed to stop replica {replica_id}, it may already be stopped. {e}"
310+
)
336311
del self.replicas[replica_id]
337312
del replica.proc_mesh
338313
except Exception as e:
339-
logger.error(f"[Controller] Failed to _teardown replica {replica_id}: {e}")
314+
logger.exception(
315+
f"[Controller] Failed to teardown replica {replica_id}: {e}"
316+
)
340317

341318

342319
# === CLI / CONFIG === #

examples/monarch/utils/failure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ async def execute_failures(
126126
f"[FailureController] Failure injection ({last_failure}) sent to replica {last_replica.rid}"
127127
)
128128
except Exception as e:
129-
logger.info(
129+
logger.exception(
130130
f"[FailureController] Failed to inject failure in replica {last_replica.rid}: {e}"
131131
)
132132
await asyncio.sleep(rest_time)
133133
except Exception as e:
134-
logger.info(
134+
logger.exception(
135135
f"[FailureController] Something went wrong while injecting failure: {e}"
136136
)

0 commit comments

Comments
 (0)