Skip to content

Commit 50f2bac

Browse files
authored
Merge pull request #24 from kozistr/feature/improve-madgrad-optimizer
[Feature] Improve MADGRAD optimizer
2 parents 8fb1be6 + e7af455 commit 50f2bac

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@
1313
from pytorch_optimizer.sam import SAM
1414
from pytorch_optimizer.sgdp import SGDP
1515

16-
__VERSION__ = '0.0.6'
16+
__VERSION__ = '0.0.7'

pytorch_optimizer/madgrad.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
class MADGRAD(Optimizer):
1515
"""
16-
Reference : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
16+
Reference 1 : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
17+
Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/madgrad/madgrad_wd.py
1718
Example :
1819
from pytorch_optimizer import MADGRAD
1920
...
@@ -35,11 +36,13 @@ def __init__(
3536
weight_decay: float = 0.0,
3637
eps: float = 1e-6,
3738
):
38-
"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
39+
"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified)
3940
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
4041
:param lr: float. learning rate.
4142
:param eps: float. term added to the denominator to improve numerical stability
4243
:param weight_decay: float. weight decay (L2 penalty)
44+
MADGRAD optimizer requires less weight decay than other methods, often as little as zero
45+
On sparse problems both weight_decay and momentum should be set to 0.
4346
"""
4447
self.lr = lr
4548
self.momentum = momentum
@@ -72,11 +75,6 @@ def supports_flat_params(self) -> bool:
7275
return True
7376

7477
def step(self, closure: CLOSURE = None) -> LOSS:
75-
"""Performs a single optimization step.
76-
Arguments:
77-
closure (callable, optional): A closure that reevaluates the model
78-
and returns the loss.
79-
"""
8078
loss: LOSS = None
8179
if closure is not None:
8280
loss = closure()
@@ -124,7 +122,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
124122
'weight_decay option is not compatible with sparse gradients'
125123
)
126124

127-
grad.add_(p.data, alpha=decay)
125+
# original implementation
126+
# grad.add_(p.data, alpha=decay)
127+
128+
# Apply weight decay - L2 / AdamW style
129+
p.data.mul_(1 - lr * decay)
128130

129131
if grad.is_sparse:
130132
grad = grad.coalesce()
@@ -174,16 +176,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
174176
grad_sum_sq.addcmul_(grad, grad, value=_lambda)
175177
rms = grad_sum_sq.pow(1 / 3).add_(eps)
176178

177-
# Update s
178179
s.data.add_(grad, alpha=_lambda)
179180

180-
# Step
181181
if momentum == 0:
182182
p.data.copy_(x0.addcdiv(s, rms, value=-1))
183183
else:
184184
z = x0.addcdiv(s, rms, value=-1)
185185

186-
# p is a moving average of z
187186
p.data.mul_(1 - ck).add_(z, alpha=ck)
188187

189188
self.state['k'] += 1

0 commit comments

Comments
 (0)