Skip to content

Commit 09a8b58

Browse files
authored
Merge pull request #38 from kozistr/feature/adamd-optimizer
[Feature] Support AdamD optimizer
2 parents 8b68f23 + 2c12ae8 commit 09a8b58

File tree

10 files changed

+112
-33
lines changed

10 files changed

+112
-33
lines changed

README.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ Supported Optimizers
5858
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
5959
| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | `github <https://github.com/amirgholami/adahessian>`__ | `https://arxiv.org/abs/2006.00719 <https://arxiv.org/abs/2006.00719>`__ |
6060
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
61+
| AdamD | *Improved bias-correction in Adam* | | `https://arxiv.org/abs/2110.10828 <https://arxiv.org/abs/2110.10828>`__ |
62+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
6163
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | `github <https://github.com/clovaai/AdamP>`__ | `https://arxiv.org/abs/2006.08217 <https://arxiv.org/abs/2006.08217>`__ |
6264
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
6365
| diffGrad | *An Optimization Method for Convolutional Neural Networks* | `github <https://github.com/shivram1987/diffGrad>`__ | `https://arxiv.org/abs/1909.11015v3 <https://arxiv.org/abs/1909.11015v3>`__ |
@@ -452,6 +454,17 @@ Gradient Surgery for Multi-Task Learning
452454
year={2020}
453455
}
454456

457+
AdamD: Improved bias-correction in Adam
458+
459+
::
460+
461+
@article{john2021adamd,
462+
title={AdamD: Improved bias-correction in Adam},
463+
author={John, John St},
464+
journal={arXiv preprint arXiv:2110.10828},
465+
year={2021}
466+
}
467+
455468
Author
456469
------
457470

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
from pytorch_optimizer.sgdp import SGDP
2121
from pytorch_optimizer.utils import clip_grad_norm, get_optimizer_parameters, normalize_gradient, unit_norm
2222

23-
__VERSION__ = '0.2.0'
23+
__VERSION__ = '0.2.1'

pytorch_optimizer/adabelief.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
rectify: bool = True,
3636
degenerated_to_sgd: bool = True,
3737
amsgrad: bool = False,
38+
adamd_debias_term: bool = False,
3839
eps: float = 1e-16,
3940
):
4041
"""
@@ -48,6 +49,7 @@ def __init__(
4849
:param rectify: bool. perform the rectified update similar to RAdam
4950
:param degenerated_to_sgd: bool. perform SGD update when variance of gradient is high
5051
:param amsgrad: bool. whether to use the AMSBound variant
52+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
5153
:param eps: float. term added to the denominator to improve numerical stability
5254
"""
5355
self.lr = lr
@@ -58,6 +60,7 @@ def __init__(
5860
self.fixed_decay = fixed_decay
5961
self.rectify = rectify
6062
self.degenerated_to_sgd = degenerated_to_sgd
63+
self.adamd_debias_term = adamd_debias_term
6164
self.eps = eps
6265

6366
buffer: BUFFER = [[None, None, None] for _ in range(10)]
@@ -73,6 +76,7 @@ def __init__(
7376
eps=eps,
7477
weight_decay=weight_decay,
7578
amsgrad=amsgrad,
79+
adamd_debias_term=adamd_debias_term,
7680
buffer=buffer,
7781
)
7882
super().__init__(params, defaults)
@@ -81,17 +85,17 @@ def __setstate__(self, state: STATE):
8185
super().__setstate__(state)
8286
for group in self.param_groups:
8387
group.setdefault('amsgrad', False)
88+
group.setdefault('adamd_debias_term', False)
8489

8590
def reset(self):
8691
for group in self.param_groups:
8792
for p in group['params']:
8893
state = self.state[p]
89-
amsgrad = group['amsgrad']
9094

9195
state['step'] = 0
9296
state['exp_avg'] = torch.zeros_like(p.data)
9397
state['exp_avg_var'] = torch.zeros_like(p.data)
94-
if amsgrad:
98+
if group['amsgrad']:
9599
state['max_exp_avg_var'] = torch.zeros_like(p.data)
96100

97101
def step(self, closure: CLOSURE = None) -> LOSS:
@@ -114,14 +118,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
114118
if grad.is_sparse:
115119
raise RuntimeError('AdaBelief does not support sparse gradients')
116120

117-
amsgrad = group['amsgrad']
118-
119121
state = self.state[p]
120122
if len(state) == 0:
121123
state['step'] = 0
122124
state['exp_avg'] = torch.zeros_like(p.data)
123125
state['exp_avg_var'] = torch.zeros_like(p.data)
124-
if amsgrad:
126+
if group['amsgrad']:
125127
state['max_exp_avg_var'] = torch.zeros_like(p.data)
126128

127129
if self.weight_decouple:
@@ -145,7 +147,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
145147
grad_residual = grad - exp_avg
146148
exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1.0 - beta2)
147149

