Skip to content

Commit 911abd4

Browse files
committed
updates
1 parent f04ac56 commit 911abd4

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

examples/monarch/train_distributed.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def start_lighthouse(self) -> str:
101101
from torchft.coordination import LighthouseServer
102102

103103
self.lighthouse = LighthouseServer(
104-
bind="[::]:0", min_replicas=1, join_timeout_ms=10000
104+
bind="[::]:0", min_replicas=1, join_timeout_ms=60000
105105
)
106106
return self.lighthouse.address()
107107

@@ -184,9 +184,7 @@ async def start_replica(self) -> None:
184184
self.spec.hosts_per_replica,
185185
self.spec.gpus_per_node,
186186
)
187-
await trainers_proc_mesh.logging_option(
188-
stream_to_client=True, aggregate_window_sec=None
189-
)
187+
await trainers_proc_mesh.logging_option(stream_to_client=True)
190188
await setup_env_for_distributed(trainers_proc_mesh)
191189

192190
training_actors = trainers_proc_mesh.spawn(
@@ -228,9 +226,9 @@ async def inject_failure(self, failure_type: Failure):
228226
# delay before re-creating proc mesh on existing job. change as needed.
229227
PROC_ATTEMPT_DELAY = 0
230228
# proc attempts before getting a new scheduler allocation. change as needed.
231-
PROC_ATTEMPTS = 3
229+
PROC_ATTEMPTS = 4
232230
# attempts before failing training on replica. change as needed.
233-
MAX_ATTEMPT = PROC_ATTEMPTS * 3
231+
MAX_ATTEMPT = PROC_ATTEMPTS * 4
234232

235233

236234
class OrchestrationManager:
@@ -274,9 +272,7 @@ async def start_lighthouse(self) -> None:
274272
else:
275273
self.lighthouse_mesh = this_host().spawn_procs({"gpus": 1})
276274

277-
await self.lighthouse_mesh.logging_option(
278-
stream_to_client=True, aggregate_window_sec=None
279-
)
275+
await self.lighthouse_mesh.logging_option(stream_to_client=True)
280276
self.lighthouse_actor = self.lighthouse_mesh.spawn(
281277
"lighthouse_actor", LighthouseActor
282278
)
@@ -337,8 +333,8 @@ async def _teardown(self, replica_id: int) -> None:
337333
try:
338334
replica = self.replicas[replica_id]
339335
await replica.proc_mesh.stop()
340-
del replica.proc_mesh
341336
del self.replicas[replica_id]
337+
del replica.proc_mesh
342338
except Exception as e:
343339
logger.error(f"[Controller] Failed to _teardown replica {replica_id}: {e}")
344340

@@ -418,12 +414,17 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
418414
"--fault_tolerance.enable",
419415
"--fault_tolerance.group_size",
420416
str(args.replica_count),
417+
"--fault_tolerance.process_group",
418+
"nccl",
419+
"--fault_tolerance.process_group_timeout_ms",
420+
"60000",
421+
421422
"--parallelism.data_parallel_shard_degree",
422423
str(data_parallel_shard_degree),
423424
"--activation_checkpoint.mode",
424425
"full",
425426
"--comm.train_timeout_seconds",
426-
"60",
427+
"300",
427428
"--training.steps",
428429
str(args.training_steps),
429430
"--training.dataset",

examples/monarch/utils/failure.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,12 @@ def kill_slurm(scheduler: "MonarchSlurm") -> None:
9191
scheduler.kill_job(selected)
9292

9393
@staticmethod
94-
async def execute_failures(replicas: Dict[int, "Replica"], scheduler: "MonarchSlurm"):
95-
startup_wait = 30
96-
rest_time = 60
94+
async def execute_failures(
95+
replicas: Dict[int, "Replica"],
96+
scheduler: "MonarchSlurm",
97+
startup_wait: int = 120,
98+
rest_time: int = 120
99+
):
97100
logger.info(f"[FailureController] Starting failure injection in {startup_wait} seconds")
98101
await asyncio.sleep(startup_wait) # allow startups.
99102

@@ -102,8 +105,8 @@ async def execute_failures(replicas: Dict[int, "Replica"], scheduler: "MonarchSl
102105
try:
103106
running_replicas = list(replicas.values())
104107
# allow deadlocked replicas more time to recover
105-
if last_failure == Failure.DEADLOCK and last_replica in running_replicas:
106-
running_replicas.remove(last_replica)
108+
if last_failure == Failure.DEADLOCK and last_replica:
109+
running_replicas = [r for r in running_replicas if r.rid != last_replica.rid]
107110

108111
last_replica = random.choice(running_replicas)
109112
last_failure = random.choice(list(Failure))

0 commit comments

Comments
 (0)