1- import math
21from typing import List
32
43import torch
@@ -15,8 +14,6 @@ class ScheduleFreeSGD(BaseOptimizer):
1514 :param lr: float. learning rate.
1615 :param momentum: float. momentum factor, must be between 0 and 1 exclusive.
1716 :param weight_decay: float. weight decay (L2 penalty).
18- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
19- :param fixed_decay: bool. fix weight decay.
2017 :param r: float. use polynomial weighting in the average with power r.
2118 :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
2219 set to 0 for no weighting.
@@ -30,8 +27,6 @@ def __init__(
3027 lr : float = 1.0 ,
3128 momentum : float = 0.9 ,
3229 weight_decay : float = 0.0 ,
33- weight_decouple : bool = True ,
34- fixed_decay : bool = False ,
3530 r : float = 0.0 ,
3631 weight_lr_power : float = 2.0 ,
3732 warmup_steps : int = 0 ,
@@ -47,8 +42,6 @@ def __init__(
4742 'lr' : lr ,
4843 'momentum' : momentum ,
4944 'weight_decay' : weight_decay ,
50- 'weight_decouple' : weight_decouple ,
51- 'fixed_decay' : fixed_decay ,
5245 'r' : r ,
5346 'weight_lr_power' : weight_lr_power ,
5447 'warmup_steps' : warmup_steps ,
@@ -114,7 +107,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114107 lr : float = group ['lr' ] * schedule
115108 lr_max = group ['lr_max' ] = max (lr , group ['lr_max' ])
116109
117- weight = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
110+ weight : float = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
118111 weight_sum = group ['weight_sum' ] = group ['weight_sum' ] + weight
119112
120113 checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
@@ -137,8 +130,8 @@ def step(self, closure: CLOSURE = None) -> LOSS:
137130 grad = grad ,
138131 lr = lr ,
139132 weight_decay = group ['weight_decay' ],
140- weight_decouple = group [ 'weight_decouple' ] ,
141- fixed_decay = group [ 'fixed_decay' ] ,
133+ weight_decouple = False ,
134+ fixed_decay = False ,
142135 )
143136
144137 z = state ['z' ]
@@ -158,8 +151,6 @@ class ScheduleFreeAdamW(BaseOptimizer):
158151 :param lr: float. learning rate.
159152 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
160153 :param weight_decay: float. weight decay (L2 penalty).
161- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
162- :param fixed_decay: bool. fix weight decay.
163154 :param r: float. use polynomial weighting in the average with power r.
164155 :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
165156 set to 0 for no weighting.
@@ -174,8 +165,6 @@ def __init__(
174165 lr : float = 2.5e-3 ,
175166 betas : BETAS = (0.9 , 0.999 ),
176167 weight_decay : float = 0.0 ,
177- weight_decouple : bool = True ,
178- fixed_decay : bool = False ,
179168 r : float = 0.0 ,
180169 weight_lr_power : float = 2.0 ,
181170 warmup_steps : int = 0 ,
@@ -192,8 +181,6 @@ def __init__(
192181 'lr' : lr ,
193182 'betas' : betas ,
194183 'weight_decay' : weight_decay ,
195- 'weight_decouple' : weight_decouple ,
196- 'fixed_decay' : fixed_decay ,
197184 'r' : r ,
198185 'weight_lr_power' : weight_lr_power ,
199186 'warmup_steps' : warmup_steps ,
@@ -259,22 +246,16 @@ def step(self, closure: CLOSURE = None) -> LOSS:
259246
260247 beta1 , beta2 = group ['betas' ]
261248
262- bias_correction2_sq : float = math . sqrt ( 1.0 - beta2 ** group ['step' ])
249+ bias_correction2 : float = self . debias ( beta2 , group ['step' ])
263250
264- lr : float = group ['lr' ] * schedule * bias_correction2_sq
251+ lr : float = group ['lr' ] * schedule
265252 lr_max = group ['lr_max' ] = max (lr , group ['lr_max' ])
266253
267- weight = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
254+ weight : float = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
268255 weight_sum = group ['weight_sum' ] = group ['weight_sum' ] + weight
269256
270257 checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
271258
272- if group ['use_palm' ]:
273- beta2 : float = 1.0 - group ['step' ] ** - 0.8
274- debias : float = (1.0 - beta2 ) / (1.0 - beta2 ** group ['step' ])
275- else :
276- debias : float = beta2
277-
278259 for p in group ['params' ]:
279260 if p .grad is None :
280261 continue
@@ -289,27 +270,27 @@ def step(self, closure: CLOSURE = None) -> LOSS:
289270 state ['z' ] = p .clone ()
290271 state ['exp_avg_sq' ] = torch .zeros_like (p )
291272
292- self .apply_weight_decay (
293- p = p ,
294- grad = grad ,
295- lr = lr ,
296- weight_decay = group ['weight_decay' ],
297- weight_decouple = group ['weight_decouple' ],
298- fixed_decay = group ['fixed_decay' ],
299- )
300-
301273 z , exp_avg_sq = state ['z' ], state ['exp_avg_sq' ]
302- exp_avg_sq .mul_ (debias ).addcmul_ (grad , grad , value = 1.0 - debias )
274+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
303275
304276 de_nom = self .apply_ams_bound (
305277 ams_bound = group ['ams_bound' ],
306- exp_avg_sq = exp_avg_sq ,
278+ exp_avg_sq = exp_avg_sq . div ( bias_correction2 ) ,
307279 max_exp_avg_sq = state .get ('max_exp_avg_sq' , None ),
308280 eps = group ['eps' ],
309281 )
310282
311283 grad .div_ (de_nom )
312284
285+ self .apply_weight_decay (
286+ p = p ,
287+ grad = grad ,
288+ lr = lr ,
289+ weight_decay = group ['weight_decay' ],
290+ weight_decouple = False ,
291+ fixed_decay = False ,
292+ )
293+
313294 p .lerp_ (z , weight = checkpoint )
314295 p .add_ (grad , alpha = lr * (beta1 * (1.0 - checkpoint ) - 1 ))
315296
@@ -325,12 +306,13 @@ class ScheduleFreeRAdam(BaseOptimizer):
325306 :param lr: float. learning rate.
326307 :param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
327308 :param weight_decay: float. weight decay (L2 penalty).
328- :param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
329- :param fixed_decay: bool. fix weight decay.
330- :param degenerated_to_sgd: float. degenerated to SGD.
331309 :param r: float. use polynomial weighting in the average with power r.
332310 :param weight_lr_power: float. during warmup, the weights in the average will be equal to lr raised to this power.
333311 set to 0 for no weighting.
312+ :param silent_sgd_phase: bool. the optimizer will not use the first SGD phase of RAdam. This means that the
313+ optimizer will not update model parameters during the early training steps (e.g., < 5 when β_2 = 0.999), but
314+ just update the momentum values of the optimizer. This helps stabilize training by ensuring smoother warmup
315+ behavior and more reliable calculation of the moving average coefficient (`ckp1`). Recommended to set to True.
334316 :param eps: float. term added to the denominator to improve numerical stability.
335317 """
336318
@@ -340,11 +322,9 @@ def __init__(
340322 lr : float = 2.5e-3 ,
341323 betas : BETAS = (0.9 , 0.999 ),
342324 weight_decay : float = 0.0 ,
343- weight_decouple : bool = True ,
344- fixed_decay : bool = False ,
345- degenerated_to_sgd : bool = False ,
346325 r : float = 0.0 ,
347326 weight_lr_power : float = 2.0 ,
327+ silent_sgd_phase : bool = True ,
348328 eps : float = 1e-8 ,
349329 ** kwargs ,
350330 ):
@@ -357,9 +337,7 @@ def __init__(
357337 'lr' : lr ,
358338 'betas' : betas ,
359339 'weight_decay' : weight_decay ,
360- 'weight_decouple' : weight_decouple ,
361- 'fixed_decay' : fixed_decay ,
362- 'degenerated_to_sgd' : degenerated_to_sgd ,
340+ 'silent_sgd_phase' : silent_sgd_phase ,
363341 'r' : r ,
364342 'weight_lr_power' : weight_lr_power ,
365343 'eps' : eps ,
@@ -418,32 +396,28 @@ def step(self, closure: CLOSURE = None) -> LOSS:
418396
419397 beta1 , beta2 = group ['betas' ]
420398
421- bias_correction2_sq : float = math . sqrt ( 1.0 - beta2 ** group ['step' ])
399+ bias_correction2 : float = self . debias_beta ( beta2 , group ['step' ])
422400
423401 lr , n_sma = self .get_rectify_step_size (
424402 is_rectify = True ,
425403 step = group ['step' ],
426404 lr = group ['lr' ],
427405 beta2 = beta2 ,
428406 n_sma_threshold = 4 ,
429- degenerated_to_sgd = group [ 'degenerated_to_sgd' ] ,
407+ degenerated_to_sgd = False ,
430408 )
409+ if lr < 0.0 :
410+ lr = float (not group ['silent_sgd_phase' ])
431411
432412 lr_max = group ['lr_max' ] = max (lr , group ['lr_max' ])
433413
434- weight = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
414+ weight : float = (group ['step' ] ** group ['r' ]) * (lr_max ** group ['weight_lr_power' ])
435415 weight_sum = group ['weight_sum' ] = group ['weight_sum' ] + weight
436416
437417 checkpoint : float = weight / weight_sum if weight_sum != 0.0 else 0.0
438418
439419 adaptive_y_lr : float = lr * (beta1 * (1.0 - checkpoint ) - 1.0 )
440420
441- if group ['use_palm' ]:
442- beta2 : float = 1.0 - group ['step' ] ** - 0.8
443- debias : float = (1.0 - beta2 ) / (1.0 - beta2 ** group ['step' ])
444- else :
445- debias : float = beta2
446-
447421 for p in group ['params' ]:
448422 if p .grad is None :
449423 continue
@@ -459,19 +433,19 @@ def step(self, closure: CLOSURE = None) -> LOSS:
459433 state ['exp_avg_sq' ] = torch .zeros_like (p )
460434
461435 z , exp_avg_sq = state ['z' ], state ['exp_avg_sq' ]
462- exp_avg_sq .mul_ (debias ).addcmul_ (grad , grad , value = 1.0 - debias )
436+ exp_avg_sq .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
463437
464438 if n_sma > 4.0 :
465- de_nom = exp_avg_sq .sqrt ().div_ (bias_correction2_sq ).add_ (group ['eps' ])
439+ de_nom = exp_avg_sq .sqrt ().div_ (bias_correction2 ).add_ (group ['eps' ])
466440 grad .div_ (de_nom )
467441
468442 self .apply_weight_decay (
469443 p = p ,
470444 grad = grad ,
471445 lr = lr ,
472446 weight_decay = group ['weight_decay' ],
473- weight_decouple = group [ 'weight_decouple' ] ,
474- fixed_decay = group [ 'fixed_decay' ] ,
447+ weight_decouple = False ,
448+ fixed_decay = False ,
475449 )
476450
477451 p .lerp_ (z , weight = checkpoint )
0 commit comments