1111from pytorch_optimizer .optimizer .gc import centralize_gradient
1212from pytorch_optimizer .optimizer .utils import normalize_gradient , unit_norm
1313
14- __AUTHORS__ = [
15- '@lessw2020' ,
16- '@NestorDemeure' ,
17- # with contributions from :
18- '@BrianPugh' ,
19- '@Kayuksel' ,
20- '@TheZothen' ,
21- ]
22-
2314
2415class Ranger21 (Optimizer , BaseOptimizer ):
2516 """
@@ -38,7 +29,7 @@ class Ranger21(Optimizer, BaseOptimizer):
3829 optimizer.step()
3930 """
4031
41- def __init__ (
32+ def __init__ ( # pylint: disable=R0913
4233 self ,
4334 params : PARAMETERS ,
4435 num_iterations : int ,
@@ -58,6 +49,7 @@ def __init__(
5849 lookahead_blending_alpha : float = 0.5 ,
5950 weight_decay : float = 1e-4 ,
6051 norm_loss_factor : float = 1e-4 ,
52+ adamd_debias_term : bool = False ,
6153 eps : float = 1e-8 ,
6254 ):
6355 """Ranger21 optimizer
@@ -76,6 +68,7 @@ def __init__(
7668 :param lookahead_blending_alpha: float. blending alpha
7769 :param weight_decay: float. weight decay (L2 penalty)
7870 :param norm_loss_factor: float. norm loss factor
71+ :param adamd_debias_term: bool.Only correct the denominator to avoid inflating step sizes early in training
7972 :param eps: float. term added to the denominator to improve numerical stability
8073 """
8174 self .lr = lr
@@ -91,6 +84,7 @@ def __init__(
9184 self .lookahead_blending_alpha = lookahead_blending_alpha
9285 self .weight_decay = weight_decay
9386 self .norm_loss_factor = norm_loss_factor
87+ self .adamd_debias_term = adamd_debias_term
9488 self .eps = eps
9589
9690 self .validate_parameters ()
@@ -108,6 +102,7 @@ def __init__(
108102 betas = betas ,
109103 eps = eps ,
110104 weight_decay = weight_decay ,
105+ adamd_debias_term = adamd_debias_term ,
111106 )
112107 super ().__init__ (params , defaults )
113108
@@ -240,6 +235,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
240235 variance_ma_sum += (variance_ma / bias_correction2 ).sum ()
241236
242237 # stable weight decay
238+ if param_size == 0 :
239+ raise ValueError ('[-] size of parameter is 0' )
240+
243241 variance_normalized = math .sqrt (variance_ma_sum / param_size )
244242 if math .isnan (variance_normalized ):
245243 raise RuntimeError ('hit nan for variance_normalized' )
@@ -299,7 +297,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
299297
300298 noise_norm : float = math .sqrt ((1.0 + beta2 ) ** 2 + beta2 ** 2 )
301299
302- step_size : float = lr / bias_correction1
300+ step_size : float = lr
301+ if not group ['adamd_debias_term' ]:
302+ step_size /= bias_correction1
303303
304304 if self .use_softplus :
305305 de_nom = F .softplus (de_nom , beta = self .beta_softplus )
0 commit comments