Skip to content

Commit eaae47a

Browse files
committed
allow to fail when loosing a diloco workers
1 parent 3e1fa45 commit eaae47a

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

open_diloco/train_fsdp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class HvConfig(BaseConfig):
102102
skip_load_from_peers: bool = False
103103
world_rank: int
104104
galaxy_size: int
105+
fail_rank_drop: bool = False # fail if we lose a diloco worker
105106

106107
@model_validator(mode="before")
107108
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -369,6 +370,9 @@ def scheduler_fn(opt):
369370

370371
loss_batch = 0
371372

373+
if world_messenger_hv:
374+
max_num_peers = 0
375+
372376
for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps):
373377
real_step = (step + 1) // gradient_accumulation_steps
374378
is_accumulating = bool((step + 1) % gradient_accumulation_steps)
@@ -448,6 +452,9 @@ def scheduler_fn(opt):
448452
if world_messenger_hv:
449453
outer_lr = [group["lr"] for group in optimizer.state_averager.optimizer.param_groups][0]
450454
num_peers = optimizer.tracker.global_progress.num_peers
455+
456+
max_num_peers = max(max_num_peers, num_peers)
457+
451458
if num_peers == 0:
452459
num_peers = 1
453460

@@ -457,6 +464,13 @@ def scheduler_fn(opt):
457464
if logging_activations_steps:
458465
metrics.update(activation_monitor.log_activations)
459466

467+
if world_messenger_hv and num_peers < max_num_peers:
468+
log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}")
469+
if config.hv.fail_rank_drop:
470+
raise ValueError(
471+
f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}"
472+
)
473+
460474
current_time = time.time()
461475

462476
wandb.log(metrics)

0 commit comments

Comments
 (0)