Skip to content

Commit 2e97e5f

Browse files
authored
Merge pull request #159 from kozistr/fix/bias-correction
[Fix] bias correction in D-Adaptation Adam v3 optimizer
2 parents ed97a2f + f5bf44e commit 2e97e5f

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

pytorch_optimizer/optimizer/dadapt.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ class DAdaptAdam(Optimizer, BaseOptimizer):
251251
:param weight_decay: float. weight decay (L2 penalty).
252252
:param weight_decouple: bool. use AdamW style weight decay.
253253
:param fixed_decay: bool. fix weight decay.
254-
:param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training.
254+
:param bias_correction: bool. Turn on Adam's bias correction.
255255
:param eps: float. term added to the denominator to improve numerical stability.
256256
"""
257257

@@ -265,7 +265,7 @@ def __init__(
265265
weight_decay: float = 0.0,
266266
weight_decouple: bool = False,
267267
fixed_decay: bool = False,
268-
adam_debias: bool = False,
268+
bias_correction: bool = False,
269269
eps: float = 0.0,
270270
):
271271
self.validate_learning_rate(lr)
@@ -281,7 +281,7 @@ def __init__(
281281
'weight_decay': weight_decay,
282282
'weight_decouple': weight_decouple,
283283
'fixed_decay': fixed_decay,
284-
'adam_debias': adam_debias,
284+
'bias_correction': bias_correction,
285285
'step': 0,
286286
'eps': eps,
287287
}
@@ -321,8 +321,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
321321
d: float = group['d']
322322
lr: float = group['lr']
323323

324-
bias_correction: float = 1.0 - pow(beta1, group['step'] + 1)
325-
d_lr: float = self.apply_adam_debias(group['adam_debias'], step_size=d * lr, bias_correction1=bias_correction)
324+
bias_correction1: float = 1.0 - beta1 ** (group['step'] + 1)
325+
bias_correction2_sq: float = math.sqrt(1.0 - beta2 ** (group['step'] + 1))
326+
bias_correction: float = bias_correction1 / bias_correction2_sq
327+
328+
# it's not Adam Debias
329+
d_lr: float = self.apply_adam_debias(
330+
group['bias_correction'], step_size=d * lr, bias_correction1=bias_correction
331+
)
326332

327333
sk_l1 = torch.tensor([0.0], device=device)
328334
numerator_acc = torch.tensor([0.0], device=device)

tests/constants.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,12 @@
310310
(Adan, {'lr': 5e-1, 'max_grad_norm': 1.0}, 5),
311311
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 5),
312312
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
313-
(DAdaptAdaGrad, {'lr': 2e0, 'weight_decay': 1e-3}, 50),
314-
(DAdaptAdaGrad, {'lr': 2e0, 'weight_decay': 1e-3, 'momentum': 0.1}, 50),
315-
(DAdaptAdam, {'lr': 5e2, 'weight_decay': 1e-3}, 25),
313+
(DAdaptAdaGrad, {'lr': 3e0, 'weight_decay': 1e-3}, 30),
314+
(DAdaptAdaGrad, {'lr': 5e0, 'weight_decay': 1e-3, 'momentum': 0.1}, 20),
315+
(DAdaptAdam, {'lr': 5e4, 'weight_decay': 1e-1}, 10),
316316
(DAdaptSGD, {'lr': 2e0, 'weight_decay': 1e-3}, 25),
317-
(DAdaptAdan, {'lr': 1e0, 'weight_decay': 1e-2}, 25),
318-
(DAdaptAdan, {'lr': 1e0, 'weight_decay': 1e-2, 'weight_decouple': True}, 50),
317+
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3}, 20),
318+
(DAdaptAdan, {'lr': 2e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 20),
319319
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3}, 10),
320320
(AdamS, {'lr': 1e0, 'weight_decay': 1e-3, 'ams_bound': True}, 20),
321321
(AdaFactor, {'lr': 7.5e-1, 'weight_decay': 1e-3, 'scale_parameter': False}, 100),

0 commit comments

Comments
 (0)