Skip to content

Commit fbb2cdb

Browse files
committed
fix: bias_correction
1 parent ed97a2f commit fbb2cdb

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class DAdaptAdam(Optimizer, BaseOptimizer):
251251
:param weight_decay: float. weight decay (L2 penalty).
252252
:param weight_decouple: bool. use AdamW style weight decay.
253253
:param fixed_decay: bool. fix weight decay.
254-
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
254+
:param bias_correction: bool. Turn on Adam's bias correction.
255255
:param eps: float. term added to the denominator to improve numerical stability.
256256
"""
257257

@@ -265,7 +265,7 @@ def __init__(
265265
weight_decay: float = 0.0,
266266
weight_decouple: bool = False,
267267
fixed_decay: bool = False,
268-
adam_debias: bool = False,
268+
bias_correction: bool = False,
269269
eps: float = 0.0,
270270
):
271271
self.validate_learning_rate(lr)
@@ -281,7 +281,7 @@ def __init__(
281281
'weight_decay': weight_decay,
282282
'weight_decouple': weight_decouple,
283283
'fixed_decay': fixed_decay,
284-
'adam_debias': adam_debias,
284+
'bias_correction': bias_correction,
285285
'step': 0,
286286
'eps': eps,
287287
}
@@ -321,8 +321,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
321321
d: float = group['d']
322322
lr: float = group['lr']
323323

324-
bias_correction: float = 1.0 - pow(beta1, group['step'] + 1)
325-
d_lr: float = self.apply_adam_debias(group['adam_debias'], step_size=d * lr, bias_correction1=bias_correction)
324+
bias_correction1: float = 1.0 - beta1 ** (group['step'] + 1)
325+
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] + 1))
326+
bias_correction: float = bias_correction1 / bias_correction2_sq
327+
328+
# it's not Adam Debias
329+
d_lr: float = self.apply_adam_debias(
330+
group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
331+
)
326332

327333
sk_l1 = torch.tensor([0.0], device=device)
328334
numerator_acc = torch.tensor([0.0], device=device)

0 commit comments

Comments
 (0)