@@ -91,6 +91,7 @@ class HvConfig(BaseConfig):
91
91
world_rank : int
92
92
galaxy_size : int
93
93
fail_rank_drop : bool = False # fail if we lose a diloco worker
94
+ warmup_outerstep : int = 10
94
95
95
96
@model_validator (mode = "before" )
96
97
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -174,8 +175,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
174
175
175
176
176
177
def _get_cosine_schedule_with_warmup_lr_lambda (
177
- current_step : int , * , num_warmup_steps : int , num_training_steps : int , num_cycles : float , min_lr_rate : float = 0.0
178
+ current_step : int ,
179
+ * ,
180
+ num_warmup_steps : int ,
181
+ num_training_steps : int ,
182
+ num_inner_steps : int ,
183
+ warmup_outerstep : int | None ,
184
+ num_cycles : float ,
185
+ min_lr_rate : float = 0.0 ,
178
186
):
187
+ if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
188
+ return 0
189
+
179
190
if current_step < num_warmup_steps :
180
191
return float (current_step ) / float (max (1 , num_warmup_steps ))
181
192
progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
@@ -184,11 +195,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
184
195
return max (0 , factor )
185
196
186
197
187
- def get_cosine_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , num_inner_steps ):
198
+ def get_cosine_schedule_with_warmup (optimizer , config : Config ):
188
199
lambda_lr = partial (
189
200
_get_cosine_schedule_with_warmup_lr_lambda ,
190
- num_warmup_steps = num_warmup_steps ,
191
- num_training_steps = num_training_steps ,
201
+ num_warmup_steps = config .warmup_steps ,
202
+ num_training_steps = config .total_steps ,
203
+ num_inner_steps = config .hv .local_steps ,
204
+ warmup_outerstep = config .hv .warmup_outerstep ,
192
205
num_cycles = 0.5 ,
193
206
)
194
207
return LambdaLR (optimizer , lambda_lr , - 1 )
@@ -281,9 +294,7 @@ def train(config: Config):
281
294
def scheduler_fn (opt ):
282
295
return get_cosine_schedule_with_warmup (
283
296
opt ,
284
- num_warmup_steps = config .warmup_steps ,
285
- num_training_steps = config .total_steps ,
286
- num_inner_steps = config .hv .local_steps ,
297
+ config = config ,
287
298
)
288
299
289
300
if config .hv is not None :
0 commit comments