@@ -35,6 +35,7 @@ def __init__(
3535 rectify : bool = True ,
3636 degenerated_to_sgd : bool = True ,
3737 amsgrad : bool = False ,
38+ adamd_debias_term : bool = False ,
3839 eps : float = 1e-16 ,
3940 ):
4041 """
@@ -48,6 +49,7 @@ def __init__(
4849 :param rectify: bool. perform the rectified update similar to RAdam
4950 :param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
5051 :param amsgrad: bool. whether to use the AMSBound variant
52+ :param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
5153 :param eps: float. term added to the denominator to improve numerical stability
5254 """
5355 self .lr = lr
@@ -58,6 +60,7 @@ def __init__(
5860 self .fixed_decay = fixed_decay
5961 self .rectify = rectify
6062 self .degenerated_to_sgd = degenerated_to_sgd
63+ self .adamd_debias_term = adamd_debias_term
6164 self .eps = eps
6265
6366 buffer : BUFFER = [[None , None , None ] for _ in range (10 )]
@@ -73,6 +76,7 @@ def __init__(
7376 eps = eps ,
7477 weight_decay = weight_decay ,
7578 amsgrad = amsgrad ,
79+ adamd_debias_term = adamd_debias_term ,
7680 buffer = buffer ,
7781 )
7882 super ().__init__ (params , defaults )
@@ -81,17 +85,17 @@ def __setstate__(self, state: STATE):
8185 super ().__setstate__ (state )
8286 for group in self .param_groups :
8387 group .setdefault ('amsgrad' , False )
88+ group .setdefault ('adamd_debias_term' , False )
8489
8590 def reset (self ):
8691 for group in self .param_groups :
8792 for p in group ['params' ]:
8893 state = self .state [p ]
89- amsgrad = group ['amsgrad' ]
9094
9195 state ['step' ] = 0
9296 state ['exp_avg' ] = torch .zeros_like (p .data )
9397 state ['exp_avg_var' ] = torch .zeros_like (p .data )
94- if amsgrad :
98+ if group [ ' amsgrad' ] :
9599 state ['max_exp_avg_var' ] = torch .zeros_like (p .data )
96100
97101 def step (self , closure : CLOSURE = None ) -> LOSS :
@@ -114,14 +118,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114118 if grad .is_sparse :
115119 raise RuntimeError ('AdaBelief does not support sparse gradients' )
116120
117- amsgrad = group ['amsgrad' ]
118-
119121 state = self .state [p ]
120122 if len (state ) == 0 :
121123 state ['step' ] = 0
122124 state ['exp_avg' ] = torch .zeros_like (p .data )
123125 state ['exp_avg_var' ] = torch .zeros_like (p .data )
124- if amsgrad :
126+ if group [ ' amsgrad' ] :
125127 state ['max_exp_avg_var' ] = torch .zeros_like (p .data )
126128
127129 if self .weight_decouple :
@@ -145,7 +147,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145147 grad_residual = grad - exp_avg
146148 exp_avg_var .mul_ (beta2 ).addcmul_ (grad_residual , grad_residual , value = 1.0 - beta2 )
147149
148- if amsgrad :
150+ if group [ ' amsgrad' ] :
149151 max_exp_avg_var = state ['max_exp_avg_var' ]
150152
151153 torch .max (
@@ -159,7 +161,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
159161 denom = (exp_avg_var .add_ (group ['eps' ]).sqrt () / math .sqrt (bias_correction2 )).add_ (group ['eps' ])
160162
161163 if not self .rectify :
162- step_size = group ['lr' ] / bias_correction1
164+ if group ['adamd_debias_term' ]:
165+ step_size = group ['lr' ]
166+ else :
167+ step_size = group ['lr' ] / bias_correction1
168+
163169 p .data .addcdiv_ (exp_avg , denom , value = - step_size )
164170 else :
165171 buffered = group ['buffer' ][int (state ['step' ] % 10 )]
@@ -173,17 +179,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
173179 buffered [1 ] = n_sma
174180
175181 if n_sma >= self .n_sma_threshold :
176- step_size = math .sqrt (
182+ rt = math .sqrt (
177183 (1 - beta2_t )
178184 * (n_sma - 4 )
179185 / (n_sma_max - 4 )
180186 * (n_sma - 2 )
181187 / n_sma
182188 * n_sma_max
183189 / (n_sma_max - 2 )
184- ) / (1 - beta1 ** state ['step' ])
190+ )
191+
192+ if group ['adamd_debias_term' ]:
193+ step_size = rt
194+ else :
195+ step_size = rt / bias_correction1
185196 elif self .degenerated_to_sgd :
186- step_size = 1.0 / ( 1.0 - beta1 ** state [ 'step' ])
197+ step_size = 1.0 / bias_correction1
187198 else :
188199 step_size = - 1
189200
0 commit comments