@@ -92,6 +92,8 @@ class HvConfig(BaseConfig):
92
92
galaxy_size : int
93
93
fail_rank_drop : bool = False # fail if we lose a diloco worker
94
94
warmup_outerstep : int = 10
95
+ outer_lr_min : float = 0.3
96
+ outer_scheduler : bool = False
95
97
96
98
@model_validator (mode = "before" )
97
99
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -180,17 +182,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
180
182
* ,
181
183
num_warmup_steps : int ,
182
184
num_training_steps : int ,
183
- num_inner_steps : int ,
184
- warmup_outerstep : int | None ,
185
185
num_cycles : float ,
186
186
min_lr_rate : float = 0.0 ,
187
187
):
188
188
if current_step < num_warmup_steps :
189
189
return float (current_step ) / float (max (1 , num_warmup_steps ))
190
190
191
- if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
192
- return 0
193
-
194
191
progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
195
192
factor = 0.5 * (1.0 + math .cos (math .pi * float (num_cycles ) * 2.0 * progress ))
196
193
factor = factor * (1 - min_lr_rate ) + min_lr_rate
@@ -202,13 +199,36 @@ def get_cosine_schedule_with_warmup(optimizer, config: Config):
202
199
_get_cosine_schedule_with_warmup_lr_lambda ,
203
200
num_warmup_steps = config .warmup_steps ,
204
201
num_training_steps = config .total_steps ,
205
- num_inner_steps = config .hv .local_steps ,
206
- warmup_outerstep = config .hv .warmup_outerstep ,
207
202
num_cycles = 0.5 ,
208
203
)
209
204
return LambdaLR (optimizer , lambda_lr , - 1 )
210
205
211
206
207
+ def _get_lr_outer (
208
+ current_step : int ,
209
+ * ,
210
+ num_warmup_steps : int ,
211
+ num_training_steps : int ,
212
+ min_lr_rate : float = 0.0 ,
213
+ ):
214
+ if current_step < num_warmup_steps :
215
+ return 1
216
+
217
+ progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
218
+ factor = 0.5 * (1.0 + math .cos (math .pi * 2.0 * progress ))
219
+ factor = factor * (1 - min_lr_rate ) + min_lr_rate
220
+ return max (0 , factor )
221
+
222
+
223
+ def get_lr_outer (optimizer , config : Config ):
224
+ lambda_lr = partial (
225
+ _get_lr_outer ,
226
+ num_warmup_steps = config .warmup_steps ,
227
+ num_training_steps = config .total_steps ,
228
+ )
229
+ return LambdaLR (optimizer , lambda_lr , - 1 )
230
+
231
+
212
232
def train (config : Config ):
213
233
sharding_strategy = get_sharding_strategy (config .sharding_strategy )
214
234
local_rank = int (os .environ ["LOCAL_RANK" ])
@@ -291,6 +311,9 @@ def scheduler_fn(opt):
291
311
config = config ,
292
312
)
293
313
314
+ def outer_scheduler_fn (opt ):
315
+ return get_lr_outer (opt , config = config )
316
+
294
317
if config .hv is not None :
295
318
if config .ckpt .resume :
296
319
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
@@ -316,6 +339,7 @@ def scheduler_fn(opt):
316
339
outer_optimizer = outer_optimizer ,
317
340
inner_optimizer = inner_optimizer ,
318
341
scheduler = None ,
342
+ outer_scheduler = outer_scheduler_fn if config .hv .outer_scheduler else None ,
319
343
params = model .parameters (),
320
344
delay_optimizer_step = False ,
321
345
delay_grad_averaging = False ,
@@ -435,6 +459,7 @@ def scheduler_fn(opt):
435
459
scaler .update ()
436
460
437
461
scheduler .step ()
462
+
438
463
optimizer .zero_grad ()
439
464
440
465
if config .hv is not None :
0 commit comments