148-
if amsgrad:
150+
if group['amsgrad']:
149151
max_exp_avg_var = state['max_exp_avg_var']
150152

151153
torch.max(
@@ -159,7 +161,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
159161
denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
160162

161163
if not self.rectify:
162-
step_size = group['lr'] / bias_correction1
164+
if group['adamd_debias_term']:
165+
step_size = group['lr']
166+
else:
167+
step_size = group['lr'] / bias_correction1
168+
163169
p.data.addcdiv_(exp_avg, denom, value=-step_size)
164170
else:
165171
buffered = group['buffer'][int(state['step'] % 10)]
@@ -173,17 +179,22 @@ def step(self, closure: CLOSURE = None) -> LOSS:
173179
buffered[1] = n_sma
174180

175181
if n_sma >= self.n_sma_threshold:
176-
step_size = math.sqrt(
182+
rt = math.sqrt(
177183
(1 - beta2_t)
178184
* (n_sma - 4)
179185
/ (n_sma_max - 4)
180186
* (n_sma - 2)
181187
/ n_sma
182188
* n_sma_max
183189
/ (n_sma_max - 2)
184-
) / (1 - beta1 ** state['step'])
190+
)
191+
192+
if group['adamd_debias_term']:
193+
step_size = rt
194+
else:
195+
step_size = rt / bias_correction1
185196
elif self.degenerated_to_sgd:
186-
step_size = 1.0 / (1.0 - beta1 ** state['step'])
197+
step_size = 1.0 / bias_correction1
187198
else:
188199
step_size = -1
189200

pytorch_optimizer/adabound.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
weight_decouple: bool = True,
3535
fixed_decay: bool = False,
3636
amsbound: bool = False,
37+
adamd_debias_term: bool = False,
3738
eps: float = 1e-8,
3839
):
3940
"""
@@ -46,6 +47,7 @@ def __init__(
4647
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW
4748
:param fixed_decay: bool.
4849
:param amsbound: bool. whether to use the AMSBound variant
50+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4951
:param eps: float. term added to the denominator to improve numerical stability
5052
"""
5153
self.lr = lr
@@ -62,6 +64,7 @@ def __init__(
6264
gamma=gamma,
6365
weight_decay=weight_decay,
6466
amsbound=amsbound,
67+
adamd_debias_term=adamd_debias_term,
6568
eps=eps,
6669
)
6770
super().__init__(params, defaults)
@@ -84,6 +87,7 @@ def __setstate__(self, state: STATE):
8487
super().__setstate__(state)
8588
for group in self.param_groups:
8689
group.setdefault('amsbound', False)
90+
group.setdefault('adamd_debias_term', False)
8791

8892
def step(self, closure: CLOSURE = None) -> LOSS:
8993
loss: LOSS = None
@@ -99,19 +103,17 @@ def step(self, closure: CLOSURE = None) -> LOSS:
99103
if grad.is_sparse:
100104
raise RuntimeError('AdaBound does not support sparse gradients')
101105

102-
amsbound = group['amsbound']
103-
104106
state = self.state[p]
105107

106108
if len(state) == 0:
107109
state['step'] = 0
108110
state['exp_avg'] = torch.zeros_like(p)
109111
state['exp_avg_sq'] = torch.zeros_like(p)
110-
if amsbound:
112+
if group['amsbound']:
111113
state['max_exp_avg_sq'] = torch.zeros_like(p)
112114

113115
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
114-
if amsbound:
116+
if group['amsbound']:
115117
max_exp_avg_sq = state['max_exp_avg_sq']
116118

117119
state['step'] += 1
@@ -129,15 +131,19 @@ def step(self, closure: CLOSURE = None) -> LOSS:
129131

130132
exp_avg.mul_(beta1).add_(1 - beta1, grad)
131133
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
132-
if amsbound:
134+
if group['amsbound']:
133135
max_exp_avg_sq = torch.max(max_exp_avg_sq, exp_avg_sq)
134136
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
135137
else:
136138
denom = exp_avg_sq.sqrt().add_(group['eps'])
137139

