Skip to content

Commit 00598ec

Browse files
authored
Merge pull request #145 from kozistr/feature/aggmo-optimizer
[Feature] Implement AggMo optimizer
2 parents deb21a0 + 42aa017 commit 00598ec

File tree

7 files changed

+117
-2
lines changed

7 files changed

+117
-2
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ You can check the supported optimizers & lr schedulers.
167167
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
168168
| AdaMod | *An Adaptive and Momental Bound Method for Stochastic Learning* | `github <https://github.com/lancopku/AdaMod>`__ | `https://arxiv.org/abs/1910.12249 <https://arxiv.org/abs/1910.12249>`__ |
169169
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
170+
| AggMo | *Aggregated Momentum: Stability Through Passive Damping* | `github <https://github.com/AtheMathmo/AggMo`__ | `https://arxiv.org/abs/1804.00325 <https://arxiv.org/abs/1804.00325>`__ |
171+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
170172

171173
Useful Resources
172174
----------------
@@ -396,6 +398,8 @@ Citations
396398

397399
`AdaMod <https://github.com/lancopku/AdaMod#citation>`__
398400

401+
`AggMo <https://ui.adsabs.harvard.edu/abs/2018arXiv180400325L/exportcitation>`__
402+
399403
Citation
400404
--------
401405

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,11 @@ AdaMod
368368

369369
.. autoclass:: pytorch_optimizer.AdaMod
370370
:members:
371+
372+
.. _AggMo:
373+
374+
AggMo
375+
-----
376+
377+
.. autoclass:: pytorch_optimizer.AggMo
378+
:members:

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_optimizer.optimizer.adanorm import AdaNorm
2929
from pytorch_optimizer.optimizer.adapnm import AdaPNM
3030
from pytorch_optimizer.optimizer.agc import agc
31+
from pytorch_optimizer.optimizer.aggmo import AggMo
3132
from pytorch_optimizer.optimizer.alig import AliG
3233
from pytorch_optimizer.optimizer.apollo import Apollo
3334
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptSGD
@@ -106,6 +107,7 @@
106107
ScalableShampoo,
107108
DAdaptAdaGrad,
108109
Fromage,
110+
AggMo,
109111
DAdaptAdam,
110112
DAdaptSGD,
111113
DAdaptAdan,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class AggMo(Optimizer, BaseOptimizer):
10+
r"""Aggregated Momentum: Stability Through Passive Damping.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
15+
:param weight_decay: float. weight decay (L2 penalty).
16+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
17+
"""
18+
19+
def __init__(
20+
self,
21+
params: PARAMETERS,
22+
lr: float = 1e-3,
23+
betas: BETAS = (0.0, 0.9, 0.99),
24+
weight_decay: float = 0.0,
25+
weight_decouple: bool = False,
26+
):
27+
self.lr = lr
28+
self.betas = betas
29+
self.weight_decay = weight_decay
30+
31+
self.validate_parameters()
32+
33+
defaults: DEFAULTS = {
34+
'lr': lr,
35+
'betas': betas,
36+
'weight_decay': weight_decay,
37+
'weight_decouple': weight_decouple,
38+
}
39+
super().__init__(params, defaults)
40+
41+
def validate_parameters(self):
42+
self.validate_learning_rate(self.lr)
43+
self.validate_betas(self.betas)
44+
self.validate_weight_decay(self.weight_decay)
45+
46+
def __str__(self) -> str:
47+
return 'AggMo'
48+
49+
@torch.no_grad()
50+
def reset(self):
51+
for group in self.param_groups:
52+
group['step'] = 0
53+
for p in group['params']:
54+
state = self.state[p]
55+
56+
state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in group['betas']}
57+
58+
@torch.no_grad()
59+
def step(self, closure: CLOSURE = None) -> LOSS:
60+
loss: LOSS = None
61+
if closure is not None:
62+
with torch.enable_grad():
63+
loss = closure()
64+
65+
for group in self.param_groups:
66+
if 'step' in group:
67+
group['step'] += 1
68+
else:
69+
group['step'] = 1
70+
71+
betas = group['betas']
72+
73+
for p in group['params']:
74+
if p.grad is None:
75+
continue
76+
77+
grad = p.grad
78+
if grad.is_sparse:
79+
raise NoSparseGradientError(str(self))
80+
81+
state = self.state[p]
82+
83+
if len(state) == 0:
84+
state['momentum_buffer'] = {beta: torch.zeros_like(p) for beta in betas}
85+
86+
if group['weight_decouple']:
87+
p.mul_(1.0 - group['weight_decay'] * group['lr'])
88+
elif group['weight_decay'] > 0.0:
89+
grad.add_(p, alpha=group['weight_decay'])
90+
91+
for beta in betas:
92+
buf = state['momentum_buffer'][beta]
93+
buf.mul_(beta).add_(grad)
94+
95+
p.add_(buf, alpha=-group['lr'] / len(betas))
96+
97+
return loss

tests/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Adan,
2323
AdaNorm,
2424
AdaPNM,
25+
AggMo,
2526
AliG,
2627
Apollo,
2728
DAdaptAdaGrad,
@@ -334,6 +335,8 @@
334335
(MSVAG, {'lr': 5e-1}, 10),
335336
(AdaMod, {'lr': 5e1, 'weight_decay': 1e-3}, 10),
336337
(AdaMod, {'lr': 5e1, 'weight_decay': 1e-3, 'weight_decouple': False}, 10),
338+
(AggMo, {'lr': 5e0, 'weight_decay': 1e-3}, 5),
339+
(AggMo, {'lr': 5e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
337340
]
338341
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
339342
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_general_optimizer_parameters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_epsilon(optimizer_name):
3434
'sgdw',
3535
'fromage',
3636
'msvag',
37+
'aggmo',
3738
):
3839
pytest.skip(f'skip {optimizer_name} optimizer')
3940

@@ -134,7 +135,7 @@ def test_betas(optimizer_name):
134135
config1.update({'num_iterations': 100})
135136
config2.update({'num_iterations': 100})
136137

137-
if optimizer_name not in ('adapnm', 'adan', 'adamod'):
138+
if optimizer_name not in ('adapnm', 'adan', 'adamod', 'aggmo'):
138139
with pytest.raises(ValueError):
139140
optimizer(None, **config1)
140141

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 39
19+
assert len(get_supported_optimizers()) == 40

0 commit comments

Comments
 (0)