@@ -92,6 +92,8 @@ class HvConfig(BaseConfig):
92
92
galaxy_size : int
93
93
fail_rank_drop : bool = False # fail if we lose a diloco worker
94
94
warmup_outerstep : int = 10
95
+ outer_lr_min : float = 0.3
96
+ outer_scheduler : bool = False
95
97
96
98
@model_validator (mode = "before" )
97
99
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -179,17 +181,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
179
181
* ,
180
182
num_warmup_steps : int ,
181
183
num_training_steps : int ,
182
- num_inner_steps : int ,
183
- warmup_outerstep : int | None ,
184
184
num_cycles : float ,
185
185
min_lr_rate : float = 0.0 ,
186
186
):
187
187
if current_step < num_warmup_steps :
188
188
return float (current_step ) / float (max (1 , num_warmup_steps ))
189
189
190
- if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
191
- return 0
192
-
193
190
progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
194
191
factor = 0.5 * (1.0 + math .cos (math .pi * float (num_cycles ) * 2.0 * progress ))
195
192
factor = factor * (1 - min_lr_rate ) + min_lr_rate
@@ -201,13 +198,36 @@ def get_cosine_schedule_with_warmup(optimizer, config: Config):
201
198
_get_cosine_schedule_with_warmup_lr_lambda ,
202
199
num_warmup_steps = config .warmup_steps ,
203
200
num_training_steps = config .total_steps ,
204
- num_inner_steps = config .hv .local_steps ,
205
- warmup_outerstep = config .hv .warmup_outerstep ,
206
201
num_cycles = 0.5 ,
207
202
)
208
203
return LambdaLR (optimizer , lambda_lr , - 1 )
209
204
210
205
206
+ def _get_lr_outer (
207
+ current_step : int ,
208
+ * ,
209
+ num_warmup_steps : int ,
210
+ num_training_steps : int ,
211
+ min_lr_rate : float = 0.0 ,
212
+ ):
213
+ if current_step < num_warmup_steps :
214
+ return 1
215
+
216
+ progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
217
+ factor = 0.5 * (1.0 + math .cos (math .pi * 2.0 * progress ))
218
+ factor = factor * (1 - min_lr_rate ) + min_lr_rate
219
+ return max (0 , factor )
220
+
221
+
222
+ def get_lr_outer (optimizer , config : Config ):
223
+ lambda_lr = partial (
224
+ _get_lr_outer ,
225
+ num_warmup_steps = config .warmup_steps ,
226
+ num_training_steps = config .total_steps ,
227
+ )
228
+ return LambdaLR (optimizer , lambda_lr , - 1 )
229
+
230
+
211
231
def train (config : Config ):
212
232
sharding_strategy = get_sharding_strategy (config .sharding_strategy )
213
233
local_rank = int (os .environ ["LOCAL_RANK" ])
@@ -298,6 +318,9 @@ def scheduler_fn(opt):
298
318
config = config ,
299
319
)
300
320
321
+ def outer_scheduler_fn (opt ):
322
+ return get_lr_outer (opt , config = config )
323
+
301
324
if config .hv is not None :
302
325
if config .ckpt .resume :
303
326
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
@@ -323,6 +346,7 @@ def scheduler_fn(opt):
323
346
outer_optimizer = outer_optimizer ,
324
347
inner_optimizer = inner_optimizer ,
325
348
scheduler = None ,
349
+ outer_scheduler = outer_scheduler_fn if config .hv .outer_scheduler else None ,
326
350
params = model .parameters (),
327
351
delay_optimizer_step = False ,
328
352
delay_grad_averaging = False ,
@@ -438,6 +462,7 @@ def scheduler_fn(opt):
438
462
scaler .update ()
439
463
440
464
scheduler .step ()
465
+
441
466
optimizer .zero_grad ()
442
467
443
468
if logging_activations_steps :
0 commit comments