@@ -102,6 +102,7 @@ class HvConfig(BaseConfig):
102
102
skip_load_from_peers : bool = False
103
103
world_rank : int
104
104
galaxy_size : int
105
+ fail_rank_drop : bool = False # fail if we lose a diloco worker
105
106
106
107
@model_validator (mode = "before" )
107
108
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -369,6 +370,9 @@ def scheduler_fn(opt):
369
370
370
371
loss_batch = 0
371
372
373
+ if world_messenger_hv :
374
+ max_num_peers = 0
375
+
372
376
for step , batch in enumerate (iterable = train_dataloader , start = start_step * gradient_accumulation_steps ):
373
377
real_step = (step + 1 ) // gradient_accumulation_steps
374
378
is_accumulating = bool ((step + 1 ) % gradient_accumulation_steps )
@@ -448,6 +452,9 @@ def scheduler_fn(opt):
448
452
if world_messenger_hv :
449
453
outer_lr = [group ["lr" ] for group in optimizer .state_averager .optimizer .param_groups ][0 ]
450
454
num_peers = optimizer .tracker .global_progress .num_peers
455
+
456
+ max_num_peers = max (max_num_peers , num_peers )
457
+
451
458
if num_peers == 0 :
452
459
num_peers = 1
453
460
@@ -457,6 +464,13 @@ def scheduler_fn(opt):
457
464
if logging_activations_steps :
458
465
metrics .update (activation_monitor .log_activations )
459
466
467
+ if world_messenger_hv and num_peers < max_num_peers :
468
+ log (message = f"Lost a diloco worker, num_peers: { num_peers } , galaxy_size: { config .hv .galaxy_size } " )
469
+ if config .hv .fail_rank_drop :
470
+ raise ValueError (
471
+ f"Lost a diloco worker, num_peers: { num_peers } , galaxy_size: { config .hv .galaxy_size } "
472
+ )
473
+
460
474
current_time = time .time ()
461
475
462
476
wandb .log (metrics )
0 commit comments