@@ -21,7 +21,9 @@ class AdaBelief(Optimizer, BaseOptimizer):
2121 :param rectify: bool. perform the rectified update similar to RAdam.
2222 :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high.
2323 :param amsgrad: bool. whether to use the AMSBound variant.
24- :param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
24+ :param r: float. EMA factor. between 0.9 ~ 0.99 is preferred.
25+ :param adanorm: bool. whether to use the AdaNorm variant.
26+ :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
2527 :param eps: float. term added to the denominator to improve numerical stability.
2628 """
2729
@@ -37,31 +39,35 @@ def __init__(
3739 rectify : bool = True ,
3840 degenerated_to_sgd : bool = True ,
3941 amsgrad : bool = False ,
40- adamd_debias_term : bool = False ,
42+ r : float = 0.95 ,
43+ adanorm : bool = False ,
44+ adam_debias : bool = False ,
4145 eps : float = 1e-16 ,
4246 ):
4347 self .lr = lr
4448 self .betas = betas
4549 self .weight_decay = weight_decay
4650 self .n_sma_threshold = n_sma_threshold
47- self .weight_decouple = weight_decouple
48- self .fixed_decay = fixed_decay
49- self .rectify = rectify
5051 self .degenerated_to_sgd = degenerated_to_sgd
51- self .adamd_debias_term = adamd_debias_term
5252 self .eps = eps
5353
5454 self .validate_parameters ()
5555
5656 defaults : DEFAULTS = {
5757 'lr' : lr ,
5858 'betas' : betas ,
59- 'eps' : eps ,
6059 'weight_decay' : weight_decay ,
60+ 'weight_decouple' : weight_decouple ,
61+ 'fixed_decay' : fixed_decay ,
62+ 'rectify' : rectify ,
6163 'amsgrad' : amsgrad ,
62- 'adamd_debias_term' : adamd_debias_term ,
63- 'buffer' : [[None , None , None ] for _ in range (10 )],
64+ 'adanorm' : adanorm ,
65+ 'adam_debias' : adam_debias ,
66+ 'eps' : eps ,
6467 }
68+ if adanorm :
69+ defaults .update ({'r' : r })
70+
6571 super ().__init__ (params , defaults )
6672
6773 def validate_parameters (self ):
@@ -76,12 +82,14 @@ def __str__(self) -> str:
7682 @torch .no_grad ()
7783 def reset (self ):
7884 for group in self .param_groups :
85+ group ['step' ] = 0
7986 for p in group ['params' ]:
8087 state = self .state [p ]
8188
82- state ['step' ] = 0
8389 state ['exp_avg' ] = torch .zeros_like (p )
8490 state ['exp_avg_var' ] = torch .zeros_like (p )
91+ if group ['adanorm' ]:
92+ state ['exp_grad_norm' ] = torch .zeros ((1 ,), dtype = p .dtype , device = p .device )
8593 if group ['amsgrad' ]:
8694 state ['max_exp_avg_var' ] = torch .zeros_like (p )
8795
@@ -93,11 +101,21 @@ def step(self, closure: CLOSURE = None) -> LOSS:
93101 loss = closure ()
94102
95103 for group in self .param_groups :
104+ if 'step' in group :
105+ group ['step' ] += 1
106+ else :
107+ group ['step' ] = 1
108+
96109 beta1 , beta2 = group ['betas' ]
97- weight_decay : float = group ['weight_decay' ]
110+ weight_decay = group ['weight_decay' ]
111+
112+ bias_correction1 = 1.0 - beta1 ** group ['step' ]
113+ bias_correction2_sq = math .sqrt (1.0 - beta2 ** group ['step' ])
98114
99- if self . rectify :
115+ if group [ ' rectify' ] :
100116 n_sma_max : float = 2.0 / (1.0 - beta2 ) - 1.0
117+ beta2_t : float = beta2 ** group ['step' ]
118+ n_sma : float = n_sma_max - 2 * group ['step' ] * beta2_t / (1.0 - beta2_t )
101119
102120 for p in group ['params' ]:
103121 if p .grad is None :
@@ -109,70 +127,68 @@ def step(self, closure: CLOSURE = None) -> LOSS:
109127
110128 state = self .state [p ]
111129 if len (state ) == 0 :
112- state ['step' ] = 0
113130 state ['exp_avg' ] = torch .zeros_like (p )
114131 state ['exp_avg_var' ] = torch .zeros_like (p )
132+ if group ['adanorm' ]:
133+ state ['exp_grad_norm' ] = torch .zeros ((1 ,), dtype = grad .dtype , device = grad .device )
115134 if group ['amsgrad' ]:
116135 state ['max_exp_avg_var' ] = torch .zeros_like (p )
117136
118- if self . weight_decouple :
119- p .mul_ (1.0 - ( group ['lr ' ] * weight_decay if not self . fixed_decay else weight_decay ))
137+ if group [ ' weight_decouple' ] :
138+ p .mul_ (1.0 - group ['weight_decay ' ] * ( 1.0 if group [ ' fixed_decay' ] else group [ 'lr' ] ))
120139 elif weight_decay > 0.0 :
121140 grad .add_ (p , alpha = weight_decay )
122141
123- state ['step' ] += 1
124142 exp_avg , exp_avg_var = state ['exp_avg' ], state ['exp_avg_var' ]
125143
126- bias_correction1 = 1.0 - beta1 ** state ['step' ]
127- bias_correction2_sq = math .sqrt (1.0 - beta2 ** state ['step' ])
144+ s_grad = grad
145+ if group ['adanorm' ]:
146+ grad_norm = torch .linalg .norm (grad )
128147
129- exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
130- grad_residual = grad - exp_avg
131- exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1.0 - beta2 ).add_ (group ['eps' ])
148+ exp_grad_norm = state ['exp_grad_norm' ]
149+ exp_grad_norm .mul_ (group ['r' ]).add_ (grad_norm , alpha = 1.0 - group ['r' ])
150+
151+ if exp_grad_norm > grad_norm :
152+ s_grad *= exp_grad_norm / grad_norm
153+
154+ exp_avg .mul_ (beta1 ).add_ (s_grad , alpha = 1.0 - beta1 )
155+ grad_residual = s_grad - exp_avg
156+ exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1.0 - beta2 ).add_ (self .eps )
132157
133158 if group ['amsgrad' ]:
134159 max_exp_avg_var = state ['max_exp_avg_var' ]
135160 torch .max (max_exp_avg_var , exp_avg_var , out = max_exp_avg_var )
136- de_nom = max_exp_avg_var .sqrt ()
161+ de_nom = max_exp_avg_var .add ( self . eps ). sqrt ()
137162 else :
138- de_nom = exp_avg_var .sqrt ()
139- de_nom .div_ (bias_correction2_sq ).add_ (group ['eps' ])
163+ de_nom = exp_avg_var .add (self .eps ).sqrt ()
164+
165+ de_nom .div_ (bias_correction2_sq ).add_ (self .eps )
140166
141- if not self . rectify :
142- step_size : float = group ['lr' ] if group ['adamd_debias_term ' ] else group ['lr' ] / bias_correction1
167+ if not group [ ' rectify' ] :
168+ step_size : float = group ['lr' ] if group ['adam_debias ' ] else group ['lr' ] / bias_correction1
143169 p .addcdiv_ (exp_avg , de_nom , value = - step_size )
144170 continue
145171
146- buffered = group ['buffer' ][state ['step' ] % 10 ]
147- if state ['step' ] == buffered [0 ]:
148- n_sma , step_size = buffered [1 ], buffered [2 ]
172+ if n_sma >= self .n_sma_threshold :
173+ step_size = math .sqrt (
174+ (1 - beta2_t )
175+ * (n_sma - 4 )
176+ / (n_sma_max - 4 )
177+ * (n_sma - 2 )
178+ / n_sma
179+ * n_sma_max
180+ / (n_sma_max - 2 )
181+ )
182+ elif self .degenerated_to_sgd :
183+ step_size = 1.0
149184 else :
150- buffered [0 ] = state ['step' ]
151- beta2_t = beta2 ** state ['step' ]
152- n_sma = n_sma_max - 2 * state ['step' ] * beta2_t / (1 - beta2_t )
153- buffered [1 ] = n_sma
154-
155- if n_sma >= self .n_sma_threshold :
156- step_size = math .sqrt (
157- (1 - beta2_t )
158- * (n_sma - 4 )
159- / (n_sma_max - 4 )
160- * (n_sma - 2 )
161- / n_sma
162- * n_sma_max
163- / (n_sma_max - 2 )
164- )
165- if not group ['adamd_debias_term' ]:
166- step_size /= bias_correction1
167- elif self .degenerated_to_sgd :
168- step_size = 1.0 / bias_correction1
169- else :
170- step_size = - 1
171-
172- buffered [2 ] = step_size
185+ step_size = - 1
186+
187+ if not group ['adam_debias' ]:
188+ step_size /= bias_correction1
173189
174190 if n_sma >= self .n_sma_threshold :
175- de_nom = exp_avg_var .sqrt ().add_ (group [ ' eps' ] )
191+ de_nom = exp_avg_var .sqrt ().add_ (self . eps )
176192 p .addcdiv_ (exp_avg , de_nom , value = - step_size * group ['lr' ])
177193 elif step_size > 0 :
178194 p .add_ (exp_avg , alpha = - step_size * group ['lr' ])
0 commit comments