@@ -251,7 +251,7 @@ class DAdaptAdam(Optimizer, BaseOptimizer):
251251 :param weight_decay: float. weight decay (L2 penalty).
252252 :param weight_decouple: bool. use AdamW style weight decay.
253253 :param fixed_decay: bool. fix weight decay.
254- :param adam_debias : bool. Only correct the denominator to avoid inflating step sizes early in training .
254+ :param bias_correction : bool. Turn on Adam's bias correction .
255255 :param eps: float. term added to the denominator to improve numerical stability.
256256 """
257257
@@ -265,7 +265,7 @@ def __init__(
265265 weight_decay : float = 0.0 ,
266266 weight_decouple : bool = False ,
267267 fixed_decay : bool = False ,
268- adam_debias : bool = False ,
268+ bias_correction : bool = False ,
269269 eps : float = 0.0 ,
270270 ):
271271 self .validate_learning_rate (lr )
@@ -281,7 +281,7 @@ def __init__(
281281 'weight_decay' : weight_decay ,
282282 'weight_decouple' : weight_decouple ,
283283 'fixed_decay' : fixed_decay ,
284- 'adam_debias ' : adam_debias ,
284+ 'bias_correction ' : bias_correction ,
285285 'step' : 0 ,
286286 'eps' : eps ,
287287 }
@@ -321,8 +321,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
321321 d : float = group ['d' ]
322322 lr : float = group ['lr' ]
323323
324- bias_correction : float = 1.0 - pow (beta1 , group ['step' ] + 1 )
325- d_lr : float = self .apply_adam_debias (group ['adam_debias' ], step_size = d * lr , bias_correction1 = bias_correction )
324+ bias_correction1 : float = 1.0 - beta1 ** (group ['step' ] + 1 )
325+ bias_correction2_sq : float = math .sqrt (1.0 - beta2 ** (group ['step' ] + 1 ))
326+ bias_correction : float = bias_correction1 / bias_correction2_sq
327+
328+ # it's not Adam Debias
329+ d_lr : float = self .apply_adam_debias (
330+ group ['bias_correction' ], step_size = d * lr , bias_correction1 = bias_correction
331+ )
326332
327333 sk_l1 = torch .tensor ([0.0 ], device = device )
328334 numerator_acc = torch .tensor ([0.0 ], device = device )
0 commit comments