Skip to content

Commit 82c10c3

Browse files
authored
Merge pull request #108 from kozistr/feature/apollo-optimizer
[Feature] Implement Apollo optimizer
2 parents 32b590f + ed33e31 commit 82c10c3

File tree

11 files changed

+267
-110
lines changed

11 files changed

+267
-110
lines changed

README.rst

Lines changed: 43 additions & 39 deletions
Large diffs are not rendered by default.

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,11 @@ AdaFactor
248248

249249
.. autoclass:: pytorch_optimizer.AdaFactor
250250
:members:
251+
252+
.. _Apollo:
253+
254+
Apollo
255+
------
256+
257+
.. autoclass:: pytorch_optimizer.Apollo
258+
:members:

pytorch_optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytorch_optimizer.optimizer.adan import Adan
2424
from pytorch_optimizer.optimizer.adapnm import AdaPNM
2525
from pytorch_optimizer.optimizer.agc import agc
26+
from pytorch_optimizer.optimizer.apollo import Apollo
2627
from pytorch_optimizer.optimizer.dadapt import DAdaptAdaGrad, DAdaptAdam, DAdaptSGD
2728
from pytorch_optimizer.optimizer.diffgrad import DiffGrad
2829
from pytorch_optimizer.optimizer.diffrgrad import DiffRGrad
@@ -94,6 +95,7 @@
9495
DAdaptSGD,
9596
AdamS,
9697
AdaFactor,
98+
Apollo,
9799
]
98100
OPTIMIZERS: Dict[str, OPTIMIZER] = {str(optimizer.__name__).lower(): optimizer for optimizer in OPTIMIZER_LIST}
99101

pytorch_optimizer/base/optimizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ def validate_weight_decay(weight_decay: float):
4242
if weight_decay < 0.0:
4343
raise ValueError(f'[-] weight_decay {weight_decay} must be non-negative')
4444

45+
@staticmethod
46+
def validate_weight_decay_type(weight_decay_type: str):
47+
if weight_decay_type not in ('l2', 'decoupled', 'stable'):
48+
raise ValueError(
49+
f'[-] weight_decay_type {weight_decay_type} must be one of (\'l2\', \'decoupled\', \'stable\')'
50+
)
51+
4552
@staticmethod
4653
def validate_weight_decay_ratio(weight_decay_ratio: float):
4754
if not 0.0 <= weight_decay_ratio < 1.0:
@@ -99,6 +106,11 @@ def validate_norm(norm: float):
99106
if norm < 0.0:
100107
raise ValueError(f'[-] norm {norm} must be positive')
101108