138140
bias_correction1 = 1 - beta1 ** state['step']
139141
bias_correction2 = 1 - beta2 ** state['step']
140-
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
142+
143+
if group['adamd_debias_term']:
144+
step_size = group['lr'] * math.sqrt(bias_correction2)
145+
else:
146+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
141147

142148
final_lr = group['final_lr'] * group['lr'] / base_lr
143149
lower_bound = final_lr * (1 - 1 / (group['gamma'] * state['step'] + 1))

pytorch_optimizer/adahessian.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
update_each: int = 1,
3333
num_samples: int = 1,
3434
average_conv_kernel: bool = False,
35+
adamd_debias_term: bool = False,
3536
eps: float = 1e-8,
3637
seed: int = 2147483647,
3738
):
@@ -44,6 +45,7 @@ def __init__(
4445
:param update_each: int. compute the hessian trace approximation only after *this* number of steps
4546
:param num_samples: int. how many times to sample `z` for the approximation of the hessian trace
4647
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
48+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
4749
:param eps: float. term added to the denominator to improve numerical stability
4850
:param seed: int.
4951
"""
@@ -69,6 +71,7 @@ def __init__(
6971
eps=eps,
7072
weight_decay=weight_decay,
7173
hessian_power=hessian_power,
74+
adamd_debias_term=adamd_debias_term,
7275
)
7376
super().__init__(params, defaults)
7477

@@ -179,7 +182,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
179182
hessian_power = group['hessian_power']
180183
denom = (exp_hessian_diag_sq / bias_correction2).pow_(hessian_power / 2).add_(group['eps'])
181184

182-
step_size = group['lr'] / bias_correction1
185+
if group['adamd_debias_term']:
186+
step_size = group['lr']
187+
else:
188+
step_size = group['lr'] / bias_correction1
189+
183190
p.addcdiv_(exp_avg, denom, value=-step_size)
184191

185192
return loss

pytorch_optimizer/adamp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
wd_ratio: float = 0.1,
3636
use_gc: bool = False,
3737
nesterov: bool = False,
38+
adamd_debias_term: bool = False,
3839
eps: float = 1e-8,
3940
):
4041
"""
@@ -47,6 +48,7 @@ def __init__(
4748
on scale-variant parameters
4849
:param use_gc: bool. use gradient centralization
4950
:param nesterov: bool. enables Nesterov momentum
51+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
5052
:param eps: float. term added to the denominator to improve numerical stability
5153
"""
5254
self.lr = lr
@@ -65,6 +67,7 @@ def __init__(
6567
delta=delta,
6668
wd_ratio=wd_ratio,
6769
nesterov=nesterov,
70+
adamd_debias_term=adamd_debias_term,
6871
eps=eps,
6972
)
7073
super().__init__(params, defaults)
@@ -157,10 +160,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
157160
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
158161

159162
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
160-
step_size = group['lr'] / bias_correction1
163+
if group['adamd_debias_term']:
164+
step_size = group['lr']
165+
else:
166+
step_size = group['lr'] / bias_correction1
161167

162-
nesterov = group['nesterov']
163-
if nesterov:
168+
if group['nesterov']:
164169
perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
165170
else:
166171
perturb = exp_avg / denom

pytorch_optimizer/diffgrad.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ def __init__(
2929
betas: BETAS = (0.9, 0.999),
3030
eps: float = 1e-8,
3131
weight_decay: float = 0.0,
32+
adamd_debias_term: bool = False,
3233
):
3334
"""
3435
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
3536
:param lr: float. learning rate.
3637
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
3738
:param eps: float. term added to the denominator to improve numerical stability
3839
:param weight_decay: float. weight decay (L2 penalty)
40+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training
3941
"""
4042
self.lr = lr
4143
self.eps = eps
@@ -44,7 +46,9 @@ def __init__(
4446

4547
self.check_valid_parameters()
4648

47-
defaults: DEFAULTS = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
49+
defaults: DEFAULTS = dict(
50+
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, adamd_debias_term=adamd_debias_term
51+
)
4852
super().__init__(params, defaults)
4953

5054
def check_valid_parameters(self):
@@ -107,7 +111,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
107111
# update momentum with dfc
108112
exp_avg1 = exp_avg * dfc
109113

110-
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
114+
if group['adamd_debias_term']:
115+
step_size = group['lr'] * math.sqrt(bias_correction2)
116+
else:
117+
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
111118

112119
p.data.addcdiv_(-step_size, exp_avg1, denom)
113120

0 commit comments

Comments
 (0)