@@ -263,50 +263,56 @@ def lr_scheduler(self, num_train_steps, override_lr=None):
263263 warmup_steps = _convert_frac_or_steps (self .warmup , cycle_steps )
264264 else :
265265 warmup_steps = _convert_frac_or_steps (self .rewarmup , cycle_steps )
266+ warmup_steps = min (warmup_steps , cycle_steps )
266267
267268 if warmup_steps != 0 :
268269 warmup = optax .linear_schedule (previous_end , learning_rate , warmup_steps )
269270 schedules .append (warmup )
270271 boundaries .append (start + warmup_steps )
271272
272- lr_decay_steps = (
273- _convert_frac_or_steps (self .decay , cycle_steps )
274- if self .decay is not None
275- else cycle_steps - warmup_steps
273+ max_decay_steps = max (cycle_steps - warmup_steps , 0 )
274+ requested_decay_steps = (
275+ _convert_frac_or_steps (self .decay , cycle_steps ) if self .decay is not None else max_decay_steps
276276 )
277+ lr_decay_steps = min (max (requested_decay_steps , 0 ), max_decay_steps )
277278 stable_steps = cycle_steps - warmup_steps - lr_decay_steps
278279
279- if stable_steps != 0 :
280+ if stable_steps > 0 :
280281 stable = optax .constant_schedule (learning_rate )
281282 schedules .append (stable )
282283 boundaries .append (start + warmup_steps + stable_steps )
283284
284- if isinstance (self .lr_schedule , str ):
285- match self .lr_schedule :
286- case "constant" :
287- schedule = optax .constant_schedule (learning_rate )
288- case "cosine" :
289- schedule = optax .cosine_decay_schedule (learning_rate , lr_decay_steps , self .min_lr_ratio )
290- case "linear" :
291- schedule = optax .linear_schedule (learning_rate , min_lr , lr_decay_steps )
292- case "inv_sqrt" :
293- schedule = _inv_sqrt_decay_schedule (learning_rate , min_lr , warmup_steps , 10000 )
294- case "inv" :
295- schedule = _inv_decay_schedule (learning_rate , min_lr , lr_decay_steps )
296- case _:
297- raise ValueError (f"Unknown lr_schedule: { self .lr_schedule } " )
298- elif isinstance (self .lr_schedule , LrSchedule ):
299- schedule = self .lr_schedule .build (
300- LrScheduleContext (
301- warmup_steps = warmup_steps ,
302- decay_steps = lr_decay_steps ,
303- learning_rate = learning_rate ,
304- min_lr_ratio = self .min_lr_ratio ,
305- min_lr = min_lr ,
285+ if lr_decay_steps > 0 :
286+ if isinstance (self .lr_schedule , str ):
287+ match self .lr_schedule :
288+ case "constant" :
289+ schedule = optax .constant_schedule (learning_rate )
290+ case "cosine" :
291+ schedule = optax .cosine_decay_schedule (learning_rate , lr_decay_steps , self .min_lr_ratio )
292+ case "linear" :
293+ schedule = optax .linear_schedule (learning_rate , min_lr , lr_decay_steps )
294+ case "inv_sqrt" :
295+ schedule = _inv_sqrt_decay_schedule (learning_rate , min_lr , warmup_steps , 10000 )
296+ case "inv" :
297+ schedule = _inv_decay_schedule (learning_rate , min_lr , lr_decay_steps )
298+ case _:
299+ raise ValueError (f"Unknown lr_schedule: { self .lr_schedule } " )
300+ elif isinstance (self .lr_schedule , LrSchedule ):
301+ schedule = self .lr_schedule .build (
302+ LrScheduleContext (
303+ warmup_steps = warmup_steps ,
304+ decay_steps = lr_decay_steps ,
305+ learning_rate = learning_rate ,
306+ min_lr_ratio = self .min_lr_ratio ,
307+ min_lr = min_lr ,
308+ )
309+ )
310+ else :
311+ raise ValueError (
312+ f"lr_schedule must be a string or an instance of LrSchedule, got { self .lr_schedule } "
306313 )
307- )
308314 else :
309- raise ValueError ( f"lr_schedule must be a string or an instance of LrSchedule, got { self . lr_schedule } " )
315+ schedule = optax . constant_schedule ( learning_rate )
310316
311317 previous_end = schedule (lr_decay_steps )
312318
0 commit comments