Skip to content

Commit 7bce7c2

Browse files
authored
Merge pull request #109 from kozistr/feature/novograd-optimizer
[Feature] Implement NovoGrad optimizer
2 parents 82c10c3 + 0bfbbcc commit 7bce7c2

File tree

6 files changed

+148
-1
lines changed

6 files changed

+148
-1
lines changed

README.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ Supported Optimizers
120120
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
121121
| Apollo | *An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization* | `github <https://github.com/XuezheMax/apollo>`__ | `https://arxiv.org/abs/2009.13586 <https://arxiv.org/abs/2009.13586>`__ |
122122
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
123+
| NovoGrad | *Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks* | `github <https://github.com/lonePatient/NovoGrad-pytorch>`__ | `https://arxiv.org/abs/1905.11286 <https://arxiv.org/abs/1905.11286>`__ |
124+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
123125

124126
Useful Resources
125127
----------------
@@ -319,6 +321,8 @@ Citations
319321

320322
`Apollo <https://ui.adsabs.harvard.edu/abs/2020arXiv200913586M/exportcitation>`__
321323

324+
`NovoGrad <https://ui.adsabs.harvard.edu/abs/2019arXiv190511286G/exportcitation>`__
325+
322326
Citation
323327
--------
324328

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,11 @@ Apollo
256256

257257
.. autoclass:: pytorch_optimizer.Apollo
258258
:members:
259+
260+
.. _NovoGrad:
261+
262+
NovoGrad
263+
--------
264+
265+
.. autoclass:: pytorch_optimizer.NovoGrad
266+
:members:

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pytorch_optimizer.optimizer.lookahead import Lookahead
3636
from pytorch_optimizer.optimizer.madgrad import MADGRAD
3737
from pytorch_optimizer.optimizer.nero import Nero
38+
from pytorch_optimizer.optimizer.novograd import NovoGrad
3839
from pytorch_optimizer.optimizer.pcgrad import PCGrad
3940
from pytorch_optimizer.optimizer.pnm import PNM
4041
from pytorch_optimizer.optimizer.radam import RAdam
@@ -96,6 +97,7 @@
9697
AdamS,
9798
AdaFactor,
9899
Apollo,
100+
NovoGrad,
99101
]
100102
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
101103

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import math
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
10+
11+
class NovoGrad(Optimizer, BaseOptimizer):
12+
r"""Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param lr: float. learning rate.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param weight_decay: float. weight decay (L2 penalty).
18+
:param grad_averaging: bool. multiply ck (1 - momentum).
19+
:param adamd_debias_term: bool. Only correct the denominator to avoid inflating step sizes early in training.
20+
:param eps: float. term added to the denominator to improve numerical stability.
21+
"""
22+
23+
def __init__(
24+
self,
25+
params: PARAMETERS,
26+
lr: float = 1e-3,
27+
betas: BETAS = (0.95, 0.98),
28+
weight_decay: float = 0.0,
29+
grad_averaging: bool = False,
30+
adamd_debias_term: bool = False,
31+
eps: float = 1e-8,
32+
):
33+
self.lr = lr
34+
self.betas = betas
35+
self.weight_decay = weight_decay
36+
self.grad_averaging = grad_averaging
37+
self.adamd_debias_term = adamd_debias_term
38+
self.eps = eps
39+
40+
self.validate_parameters()
41+
42+
defaults: DEFAULTS = {
43+
'lr': lr,
44+
'betas': betas,
45+
'weight_decay': weight_decay,
46+
'eps': eps,
47+
}
48+
super().__init__(params, defaults)
49+
50+
def validate_parameters(self):
51+
self.validate_learning_rate(self.lr)
52+
self.validate_betas(self.betas)
53+
self.validate_weight_decay(self.weight_decay)
54+
self.validate_epsilon(self.eps)
55+
56+
@property
57+
def __str__(self) -> str:
58+
return 'NovoGrad'
59+
60+
@torch.no_grad()
61+
def reset(self):
62+
for group in self.param_groups:
63+
group['step'] = 0
64+
for p in group['params']:
65+
state = self.state[p]
66+
67+
grad = p.grad
68+
g_2 = grad ** 2 # fmt: skip
69+
70+
state['step'] = 0
71+
state['moments'] = grad.div(g_2.sqrt() + group['eps']) + group['weight_decay'] * p
72+
state['grads_ema'] = g_2
73+
74+
@torch.no_grad()
75+
def step(self, closure: CLOSURE = None) -> LOSS:
76+
loss: LOSS = None
77+
if closure is not None:
78+
with torch.enable_grad():
79+
loss = closure()
80+
81+
for group in self.param_groups:
82+
if 'step' in group:
83+
group['step'] += 1
84+
else:
85+
group['step'] = 1
86+
87+
beta1, beta2 = group['betas']
88+
weight_decay = group['weight_decay']
89+
90+
bias_correction1 = 1.0 - beta1 ** group['step']
91+
bias_correction2_sq = math.sqrt(1.0 - beta2 ** group['step'])
92+
93+
step_size: float = group['lr'] * bias_correction2_sq
94+
if not self.adamd_debias_term:
95+
step_size /= bias_correction1
96+
97+
for p in group['params']:
98+
if p.grad is None:
99+
continue
100+
101+
grad = p.grad
102+
if grad.is_sparse:
103+
raise NoSparseGradientError(self.__str__)
104+
105+
state = self.state[p]
106+
g_2 = grad ** 2 # fmt: skip
107+
108+
if len(state) == 0:
109+
state['moments'] = grad.div(g_2.sqrt() + group['eps']) + weight_decay * p
110+
state['grads_ema'] = g_2
111+
112+
moments, grads_ema = state['moments'], state['grads_ema']
113+
114+
grads_ema.mul_(beta2).add_(g_2, alpha=1.0 - beta2)
115+
116+
de_nom = grads_ema.sqrt().add_(group['eps'])
117+
grad.div_(de_nom)
118+
119+
if weight_decay > 0.0:
120+
grad.add_(p, alpha=weight_decay)
121+
122+
if self.grad_averaging:
123+
grad.mul_(1.0 - beta1)
124+
125+
moments.mul_(beta1).add_(grad)
126+
127+
p.add_(moments, alpha=-step_size)
128+
129+
return loss

tests/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
DiffRGrad,
2323
Lamb,
2424
Nero,
25+
NovoGrad,
2526
RAdam,
2627
RaLamb,
2728
Ranger,
@@ -68,6 +69,7 @@
6869
'dadaptadam',
6970
'adams',
7071
'adafactor',
72+
'novograd',
7173
]
7274

7375
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -158,6 +160,7 @@
158160
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
159161
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
160162
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),
163+
(NovoGrad, {'lr': 5e-1, 'weight_decay': 1e-3, 'grad_averaging': True}, 50),
161164
]
162165
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
163166
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),
@@ -172,4 +175,5 @@
172175
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True, 'num_iterations': 200}, 200),
173176
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
174177
(AdamS, {'lr': 2e1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
178+
(NovoGrad, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 50),
175179
]

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 26
19+
assert len(get_supported_optimizers()) == 27

0 commit comments

Comments
 (0)