Skip to content

Commit fab0d29

Browse files
authored
Merge pull request #146 from kozistr/feature/qhadam-optimizer
[Feature] Implement QHAdam, QHM optimizers
2 parents 00598ec + a87b31a commit fab0d29

File tree

9 files changed

+299
-2
lines changed

9 files changed

+299
-2
lines changed

README.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,9 @@ You can check the supported optimizers & lr schedulers.
167167
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
168168
| AdaMod | *An Adaptive and Momental Bound Method for Stochastic Learning* | `github <https://github.com/lancopku/AdaMod>`__ | `https://arxiv.org/abs/1910.12249 <https://arxiv.org/abs/1910.12249>`__ |
169169
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
170-
| AggMo | *Aggregated Momentum: Stability Through Passive Damping* | `github <https://github.com/AtheMathmo/AggMo`__ | `https://arxiv.org/abs/1804.00325 <https://arxiv.org/abs/1804.00325>`__ |
170+
| AggMo | *Aggregated Momentum: Stability Through Passive Damping* | `github <https://github.com/AtheMathmo/AggMo>`__ | `https://arxiv.org/abs/1804.00325 <https://arxiv.org/abs/1804.00325>`__ |
171+
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
172+
| QHAdam | *Quasi-hyperbolic momentum and Adam for deep learning* | `github <https://github.com/facebookresearch/qhoptim>`__ | `https://arxiv.org/abs/1810.06801 <https://arxiv.org/abs/1810.06801>`__ |
171173
+--------------+-------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
172174

173175
Useful Resources
@@ -400,6 +402,8 @@ Citations
400402

401403
`AggMo <https://ui.adsabs.harvard.edu/abs/2018arXiv180400325L/exportcitation>`__
402404

405+
`QHAdam <https://github.com/facebookresearch/qhoptim#reference>`__
406+
403407
Citation
404408
--------
405409

docs/optimizer_api.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,3 +376,19 @@ AggMo
376376

377377
.. autoclass:: pytorch_optimizer.AggMo
378378
:members:
379+
380+
.. _QHAdam:
381+
382+
QHAdam
383+
------
384+
385+
.. autoclass:: pytorch_optimizer.QHAdam
386+
:members:
387+
388+
.. _QHM:
389+
390+
QHM
391+
---
392+
393+
.. autoclass:: pytorch_optimizer.QHM
394+
:members:

pytorch_optimizer/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from pytorch_optimizer.optimizer.novograd import NovoGrad
4848
from pytorch_optimizer.optimizer.pcgrad import PCGrad
4949
from pytorch_optimizer.optimizer.pnm import PNM
50+
from pytorch_optimizer.optimizer.qhadam import QHAdam
51+
from pytorch_optimizer.optimizer.qhm import QHM
5052
from pytorch_optimizer.optimizer.radam import RAdam
5153
from pytorch_optimizer.optimizer.ranger import Ranger
5254
from pytorch_optimizer.optimizer.ranger21 import Ranger21
@@ -95,6 +97,8 @@
9597
DiffGrad,
9698
Lamb,
9799
LARS,
100+
QHAdam,
101+
QHM,
98102
MADGRAD,
99103
Nero,
100104
PNM,

pytorch_optimizer/base/optimizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from typing import Tuple, Union
23

34
import torch
45

@@ -141,6 +142,17 @@ def validate_amplifier(amplifier: float):
141142
if amplifier < 0.0:
142143
raise ValueError(f'[-] amplifier {amplifier} must be non-negative')
143144

