@@ -91,6 +91,7 @@ class HvConfig(BaseConfig):
91
91
world_rank : int
92
92
galaxy_size : int
93
93
fail_rank_drop : bool = False # fail if we lose a diloco worker
94
+ warmup_outerstep : int = 10
94
95
95
96
@model_validator (mode = "before" )
96
97
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -175,8 +176,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
175
176
176
177
177
178
def _get_cosine_schedule_with_warmup_lr_lambda (
178
- current_step : int , * , num_warmup_steps : int , num_training_steps : int , num_cycles : float , min_lr_rate : float = 0.0
179
+ current_step : int ,
180
+ * ,
181
+ num_warmup_steps : int ,
182
+ num_training_steps : int ,
183
+ num_inner_steps : int ,
184
+ warmup_outerstep : int | None ,
185
+ num_cycles : float ,
186
+ min_lr_rate : float = 0.0 ,
179
187
):
188
+ if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
189
+ return 0
190
+
180
191
if current_step < num_warmup_steps :
181
192
return float (current_step ) / float (max (1 , num_warmup_steps ))
182
193
progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
@@ -185,11 +196,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
185
196
return max (0 , factor )
186
197
187
198
188
- def get_cosine_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , num_inner_steps ):
199
+ def get_cosine_schedule_with_warmup (optimizer , config : Config ):
189
200
lambda_lr = partial (
190
201
_get_cosine_schedule_with_warmup_lr_lambda ,
191
- num_warmup_steps = num_warmup_steps ,
192
- num_training_steps = num_training_steps ,
202
+ num_warmup_steps = config .warmup_steps ,
203
+ num_training_steps = config .total_steps ,
204
+ num_inner_steps = config .hv .local_steps ,
205
+ warmup_outerstep = config .hv .warmup_outerstep ,
193
206
num_cycles = 0.5 ,
194
207
)
195
208
return LambdaLR (optimizer , lambda_lr , - 1 )
@@ -274,9 +287,7 @@ def train(config: Config):
274
287
def scheduler_fn (opt ):
275
288
return get_cosine_schedule_with_warmup (
276
289
opt ,
277
- num_warmup_steps = config .warmup_steps ,
278
- num_training_steps = config .total_steps ,
279
- num_inner_steps = config .hv .local_steps ,
290
+ config = config ,
280
291
)
281
292
282
293
if config .hv is not None :
0 commit comments