109+
@staticmethod
110+
def validate_rebound(rebound: str):
111+
if rebound not in ('constant', 'belief'):
112+
raise ValueError(f'[-] rebound {rebound} must be one of (\'constant\' or \'belief\')')
113+
102114
@abstractmethod
103115
def validate_parameters(self):
104116
raise NotImplementedError
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
import torch
5+
from torch.optim.optimizer import Optimizer
6+
7+
from pytorch_optimizer.base.exception import NoSparseGradientError
8+
from pytorch_optimizer.base.optimizer import BaseOptimizer
9+
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
10+
11+
12+
class Apollo(Optimizer, BaseOptimizer):
13+
r"""An Adaptive Parameter-wise Diagonal Quasi-Newton Method for Nonconvex Stochastic Optimization.
14+
15+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
16+
:param lr: float. learning rate.
17+
:param init_lr: Optional[float]. initial learning rate (default lr / 1000).
18+
:param beta: float. coefficient used for computing running averages of gradient.
19+
:param rebound: str. rectified bound for diagonal hessian. (constant, belief).
20+
:param weight_decay: float. weight decay (L2 penalty).
21+
:param weight_decay_type: str. type of weight decay. (l2, decoupled, stable).
22+
:param warmup_steps: int. number of warmup steps.
23+
:param eps: float. term added to the denominator to improve numerical stability.
24+
"""
25+
26+
def __init__(
27+
self,
28+
params: PARAMETERS,
29+
lr: float = 1e-3,
30+
init_lr: Optional[float] = None,
31+
beta: float = 0.9,
32+
rebound: str = 'constant',
33+
weight_decay: float = 0.0,
34+
weight_decay_type: str = 'l2',
35+
warmup_steps: int = 500,
36+
eps: float = 1e-4,
37+
):
38+
self.lr = lr
39+
self.beta = beta
40+
self.rebound = rebound
41+
self.weight_decay = weight_decay
42+
self.weight_decay_type = weight_decay_type
43+
self.warmup_steps = warmup_steps
44+
self.eps = eps
45+
46+
self.validate_parameters()
47+
48+
self.init_lr: float = init_lr if init_lr is not None else lr / 1000.0
49+
50+
defaults: DEFAULTS = {
51+
'lr': lr,
52+
'init_lr': self.init_lr,
53+
'beta': beta,
54+
'weight_decay': weight_decay,
55+
'eps': eps,
56+
}
57+
super().__init__(params, defaults)
58+
59+
def validate_parameters(self):
60+
self.validate_learning_rate(self.lr)
61+
self.validate_beta(self.beta)
62+
self.validate_rebound(self.rebound)
63+
self.validate_weight_decay(self.weight_decay)
64+
self.validate_weight_decay_type(self.weight_decay_type)
65+
self.validate_epsilon(self.eps)
66+
67+
@property
68+
def __str__(self) -> str:
69+
return 'Apollo'
70+
71+
@torch.no_grad()
72+
def reset(self):
73+
for group in self.param_groups:
74+
group['step'] = 0
75+
for p in group['params']:
76+
state = self.state[p]
77+
78+
state['step'] = 0
79+
state['exp_avg_grad'] = torch.zeros_like(p)
80+
state['approx_hessian'] = torch.zeros_like(p)
81+
state['update'] = torch.zeros_like(p)
82+
83+
@torch.no_grad()
84+
def step(self, closure: CLOSURE = None) -> LOSS:
85+
loss: LOSS = None
86+
if closure is not None:
87+
with torch.enable_grad():
88+
loss = closure()
89+
90+
for group in self.param_groups:
91+
if 'step' in group:
92+
group['step'] += 1
93+
else:
94+
group['step'] = 1
95+
96+
current_lr: float = (
97+
group['lr']
98+
if group['step'] >= self.warmup_steps
99+
else (self.lr - group['init_lr']) * group['step'] / self.warmup_steps + group['init_lr']
100+
)
101+
102+
weight_decay, eps = group['weight_decay'], group['eps']
103+
104+
bias_correction: float = 1.0 - group['beta'] ** group['step']
105+
alpha: float = (1.0 - group['beta']) / bias_correction
106+
107+
for p in group['params']:
108+
if p.grad is None:
109+
continue
110+
111+
grad = p.grad
112+
if grad.is_sparse:
113+
raise NoSparseGradientError(self.__str__)
114+
115+
state = self.state[p]
116+
if len(state) == 0:
117+
state['exp_avg_grad'] = torch.zeros_like(p)
118+
state['approx_hessian'] = torch.zeros_like(p)
119+
state['update'] = torch.zeros_like(p)
120+
121+
exp_avg_grad, b, d_p = state['exp_avg_grad'], state['approx_hessian'], state['update']
122+
123+
if weight_decay > 0.0 and self.weight_decay_type == 'l2':
124+
grad.add_(p, alpha=weight_decay)
125+
126+
delta_grad = grad - exp_avg_grad
127+
if self.rebound == 'belief':
128+
rebound = delta_grad.norm(p=np.inf)
129+
else:
130+
rebound = 1e-2
131+
eps /= rebound
132+
133+
exp_avg_grad.add_(delta_grad, alpha=alpha)
134+
135+
de_nom = d_p.norm(p=4).add(eps)
136+
d_p.div_(de_nom)
137+
138+
v_sq = d_p.mul(d_p)
139+
delta = delta_grad.div_(de_nom).mul_(d_p).sum().mul(-alpha) - b.mul(v_sq).sum()
140+
141+
b.addcmul_(v_sq, delta)
142+
143+
de_nom = b.abs().clamp_(min=rebound)
144+
if self.rebound == 'belief':
145+
de_nom.add_(eps / alpha)
146+
147+
d_p.copy_(exp_avg_grad.div(de_nom))
148+
149+
if weight_decay > 0.0 and self.weight_decay_type != 'l2':
150+
if self.weight_decay_type == 'stable':
151+
weight_decay /= de_nom.mean().item()
152+
153+
d_p.add_(p, alpha=weight_decay)
154+
155+
p.add_(d_p, alpha=-current_lr)
156+
157+
return loss

pytorch_optimizer/optimizer/madgrad.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,9 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8989
k = self.state['k']
9090

9191
for group in self.param_groups:
92-
eps = group['eps']
92+
weight_decay, momentum, eps = group['weight_decay'], group['momentum'], group['eps']
9393
lr = group['lr'] + eps
94-
weight_decay = group['weight_decay']
95-
momentum = group['momentum']
9694

97-
ck: float = 1.0 - momentum
9895
_lambda = lr * math.pow(k + 1, 0.5)
9996

10097
for p in group['params']:
@@ -113,8 +110,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
113110
if momentum > 0.0 and grad.is_sparse:
114111
raise NoSparseGradientError(self.__str__, note='momentum > 0.0')
115112

116-
grad_sum_sq = state['grad_sum_sq']
117-
s = state['s']
113+
grad_sum_sq, s = state['grad_sum_sq'], state['s']
118114

119115
if weight_decay > 0.0 and not self.decouple_decay:
120116
if grad.is_sparse:
@@ -176,7 +172,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
176172
p.copy_(x0.addcdiv(s, rms, value=-1))
177173
else:
178174
z = x0.addcdiv(s, rms, value=-1)
179-
p.mul_(1.0 - ck).add_(z, alpha=ck)
175+
p.mul_(momentum).add_(z, alpha=1.0 - momentum)
180176

181177
if weight_decay > 0.0 and self.decouple_decay:
182178
p.add_(p_old, alpha=-lr * weight_decay)

tests/constants.py

