@@ -105,6 +105,7 @@ class HvConfig(BaseConfig):
105
105
skip_load_from_peers : bool = False
106
106
world_rank : int
107
107
galaxy_size : int
108
+ warmup_outerstep : int = 10
108
109
109
110
@model_validator (mode = "before" )
110
111
def cast_str_to_list (cls , values : dict [str , Any ]) -> dict [str , Any ]:
@@ -190,8 +191,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
190
191
191
192
192
193
def _get_cosine_schedule_with_warmup_lr_lambda (
193
- current_step : int , * , num_warmup_steps : int , num_training_steps : int , num_cycles : float , min_lr_rate : float = 0.0
194
+ current_step : int ,
195
+ * ,
196
+ num_warmup_steps : int ,
197
+ num_training_steps : int ,
198
+ num_inner_steps : int ,
199
+ warmup_outerstep : int | None ,
200
+ num_cycles : float ,
201
+ min_lr_rate : float = 0.0 ,
194
202
):
203
+ if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep :
204
+ return 0
205
+
195
206
if current_step < num_warmup_steps :
196
207
return float (current_step ) / float (max (1 , num_warmup_steps ))
197
208
progress = float (current_step - num_warmup_steps ) / float (max (1 , num_training_steps - num_warmup_steps ))
@@ -200,11 +211,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
200
211
return max (0 , factor )
201
212
202
213
203
- def get_cosine_schedule_with_warmup (optimizer , num_warmup_steps , num_training_steps , num_inner_steps ):
214
+ def get_cosine_schedule_with_warmup (optimizer , config : Config ):
204
215
lambda_lr = partial (
205
216
_get_cosine_schedule_with_warmup_lr_lambda ,
206
- num_warmup_steps = num_warmup_steps ,
207
- num_training_steps = num_training_steps ,
217
+ num_warmup_steps = config .warmup_steps ,
218
+ num_training_steps = config .total_steps ,
219
+ num_inner_steps = config .hv .local_steps ,
220
+ warmup_outerstep = config .hv .warmup_outerstep ,
208
221
num_cycles = 0.5 ,
209
222
)
210
223
return LambdaLR (optimizer , lambda_lr , - 1 )
@@ -301,9 +314,7 @@ def train(config: Config):
301
314
def scheduler_fn (opt ):
302
315
return get_cosine_schedule_with_warmup (
303
316
opt ,
304
- num_warmup_steps = config .warmup_steps ,
305
- num_training_steps = config .total_steps ,
306
- num_inner_steps = config .hv .local_steps ,
317
+ config = config ,
307
318
)
308
319
309
320
if config .hv is not None :
0 commit comments