Skip to content

Commit 29a9dd3

Browse files
authored
Merge pull request #19 from kozistr/feature/adahessian-optimizer
[Feature] Implement AdaHessian optimizer
2 parents 3c4ee7f + f421254 commit 29a9dd3

File tree

10 files changed

+442
-23
lines changed

10 files changed

+442
-23
lines changed

README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,28 @@ Bunch of optimizer implementations in PyTorch with clean-code, strict types. Hig
1212
$ pip3 install pytorch-optimizer
1313
```
1414

15+
### Simple Usage
16+
17+
```
18+
from pytorch_optimizer import Ranger21
19+
20+
...
21+
model = YourModel()
22+
optimizer = Ranger21(model.parameters())
23+
...
24+
25+
for input, output in data:
26+
optimizer.zero_grad()
27+
loss = loss_function(output, model(input))
28+
loss.backward()
29+
optimizer.step()
30+
```
31+
1532
## Supported Optimizers
1633

1734
| Optimizer | Description | Official Code | Paper |
1835
| :---: | :---: | :---: | :---: |
36+
| AdaHessian | *An Adaptive Second Order Optimizer for Machine Learning* | [github](https://github.com/amirgholami/adahessian) | [https://arxiv.org/abs/2006.00719](https://arxiv.org/abs/2006.00719) |
1937
| AdamP | *Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights* | [github](https://github.com/clovaai/AdamP) | [https://arxiv.org/abs/2006.08217](https://arxiv.org/abs/2006.08217) |
2038
| MADGRAD | *A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic* | [github](https://github.com/facebookresearch/madgrad) | [https://arxiv.org/abs/2101.11075](https://arxiv.org/abs/2101.11075) |
2139
| RAdam | *On the Variance of the Adaptive Learning Rate and Beyond* | [github](https://github.com/LiyuanLucasLiu/RAdam) | [https://arxiv.org/abs/1908.03265](https://arxiv.org/abs/1908.03265) |
@@ -303,6 +321,22 @@ Acceleration via Fractal Learning Rate Schedules
303321

304322
</details>
305323

324+
<details>
325+
326+
<summary>AdaHessian</summary>
327+
328+
```
329+
@article{yao2020adahessian,
330+
title={ADAHESSIAN: An adaptive second order optimizer for machine learning},
331+
author={Yao, Zhewei and Gholami, Amir and Shen, Sheng and Mustafa, Mustafa and Keutzer, Kurt and Mahoney, Michael W},
332+
journal={arXiv preprint arXiv:2006.00719},
333+
year={2020}
334+
}
335+
```
336+
337+
</details>
338+
339+
306340
## Author
307341

308342
Hyeongchan Kim / [@kozistr](http://kozistr.tech/about)

pytorch_optimizer/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,13 @@
1-
__VERSION__ = '0.0.2'
1+
from pytorch_optimizer.adahessian import AdaHessian
2+
from pytorch_optimizer.adamp import AdamP
3+
from pytorch_optimizer.agc import agc
4+
from pytorch_optimizer.chebyshev_schedule import get_chebyshev_schedule
5+
from pytorch_optimizer.gc import centralize_gradient
6+
from pytorch_optimizer.lookahead import Lookahead
7+
from pytorch_optimizer.madgrad import MADGRAD
8+
from pytorch_optimizer.radam import RAdam
9+
from pytorch_optimizer.ranger import Ranger
10+
from pytorch_optimizer.ranger21 import Ranger21
11+
from pytorch_optimizer.sgdp import SGDP
12+
13+
__VERSION__ = '0.0.3'

pytorch_optimizer/adahessian.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import torch
2+
from torch.optim import Optimizer
3+
4+
from pytorch_optimizer.types import (
5+
BETAS,
6+
CLOSURE,
7+
DEFAULT_PARAMETERS,
8+
LOSS,
9+
PARAMS,
10+
)
11+
12+
13+
class AdaHessian(Optimizer):
14+
"""
15+
Reference : https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
16+
Example :
17+
from pytorch_optimizer import AdaHessian
18+
...
19+
model = YourModel()
20+
optimizer = AdaHessian(model.parameters())
21+
...
22+
for input, output in data:
23+
optimizer.zero_grad()
24+
loss = loss_function(output, model(input))
25+
loss.backward(create_graph=True) # this is the important line!
26+
optimizer.step()
27+
"""
28+
29+
def __init__(
30+
self,
31+
params: PARAMS,
32+
lr: float = 1e-3,
33+
betas: BETAS = (0.9, 0.999),
34+
eps: float = 1e-8,
35+
weight_decay: float = 0.0,
36+
hessian_power: float = 1.0,
37+
update_each: int = 1,
38+
n_samples: int = 1,
39+
average_conv_kernel: bool = False,
40+
seed: int = 2147483647,
41+
):
42+
"""
43+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
44+
:param lr: float. learning rate.
45+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
46+
:param eps: float. term added to the denominator to improve numerical stability
47+
:param weight_decay: float. weight decay (L2 penalty)
48+
:param hessian_power: float. exponent of the hessian trace
49+
:param update_each: int. compute the hessian trace approximation only after *this* number of steps
50+
:param n_samples: int. how many times to sample `z` for the approximation of the hessian trace
51+
:param average_conv_kernel: bool. average out the hessian traces of convolutional kernels as in the paper.
52+
:param seed: int.
53+
"""
54+
self.lr = lr
55+
self.eps = eps
56+
self.betas = betas
57+
self.weight_decay = weight_decay
58+
self.hessian_power = hessian_power
59+
self.n_samples = n_samples
60+
self.update_each = update_each
61+
self.average_conv_kernel = average_conv_kernel
62+
self.seed = seed
63+
64+
self.check_valid_parameters()
65+
66+
# use a separate generator that deterministically generates the same `z`s across all GPUs
67+
# in case of distributed training
68+
self.generator: torch.Generator = torch.Generator().manual_seed(
69+
self.seed
70+
)
71+
72+
defaults: DEFAULT_PARAMETERS = dict(
73+
lr=lr,
74+
betas=betas,
75+
eps=eps,
76+
weight_decay=weight_decay,
77+
hessian_power=hessian_power,
78+
)
79+
super().__init__(params, defaults)
80+
81+
for p in self.get_params():
82+
p.hess = 0.0
83+
self.state[p]['hessian_step'] = 0
84+
85+
def check_valid_parameters(self):
86+
if 0.0 > self.lr:
87+
raise ValueError(f'Invalid learning rate : {self.lr}')
88+
if 0.0 > self.eps:
89+
raise ValueError(f'Invalid eps : {self.eps}')
90+
if 0.0 > self.weight_decay:
91+
raise ValueError(f'Invalid weight_decay : {self.weight_decay}')
92+
if not 0.0 <= self.betas[0] < 1.0:
93+
raise ValueError(f'Invalid beta_0 : {self.betas[0]}')
94+
if not 0.0 <= self.betas[1] < 1.0:
95+
raise ValueError(f'Invalid beta_1 : {self.betas[1]}')
96+
if not 0.0 <= self.hessian_power < 1.0:
97+
raise ValueError(f'Invalid hessian_power : {self.hessian_power}')
98+
99+
def get_params(self):
100+
"""Gets all parameters in all param_groups with gradients"""
101+
return (
102+
p
103+
for group in self.param_groups
104+
for p in group['params']
105+
if p.requires_grad
106+
)
107+
108+
def zero_hessian(self):
109+
"""Zeros out the accumulated hessian traces."""
110+
for p in self.get_params():
111+
if (
112+
not isinstance(p.hess, float)
113+
and self.state[p]['hessian_step'] % self.update_each == 0
114+
):
115+
p.hess.zero_()
116+
117+
@torch.no_grad()
118+
def set_hessian(self):
119+
"""Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter"""
120+
params = []
121+
for p in filter(
122+
lambda param: param.grad is not None, self.get_params()
123+
):
124+
# compute the trace only each `update_each` step
125+
if self.state[p]['hessian_step'] % self.update_each == 0:
126+
params.append(p)
127+
self.state[p]['hessian_step'] += 1
128+
129+
if len(params) == 0:
130+
return
131+
132+
if self.generator.device != params[0].device:
133+
# hackish way of casting the generator to the right device
134+
self.generator = torch.Generator(params[0].device).manual_seed(
135+
self.seed
136+
)
137+
138+
grads = [p.grad for p in params]
139+
140+
for i in range(self.n_samples):
141+
# Rademacher distribution {-1.0, 1.0}
142+
zs = [
143+
torch.randint(
144+
0, 2, p.size(), generator=self.generator, device=p.device
145+
)
146+
* 2.0
147+
- 1.0
148+
for p in params
149+
]
150+
h_zs = torch.autograd.grad(
151+
grads,
152+
params,
153+
grad_outputs=zs,
154+
only_inputs=True,
155+
retain_graph=i < self.n_samples - 1,
156+
)
157+
for h_z, z, p in zip(h_zs, zs, params):
158+
# approximate the expected values of z * (H@z)
159+
p.hess += h_z * z / self.n_samples
160+
161+
@torch.no_grad()
162+
def step(self, closure: CLOSURE = None) -> LOSS:
163+
loss: LOSS = None
164+
if closure is not None:
165+
loss = closure()
166+
167+
self.zero_hessian()
168+
self.set_hessian()
169+
170+
for group in self.param_groups:
171+
for p in group['params']:
172+
if p.grad is None or p.hess is None:
173+
continue
174+
175+
if self.average_conv_kernel and p.dim() == 4:
176+
p.hess = (
177+
torch.abs(p.hess)
178+
.mean(dim=[2, 3], keepdim=True)
179+
.expand_as(p.hess)
180+
.clone()
181+
)
182+
183+
# Perform correct step-weight decay as in AdamW
184+
p.mul_(1 - group['lr'] * group['weight_decay'])
185+
186+
state = self.state[p]
187+
188+
if len(state) == 1:
189+
state['step'] = 0
190+
state['exp_avg'] = torch.zeros_like(p.data)
191+
state['exp_hessian_diag_sq'] = torch.zeros_like(p.data)
192+
193+
exp_avg, exp_hessian_diag_sq = (
194+
state['exp_avg'],
195+
state['exp_hessian_diag_sq'],
196+
)
197+
beta1, beta2 = group['betas']
198+
state['step'] += 1
199+
200+
# Decay the first and second moment running average coefficient
201+
exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
202+
exp_hessian_diag_sq.mul_(beta2).addcmul_(
203+
p.hess, p.hess, value=1 - beta2
204+
)
205+
206+
bias_correction1 = 1 - beta1 ** state['step']
207+
bias_correction2 = 1 - beta2 ** state['step']
208+
209+
k = group['hessian_power']
210+
denom = (
211+
(exp_hessian_diag_sq / bias_correction2)
212+
.pow_(k / 2)
213+
.add_(group['eps'])
214+
)
215+
216+
step_size = group['lr'] / bias_correction1
217+
p.addcdiv_(exp_avg, denom, value=-step_size)
218+
219+
return loss

pytorch_optimizer/adamp.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,34 @@
55
import torch.nn.functional as F
66
from torch.optim.optimizer import Optimizer
77

8-
from pytorch_optimizer.types import BETAS, CLOSURE, DEFAULT_PARAMETERS, LOSS
8+
from pytorch_optimizer.types import (
9+
BETAS,
10+
CLOSURE,
11+
DEFAULT_PARAMETERS,
12+
LOSS,
13+
PARAMS,
14+
)
915

1016

1117
class AdamP(Optimizer):
18+
"""
19+
Reference : https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
20+
Example :
21+
from pytorch_optimizer import AdamP
22+
...
23+
model = YourModel()
24+
optimizer = AdaHessian(model.parameters())
25+
...
26+
for input, output in data:
27+
optimizer.zero_grad()
28+
loss = loss_function(output, model(input))
29+
loss.backward()
30+
optimizer.step()
31+
"""
32+
1233
def __init__(
1334
self,
14-
params,
35+
params: PARAMS,
1536
lr: float = 1e-3,
1637
betas: BETAS = (0.9, 0.999),
1738
eps: float = 1e-8,
@@ -20,6 +41,18 @@ def __init__(
2041
wd_ratio: float = 0.1,
2142
nesterov: bool = False,
2243
):
44+
"""
45+
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
46+
:param lr: float. learning rate.
47+
:param betas: BETAS. coefficients used for computing running averages of gradient and the squared hessian trace
48+
:param eps: float. term added to the denominator to improve numerical stability
49+
:param weight_decay: float. weight decay (L2 penalty)
50+
:param delta: float. threshold that determines whether a set of parameters is scale invariant or not
51+
:param wd_ratio: float. relative weight decay applied on scale-invariant parameters compared to that applied
52+
on scale-variant parameters
53+
:param nesterov: bool. enables Nesterov momentum
54+
"""
55+
2356
defaults: DEFAULT_PARAMETERS = dict(
2457
lr=lr,
2558
betas=betas,

pytorch_optimizer/agc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
def agc(p, agc_eps: float, agc_clip_val: float, eps: float = 1e-6):
77
"""clip gradient values in excess of the unit-wise norm.
8+
:param p: parameter.
89
:param agc_eps: float.
910
:param agc_clip_val: float.
1011
:param eps: float. simple stop from div by zero and no relation to standard optimizer eps

0 commit comments

Comments
 (0)