145+
@staticmethod
146+
def validate_nus(nus: Union[float, Tuple[float, float]]):
147+
if isinstance(nus, float):
148+
if not 0.0 <= nus <= 1.0:
149+
raise ValueError(f'[-] nus {nus} must be in the range [0, 1]')
150+
else:
151+
if not 0.0 <= nus[0] <= 1.0:
152+
raise ValueError(f'[-] nus1 {nus[0]} must be in the range [0, 1]')
153+
if not 0.0 <= nus[1] <= 1.0:
154+
raise ValueError(f'[-] nus2 {nus[1]} must be in the range [0, 1]')
155+
144156
@abstractmethod
145157
def validate_parameters(self):
146158
raise NotImplementedError
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from torch.optim.optimizer import Optimizer
5+
6+
from pytorch_optimizer.base.exception import NoSparseGradientError
7+
from pytorch_optimizer.base.optimizer import BaseOptimizer
8+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
9+
10+
11+
class QHAdam(Optimizer, BaseOptimizer):
12+
r"""Quasi-hyperbolic momentum and Adam for deep learning.
13+
14+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
15+
:param lr: float. learning rate.
16+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace.
17+
:param nus: Tuple[float, float]. immediate discount factors used to estimate the gradient and its square.
18+
:param weight_decay: float. weight decay (L2 penalty).
19+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
20+
:param eps: float. term added to the denominator to improve numerical stability.
21+
"""
22+
23+
def __init__(
24+
self,
25+
params: PARAMETERS,
26+
lr: float = 1e-3,
27+
betas: BETAS = (0.9, 0.999),
28+
nus: Tuple[float, float] = (1.0, 1.0),
29+
weight_decay: float = 0.0,
30+
weight_decouple: bool = False,
31+
eps: float = 1e-8,
32+
):
33+
self.lr = lr
34+
self.betas = betas
35+
self.nus = nus
36+
self.weight_decay = weight_decay
37+
self.eps = eps
38+
39+
self.validate_parameters()
40+
41+
defaults: DEFAULTS = {
42+
'lr': lr,
43+
'betas': betas,
44+
'nus': nus,
45+
'weight_decay': weight_decay,
46+
'weight_decouple': weight_decouple,
47+
'eps': eps,
48+
}
49+
super().__init__(params, defaults)
50+
51+
def validate_parameters(self):
52+
self.validate_learning_rate(self.lr)
53+
self.validate_betas(self.betas)
54+
self.validate_weight_decay(self.weight_decay)
55+
self.validate_epsilon(self.eps)
56+
self.validate_nus(self.nus)
57+
58+
def __str__(self) -> str:
59+
return 'QHAdam'
60+
61+
@torch.no_grad()
62+
def reset(self):
63+
for group in self.param_groups:
64+
group['step'] = 0
65+
for p in group['params']:
66+
state = self.state[p]
67+
68+
state['beta1_weight'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
69+
state['beta2_weight'] = torch.zeros((1,), dtype=p.dtype, device=p.device)
70+
state['exp_avg'] = torch.zeros_like(p)
71+
state['exp_avg_sq'] = torch.zeros_like(p)
72+
73+
@torch.no_grad()
74+
def step(self, closure: CLOSURE = None) -> LOSS:
75+
loss: LOSS = None
76+
if closure is not None:
77+
with torch.enable_grad():
78+
loss = closure()
79+
80+
for group in self.param_groups:
81+
if 'step' in group:
82+
group['step'] += 1
83+
else:
84+
group['step'] = 1
85+
86+
beta1, beta2 = group['betas']
87+
nu1, nu2 = group['nus']
88+
89+
for p in group['params']:
90+
if p.grad is None:
91+
continue
92+
93+
grad = p.grad
94+
if grad.is_sparse:
95+
raise NoSparseGradientError(str(self))
96+
97+
state = self.state[p]
98+
99+
if len(state) == 0:
100+
state['beta1_weight'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
101+
state['beta2_weight'] = torch.zeros((1,), dtype=grad.dtype, device=grad.device)
102+
state['exp_avg'] = torch.zeros_like(p)
103+
state['exp_avg_sq'] = torch.zeros_like(p)
104+
105+
if group['weight_decouple']:
106+
p.mul_(1.0 - group['weight_decay'] * group['lr'])
107+
elif group['weight_decay'] > 0.0:
108+
grad.add_(p, alpha=group['weight_decay'])
109+
110+
beta1_weight, beta2_weight = state['beta1_weight'], state['beta2_weight']
111+
beta1_weight.mul_(beta1).add_(1.0)
112+
beta2_weight.mul_(beta2).add_(1.0)
113+
114+
beta1_adj = 1.0 - (1.0 / beta1_weight)
115+
beta2_adj = 1.0 - (1.0 / beta2_weight)
116+
117+
grad_p2 = grad.pow(2)
118+
119+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
120+
exp_avg.mul_(beta1_adj).add_((1.0 - beta1_adj) * grad)
121+
exp_avg_sq.mul_(beta2_adj).add_(1.0 - beta2_adj * grad_p2)
122+
123+
avg_grad = exp_avg.mul(nu1)
124+
if nu1 != 1.0:
125+
avg_grad.add_(grad, alpha=1.0 - nu1)
126+
127+
avg_grad_rms = exp_avg_sq.mul(nu2)
128+
if nu2 != 1.0:
129+
avg_grad_rms.add_(grad_p2, alpha=1.0 - nu2)
130+
131+
avg_grad_rms.sqrt_().add_(group['eps'])
132+
133+
p.addcdiv_(avg_grad, avg_grad_rms, value=-group['lr'])
134+
135+
return loss

pytorch_optimizer/optimizer/qhm.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base.exception import NoSparseGradientError
5+
from pytorch_optimizer.base.optimizer import BaseOptimizer
6+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
8+
9+
class QHM(Optimizer, BaseOptimizer):
10+
r"""Quasi-hyperbolic momentum (QHM) optimization algorithm.
11+
12+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
13+
:param lr: float. learning rate.
14+
:param momentum: float. momentum factor.
15+
:param nu: float. immediate discount factor used to estimate the gradient and its square.
16+
:param weight_decay: float. weight decay (L2 penalty).
17+
:param weight_decouple: bool. the optimizer uses decoupled weight decay as in AdamW.
18+
:param eps: float. term added to the denominator to improve numerical stability.
19+
"""
20+
21+
def __init__(
22+
self,
23+
params: PARAMETERS,
24+
lr: float = 1e-3,
25+
momentum: float = 0.0,
26+
nu: float = 1.0,
27+
weight_decay: float = 0.0,
28+
weight_decouple: bool = False,
29+
):
30+
self.lr = lr
31+
self.momentum = momentum
32+
self.nu = nu
33+
self.weight_decay = weight_decay
34+
35+
self.validate_parameters()
36+
37+
defaults: DEFAULTS = {
38+
'lr': lr,
39+
'momentum': momentum,
40+
'nu': nu,
41+
'weight_decay': weight_decay,
42+
'weight_decouple': weight_decouple,
43+
}
44+
super().__init__(params, defaults)
45+
46+
def validate_parameters(self):
47+
self.validate_learning_rate(self.lr)
48+
self.validate_momentum(self.momentum)
49+
self.validate_weight_decay(self.weight_decay)
50+
self.validate_nus(self.nu)
51+
52+
def __str__(self) -> str:
53+
return 'QHM'
54+
55+
@torch.no_grad()
56+
def reset(self):
57+
for group in self.param_groups:
58+
group['step'] = 0
59+
for p in group['params']:
60+
state = self.state[p]
61+
62+
state['momentum_buffer'] = torch.zeros_like(p)
63+
64+
@torch.no_grad()
65+
def step(self, closure: CLOSURE = None) -> LOSS:
66+
loss: LOSS = None
67+
if closure is not None:
68+
with torch.enable_grad():
69+
loss = closure()
70+
71+
for group in self.param_groups:
72+
if 'step' in group:
73+
group['step'] += 1
74+
else:
75+
group['step'] = 1
76+
77+
for p in group['params']:
78+
if p.grad is None:
79+
continue
80+
81+
grad = p.grad
82+
if grad.is_sparse:
83+
raise NoSparseGradientError(str(self))
84+
85+
state = self.state[p]
86+
87+
if len(state) == 0:
88+
state['momentum_buffer'] = torch.zeros_like(p)
89+
90+
if group['weight_decouple']:
91+
p.mul_(1.0 - group['weight_decay'] * group['lr'])
92+
elif group['weight_decay'] > 0.0:
93+
grad.add_(p, alpha=group['weight_decay'])
94+
95+
buf = state['momentum_buffer']
96+
buf.mul_(group['momentum']).add_(grad, alpha=1.0 - group['momentum'])
97+
98+
p.add_(buf, alpha=-group['lr'] * group['nu'])
99+
p.add_(grad, alpha=-group['lr'] * (1.0 - group['nu']))
100+
101+
return loss

tests/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
MSVAG,
88
OPTIMIZERS,
99
PNM,
10+
QHM,
1011
SGDP,
1112
SGDW,
1213
SM3,
@@ -35,6 +36,7 @@
3536
Lion,
3637
Nero,
3738
NovoGrad,
39+
QHAdam,
3840
RAdam,
3941
Ranger,
4042
Ranger21,
@@ -85,6 +87,9 @@
8587
'adanorm',
8688
'yogi',
8789
'swats',
90+
'adamod',
91+
'aggmo',
92+
'qhadam',
8893
]
8994

9095
VALID_LR_SCHEDULER_NAMES: List[str] = [
@@ -337,6 +342,10 @@
337342
(AdaMod, {'lr': 5e1, 'weight_decay': 1e-3, 'weight_decouple': False}, 10),
338343
(AggMo, {'lr': 5e0, 'weight_decay': 1e-3}, 5),
339344
(AggMo, {'lr': 5e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
345+
(QHAdam, {'lr': 1e0, 'nus': (0.9, 0.9), 'weight_decay': 1e-3}, 5),
346+
(QHAdam, {'lr': 1e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
347+
(QHM, {'lr': 1e0, 'weight_decay': 1e-3}, 5),
348+
(QHM, {'lr': 1e0, 'weight_decay': 1e-3, 'weight_decouple': True}, 5),
340349
]
341350
ADANORM_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
342351
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'adanorm': True}, 10),

tests/test_general_optimizer_parameters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_epsilon(optimizer_name):
3535
'fromage',
3636
'msvag',
3737
'aggmo',
38+
'qhm',
3839
):
3940
pytest.skip(f'skip {optimizer_name} optimizer')
4041

@@ -218,3 +219,18 @@ def test_amplifier(optimizer_name):
218219
optimizer = load_optimizer(optimizer_name)
219220
with pytest.raises(ValueError):
220221
optimizer([simple_parameter(False)], amplifier=-1.0)
222+
223+
224+
@pytest.mark.parametrize('optimizer_name', ['qhadam', 'qhm'])
225+
def test_nus(optimizer_name):
226+
optimizer = load_optimizer(optimizer_name)
227+
228+
if optimizer_name == 'qhadam':
229+
with pytest.raises(ValueError):
230+
optimizer([simple_parameter(False)], nus=(-0.1, 0.1))
231+
232+
with pytest.raises(ValueError):
233+
optimizer([simple_parameter(False)], nus=(0.1, -0.1))
234+
else:
235+
with pytest.raises(ValueError):
236+
optimizer([simple_parameter(False)], nu=-0.1)

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 40
19+
assert len(get_supported_optimizers()) == 42

0 commit comments

Comments
 (0)