Skip to content

Commit fa5a103

Browse files
authored
Merge pull request #17 from kozistr/feature/madgrad-optimizer
[Feature] MADGRAD optimizer
2 parents 6aa80a6 + 2e41dfa commit fa5a103

File tree

9 files changed

+268
-43
lines changed

9 files changed

+268
-43
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ $ pip3 install pytorch-optimizer
1717
| Optimizer | Description | Official Code | Paper |
1818
| :---: | :---: | :---: | :---: |
1919
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | [https://arxiv.org/abs/2006.08217](https://arxiv.org/abs/2006.08217) |
20+
| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | [github](https://github.com/facebookresearch/madgrad) | [https://arxiv.org/abs/2101.11075](https://arxiv.org/abs/2101.11075) |
2021
| RAdam | *On the Variance of the Adaptive Learning Rate and Beyond* | [github](https://github.com/LiyuanLucasLiu/RAdam) | [https://arxiv.org/abs/1908.03265](https://arxiv.org/abs/1908.03265) |
2122
| Ranger | *a synergistic optimizer combining RAdam and LookAhead, and now GC in one optimizer* | [github](https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer) | |
2223
| Ranger21 | *a synergistic deep learning optimizer* | [github](https://github.com/lessw2020/Ranger21) | [https://arxiv.org/abs/2106.13731](https://arxiv.org/abs/2106.13731) |
@@ -287,6 +288,21 @@ Acceleration via Fractal Learning Rate Schedules
287288

288289
</details>
289290

291+
<details>
292+
293+
<summary>MADGRAD</summary>
294+
295+
```
296+
@article{defazio2021adaptivity,
297+
title={Adaptivity without compromise: a momentumized, adaptive, dual averaged gradient method for stochastic optimization},
298+
author={Defazio, Aaron and Jelassi, Samy},
299+
journal={arXiv preprint arXiv:2101.11075},
300+
year={2021}
301+
}
302+
```
303+
304+
</details>
305+
290306
## Author
291307

292308
Hyeongchan Kim / [@kozistr](http://kozistr.tech/about)

pytorch_optimizer/adamp.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
import math
2-
from typing import Any, Callable, Dict, List, Optional, Tuple
2+
from typing import Callable, List, Tuple
33

44
import torch
55
import torch.nn.functional as F
66
from torch.optim.optimizer import Optimizer
77

8+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULT_PARAMETERS, LOSS
9+
810

911
class AdamP(Optimizer):
1012
def __init__(
1113
self,
1214
params,
1315
lr: float = 1e-3,
14-
betas: Tuple[float, float] = (0.9, 0.999),
16+
betas: BETAS = (0.9, 0.999),
1517
eps: float = 1e-8,
1618
weight_decay: float = 0.0,
1719
delta: float = 0.1,
1820
wd_ratio: float = 0.1,
1921
nesterov: bool = False,
2022
):
21-
defaults: Dict[str, Any] = dict(
23+
defaults: DEFAULT_PARAMETERS = dict(
2224
lr=lr,
2325
betas=betas,
2426
eps=eps,
@@ -39,7 +41,10 @@ def layer_view(x: torch.Tensor) -> torch.Tensor:
3941

4042
@staticmethod
4143
def cosine_similarity(
42-
x: torch.Tensor, y: torch.Tensor, eps: float, view_func: Callable
44+
x: torch.Tensor,
45+
y: torch.Tensor,
46+
eps: float,
47+
view_func: Callable[[torch.Tensor], torch.Tensor],
4348
):
4449
x = view_func(x)
4550
y = view_func(y)
@@ -74,8 +79,8 @@ def projection(
7479

7580
return perturb, wd
7681

77-
def step(self, closure: Optional[Callable] = None) -> float:
78-
loss: Optional[float] = None
82+
def step(self, closure: CLOSURE = None) -> LOSS:
83+
loss: LOSS = None
7984
if closure is not None:
8085
loss = closure()
8186

@@ -114,7 +119,6 @@ def step(self, closure: Optional[Callable] = None) -> float:
114119
else:
115120
perturb = exp_avg / denom
116121

117-
# Projection
118122
wd_ratio: float = 1
119123
if len(p.shape) > 1:
120124
perturb, wd_ratio = self.projection(

pytorch_optimizer/lookahead.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
from collections import defaultdict
2-
from typing import Callable, Dict, List, Optional
2+
from typing import Dict
33

44
import torch
55
from torch.optim import Optimizer
66

7+
from pytorch_optimizer.types import (
8+
CLOSURE,
9+
LOSS,
10+
PARAM_GROUP,
11+
PARAM_GROUPS,
12+
STATE,
13+
)
14+
715

816
class Lookahead(Optimizer):
917
def __init__(self, optimizer: Optimizer, k: int = 5, alpha: float = 0.5):
1018
self.optimizer = optimizer
1119
self.k = k
1220
self.alpha = alpha
1321

14-
self.param_groups: List[Dict] = self.optimizer.param_groups
15-
self.fast_state: Dict = self.optimizer.state
16-
self.state = defaultdict(dict)
22+
self.param_groups: PARAM_GROUPS = self.optimizer.param_groups
23+
self.fast_state: STATE = self.optimizer.state
24+
self.state: STATE = defaultdict(dict)
1725

1826
for group in self.param_groups:
1927
group['counter'] = 0
@@ -32,8 +40,8 @@ def update_lookahead(self):
3240
for group in self.param_groups:
3341
self.update(group)
3442

35-
def step(self, closure: Optional[Callable] = None) -> float:
36-
loss: float = self.optimizer.step(closure)
43+
def step(self, closure: CLOSURE = None) -> LOSS:
44+
loss: LOSS = self.optimizer.step(closure)
3745
for group in self.param_groups:
3846
if group['counter'] == 0:
3947
self.update(group)
@@ -42,12 +50,12 @@ def step(self, closure: Optional[Callable] = None) -> float:
4250
group['counter'] = 0
4351
return loss
4452

45-
def state_dict(self) -> Dict[str, torch.Tensor]:
46-
fast_state_dict = self.optimizer.state_dict()
53+
def state_dict(self) -> STATE:
54+
fast_state_dict: STATE = self.optimizer.state_dict()
4755
fast_state = fast_state_dict['state']
4856
param_groups = fast_state_dict['param_groups']
4957

50-
slow_state: Dict[int, torch.Tensor] = {
58+
slow_state: STATE = {
5159
(id(k) if isinstance(k, torch.Tensor) else k): v
5260
for k, v in self.state.items()
5361
}
@@ -58,12 +66,12 @@ def state_dict(self) -> Dict[str, torch.Tensor]:
5866
'param_groups': param_groups,
5967
}
6068

61-
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
62-
slow_state_dict: Dict[str, torch.Tensor] = {
69+
def load_state_dict(self, state_dict: STATE):
70+
slow_state_dict: STATE = {
6371
'state': state_dict['slow_state'],
6472
'param_groups': state_dict['param_groups'],
6573
}
66-
fast_state_dict: Dict[str, torch.Tensor] = {
74+
fast_state_dict: STATE = {
6775
'state': state_dict['fast_state'],
6876
'param_groups': state_dict['param_groups'],
6977
}
@@ -72,6 +80,6 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
7280
self.optimizer.load_state_dict(fast_state_dict)
7381
self.fast_state = self.optimizer.state
7482

75-
def add_param_group(self, param_group: Dict):
83+
def add_param_group(self, param_group: PARAM_GROUP):
7684
param_group['counter'] = 0
7785
self.optimizer.add_param_group(param_group)

pytorch_optimizer/madgrad.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import math
2+
3+
import torch
4+
from torch.optim import Optimizer
5+
6+
from pytorch_optimizer.types import CLOSURE, DEFAULT_PARAMETERS, LOSS
7+
8+
9+
class MADGRAD(Optimizer):
10+
"""
11+
A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
12+
Reference : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
13+
"""
14+
15+
def __init__(
16+
self,
17+
params,
18+
lr: float = 1e-3,
19+
momentum: float = 0.9,
20+
weight_decay: float = 0.0,
21+
eps: float = 1e-6,
22+
):
23+
self.lr = lr
24+
self.momentum = momentum
25+
self.weight_decay = weight_decay
26+
self.eps = eps
27+
28+
self.check_valid_parameters()
29+
30+
defaults: DEFAULT_PARAMETERS = dict(
31+
lr=lr, eps=eps, momentum=momentum, weight_decay=weight_decay
32+
)
33+
super().__init__(params, defaults)
34+
35+
def check_valid_parameters(self):
36+
if 0.0 > self.lr:
37+
raise ValueError(f'Invalid learning rate : {self.lr}')
38+
if 0.0 > self.eps:
39+
raise ValueError(f'Invalid eps : {self.eps}')
40+
if 0.0 > self.weight_decay:
41+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
42+
if 0.0 > self.momentum or 1.0 <= self.momentum:
43+
raise ValueError(f'Invalid momentum : {self.momentum}')
44+
45+
@property
46+
def supports_memory_efficient_fp16(self) -> bool:
47+
return False
48+
49+
@property
50+
def supports_flat_params(self) -> bool:
51+
return True
52+
53+
def step(self, closure: CLOSURE = None) -> LOSS:
54+
"""Performs a single optimization step.
55+
Arguments:
56+
closure (callable, optional): A closure that reevaluates the model
57+
and returns the loss.
58+
"""
59+
loss: LOSS = None
60+
if closure is not None:
61+
loss = closure()
62+
63+
# step counter must be stored in state to ensure correct behavior under
64+
# optimizer sharding
65+
if 'k' not in self.state:
66+
self.state['k'] = torch.tensor([0], dtype=torch.long)
67+
68+
k = self.state['k'].item()
69+
70+
for group in self.param_groups:
71+
eps = group['eps']
72+
lr = group['lr'] + eps
73+
decay = group['weight_decay']
74+
momentum = group['momentum']
75+
76+
ck: float = 1.0 - momentum
77+
_lambda = lr * math.pow(k + 1, 0.5)
78+
79+
for p in group['params']:
80+
if p.grad is None:
81+
continue
82+
83+
grad = p.grad.data
84+
state = self.state[p]
85+
86+
if 'grad_sum_sq' not in state:
87+
state['grad_sum_sq'] = torch.zeros_like(p.data).detach()
88+
state['s'] = torch.zeros_like(p.data).detach()
89+
if momentum != 0:
90+
state['x0'] = torch.clone(p.data).detach()
91+
92+
if momentum != 0.0 and grad.is_sparse:
93+
raise RuntimeError(
94+
'momentum != 0 is not compatible with sparse gradients'
95+
)
96+
97+
grad_sum_sq = state['grad_sum_sq']
98+
s = state['s']
99+
100+
if decay != 0:
101+
if grad.is_sparse:
102+
raise RuntimeError(
103+
'weight_decay option is not compatible with sparse gradients'
104+
)
105+
106+
grad.add_(p.data, alpha=decay)
107+
108+
if grad.is_sparse:
109+
grad = grad.coalesce()
110+
grad_val = grad._values()
111+
112+
p_masked = p.sparse_mask(grad)
113+
grad_sum_sq_masked = grad_sum_sq.sparse_mask(grad)
114+
s_masked = s.sparse_mask(grad)
115+
116+
# Compute x_0 from other known quantities
117+
rms_masked_vals = (
118+
grad_sum_sq_masked._values().pow(1 / 3).add_(eps)
119+
)
120+
x0_masked_vals = p_masked._values().addcdiv(
121+
s_masked._values(), rms_masked_vals, value=1
122+
)
123+
124+
# Dense + sparse op
125+
grad_sq = grad * grad
126+
grad_sum_sq.add_(grad_sq, alpha=_lambda)
127+
grad_sum_sq_masked.add_(grad_sq, alpha=_lambda)
128+
129+
rms_masked_vals = (
130+
grad_sum_sq_masked._values().pow_(1 / 3).add_(eps)
131+
)
132+
133+
s.add_(grad, alpha=_lambda)
134+
s_masked._values().add_(grad_val, alpha=_lambda)
135+
136+
# update masked copy of p
137+
p_kp1_masked_values = x0_masked_vals.addcdiv(
138+
s_masked._values(), rms_masked_vals, value=-1
139+
)
140+
141+
# Copy updated masked p to dense p using an add operation
142+
p_masked._values().add_(p_kp1_masked_values, alpha=-1)
143+
p.data.add_(p_masked, alpha=-1)
144+
else:
145+
if momentum == 0:
146+
# Compute x_0 from other known quantities
147+
rms = grad_sum_sq.pow(1 / 3).add_(eps)
148+
x0 = p.data.addcdiv(s, rms, value=1)
149+
else:
150+
x0 = state['x0']
151+
152+
# Accumulate second moments
153+
grad_sum_sq.addcmul_(grad, grad, value=_lambda)
154+
rms = grad_sum_sq.pow(1 / 3).add_(eps)
155+
156+
# Update s
157+
s.data.add_(grad, alpha=_lambda)
158+
159+
# Step
160+
if momentum == 0:
161+
p.data.copy_(x0.addcdiv(s, rms, value=-1))
162+
else:
163+
z = x0.addcdiv(s, rms, value=-1)
164+
165+
# p is a moving average of z
166+
p.data.mul_(1 - ck).add_(z, alpha=ck)
167+
168+
self.state['k'] += 1
169+
170+
return loss

pytorch_optimizer/radam.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import math
2-
from typing import Any, Callable, Dict, Optional, Tuple
2+
from typing import Dict
33

44
import torch
55
from torch.optim.optimizer import Optimizer
66

7+
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULT_PARAMETERS, LOSS
8+
79

810
class RAdam(Optimizer):
911
"""
@@ -15,7 +17,7 @@ def __init__(
1517
self,
1618
params,
1719
lr: float = 1e-3,
18-
betas: Tuple[float, float] = (0.9, 0.999),
20+
betas: BETAS = (0.9, 0.999),
1921
eps: float = 1e-8,
2022
weight_decay: float = 0.0,
2123
n_sma_threshold: int = 5,
@@ -42,7 +44,7 @@ def __init__(
4244
):
4345
param['buffer'] = [[None, None, None] for _ in range(10)]
4446

45-
defaults: Dict[str, Any] = dict(
47+
defaults: DEFAULT_PARAMETERS = dict(
4648
lr=lr,
4749
betas=betas,
4850
eps=eps,
@@ -67,8 +69,8 @@ def check_valid_parameters(self):
6769
def __setstate__(self, state: Dict):
6870
super().__setstate__(state)
6971

70-
def step(self, closure: Optional[Callable] = None) -> float:
71-
loss: Optional[float] = None
72+
def step(self, closure: CLOSURE = None) -> LOSS:
73+
loss: LOSS = None
7274
if closure is not None:
7375
loss = closure()
7476

0 commit comments

Comments
 (0)