Lines changed: 13 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytorch_optimizer import (
44
LARS,
55
MADGRAD,
6+
OPTIMIZERS,
67
PNM,
78
SGDP,
89
AdaBelief,
@@ -13,6 +14,7 @@
1314
AdamS,
1415
Adan,
1516
AdaPNM,
17+
Apollo,
1618
DAdaptAdaGrad,
1719
DAdaptAdam,
1820
DAdaptSGD,
@@ -32,63 +34,7 @@
3234
ADAPTIVE_FLAGS: List[bool] = [True, False]
3335
PULLBACK_MOMENTUM: List[str] = ['none', 'reset', 'pullback']
3436

35-
SPARSE_OPTIMIZERS: List[str] = [
36-
'madgrad',
37-
'dadaptadagrad',
38-
]
39-
NO_SPARSE_OPTIMIZERS: List[str] = [
40-
'adamp',
41-
'sgdp',
42-
'madgrad',
43-
'ranger',
44-
'ranger21',
45-
'radam',
46-
'adabound',
47-
'adabelief',
48-
'diffgrad',
49-
'diffrgrad',
50-
'lamb',
51-
'ralamb',
52-
'lars',
53-
'shampoo',
54-
'scalableshampoo',
55-
'nero',
56-
'adan',
57-
'adai',
58-
'adapnm',
59-
'pnm',
60-
'dadaptadam',
61-
'dadaptsgd',
62-
'adams',
63-
'adafactor',
64-
]
65-
VALID_OPTIMIZER_NAMES: List[str] = [
66-
'adamp',
67-
'adan',
68-
'sgdp',
69-
'madgrad',
70-
'ranger',
71-
'ranger21',
72-
'radam',
73-
'adabound',
74-
'adabelief',
75-
'diffgrad',
76-
'diffrgrad',
77-
'lamb',
78-
'ralamb',
79-
'lars',
80-
'shampoo',
81-
'scalableshampoo',
82-
'pnm',
83-
'adapnm',
84-
'nero',
85-
'adai',
86-
'dadaptadagrad',
87-
'dadaptadam',
88-
'dadaptsgd',
89-
'adams',
90-
'adafactor',
91-
]
37+
VALID_OPTIMIZER_NAMES: List[str] = list(OPTIMIZERS.keys())
9238
INVALID_OPTIMIZER_NAMES: List[str] = [
9339
'asam',
9440
'sam',
@@ -97,6 +43,12 @@
9743
'adamd',
9844
'lookahead',
9945
]
46+
47+
SPARSE_OPTIMIZERS: List[str] = ['madgrad', 'dadaptadagrad']
48+
NO_SPARSE_OPTIMIZERS: List[str] = [
49+
optimizer for optimizer in VALID_OPTIMIZER_NAMES if optimizer not in SPARSE_OPTIMIZERS
50+
]
51+
10052
BETA_OPTIMIZER_NAMES: List[str] = [
10153
'adabelief',
10254
'adabound',
@@ -126,9 +78,7 @@
12678
'CyclicLR',
12779
'OneCycleLR',
12880
]
129-
INVALID_LR_SCHEDULER_NAMES: List[str] = [
130-
'dummy',
131-
]
81+
INVALID_LR_SCHEDULER_NAMES: List[str] = ['dummy']
13282

13383
OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
13484
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
@@ -205,6 +155,9 @@
205155
(AdamS, {'lr': 1.0, 'weight_decay': 1e-3}, 30),
206156
(AdamS, {'lr': 1.0, 'weight_decay': 1e-3, 'amsgrad': True}, 30),
207157
(AdaFactor, {'lr': 5e-1, 'weight_decay': 1e-2, 'scale_parameter': False}, 100),
158+
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
159+
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'rebound': 'belief'}, 10),
160+
(Apollo, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decay_type': 'stable', 'warmup_steps': 0}, 50),
208161
]
209162
ADAMD_SUPPORTED_OPTIMIZERS: List[Tuple[Any, Dict[str, Union[float, bool, int]], int]] = [
210163
(build_lookahead, {'lr': 5e-1, 'weight_decay': 1e-3, 'adamd_debias_term': True}, 10),

tests/test_gradients.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,15 @@ def test_sparse_supported(sparse_optimizer):
5959
optimizer.zero_grad()
6060
optimizer.step()
6161

62-
optimizer = opt([param], momentum=0.0, weight_decay=1e-3)
62+
if sparse_optimizer == 'madgrad':
63+
optimizer = opt([param], momentum=0.0, weight_decay=1e-3, decouple_decay=False)
64+
optimizer.reset()
65+
optimizer.zero_grad()
66+
67+
with pytest.raises(NoSparseGradientError):
68+
optimizer.step()
69+
70+
optimizer = opt([param], momentum=0.9, weight_decay=1e-3)
6371
optimizer.reset()
6472
optimizer.zero_grad()
6573

tests/test_load_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ def test_load_optimizers_invalid(invalid_optimizer_names):
1616

1717

1818
def test_get_supported_optimizers():
19-
assert len(get_supported_optimizers()) == 25
19+
assert len(get_supported_optimizers()) == 26

0 commit comments

Comments
 (0)