@@ -103,6 +103,7 @@ class HvConfig(BaseConfig):
103
103
skip_load_from_peers : bool = False
104
104
world_rank : int
105
105
galaxy_size : int
106
+ warmup_outerstep : int = 10
106
107
107
108
@model_validator (mode = "before" )
108
109
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -188,8 +189,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
188
189
189
190
190
191
def _get_cosine_schedule_with_warmup_lr_lambda (
191
- current_step : int , * , num_warmup_steps : int , num_training_steps : int , num_cycles : float , min_lr_rate : float = 0.0
192
+ current_step : int ,
193
+ * ,
194
+ num_warmup_steps : int ,
195
+ num_training_steps : int ,
196
+ num_inner_steps : int ,
197
+ warmup_outerstep : int | None ,
198
+ num_cycles : float ,
199
+ min_lr_rate : float = 0.0 ,
192
200
):
201
+ if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
202
+ return 0
203
+
193
204
if current_step < num_warmup_steps :
194
205
return float (current_step ) / float (max (1 , num_warmup_steps ))
195
206
progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
@@ -198,11 +209,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
198
209
return max (0 , factor )
199
210
200
211
201
- def get_cosine_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , num_inner_steps ):
212
+ def get_cosine_schedule_with_warmup (optimizer , config : Config ):
202
213
lambda_lr = partial (
203
214
_get_cosine_schedule_with_warmup_lr_lambda ,
204
- num_warmup_steps = num_warmup_steps ,
205
- num_training_steps = num_training_steps ,
215
+ num_warmup_steps = config .warmup_steps ,
216
+ num_training_steps = config .total_steps ,
217
+ num_inner_steps = config .hv .local_steps ,
218
+ warmup_outerstep = config .hv .warmup_outerstep ,
206
219
num_cycles = 0.5 ,
207
220
)
208
221
return LambdaLR (optimizer , lambda_lr , - 1 )
@@ -299,9 +312,7 @@ def train(config: Config):
299
312
def scheduler_fn (opt ):
300
313
return get_cosine_schedule_with_warmup (
301
314
opt ,
302
- num_warmup_steps = config .warmup_steps ,
303
- num_training_steps = config .total_steps ,
304
- num_inner_steps = config .hv .local_steps ,
315
+ config = config ,
305
316
)
306
317
307
318
if config .hv is not None :
0 commit comments