Skip to content

Commit 54e68a4

Browse files
authored
Merge pull request #59 from kozistr/feature/optimizers
[Feature] Remove AdaHessian optimizer
2 parents a4b594b + 356dce0 commit 54e68a4

File tree

13 files changed

+194
-222
lines changed

13 files changed

+194
-222
lines changed

README.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ Supported Optimizers
7676
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
7777
| Shampoo | *Preconditioned Stochastic Tensor Optimization* | `github <https://github.com/moskomule/shampoo.pytorch>`__ | `https://arxiv.org/abs/1802.09568 <https://arxiv.org/abs/1802.09568>`__ |
7878
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
79+
| Nero | *Learning by Turning: Neural Architecture Aware Optimisation* | `github <https://github.com/jxbz/nero>`__ | `https://arxiv.org/abs/2102.07227 <https://arxiv.org/abs/2102.07227>`__ |
80+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
7981

8082
Useful Resources
8183
----------------
@@ -482,6 +484,17 @@ Shampoo: Preconditioned Stochastic Tensor Optimization
482484
organization={PMLR}
483485
}
484486

487+
Nero: Learning by Turning: Neural Architecture Aware Optimisation
488+
489+
::
490+
491+
@misc{nero2021,
492+
title={Learning by Turning: Neural Architecture Aware Optimisation},
493+
author={Yang Liu and Jeremy Bernstein and Markus Meister and Yisong Yue},
494+
year={2021},
495+
eprint={arXiv:2102.07227}
496+
}
497+
485498
Author
486499
------
487500

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "1.0.0"
3+
version = "1.1.0"
44
description = "Bunch of optimizer implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# pylint: disable=unused-import
22
from pytorch_optimizer.adabelief import AdaBelief
33
from pytorch_optimizer.adabound import AdaBound
4-
from pytorch_optimizer.adahessian import AdaHessian
54
from pytorch_optimizer.adamp import AdamP
65
from pytorch_optimizer.adapnm import AdaPNM
76
from pytorch_optimizer.agc import agc
@@ -14,6 +13,7 @@
1413
from pytorch_optimizer.lars import LARS
1514
from pytorch_optimizer.lookahead import Lookahead
1615
from pytorch_optimizer.madgrad import MADGRAD
16+
from pytorch_optimizer.nero import Nero
1717
from pytorch_optimizer.optimizers import load_optimizers
1818
from pytorch_optimizer.pcgrad import PCGrad
1919
from pytorch_optimizer.pnm import PNM

pytorch_optimizer/adahessian.py

Lines changed: 0 additions & 198 deletions
This file was deleted.

pytorch_optimizer/base_optimizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ def validate_learning_rate(learning_rate: float):
1111
if learning_rate < 0.0:
1212
raise ValueError(f'[-] learning rate {learning_rate} must be positive')
1313

14+
@staticmethod
15+
def validate_beta(beta: float):
16+
if not 0.0 <= beta <= 1.0:
17+
raise ValueError(f'[-] beta {beta} must be in the range [0, 1]')
18+
1419
@staticmethod
1520
def validate_beta0(beta_0: float):
1621
if not 0.0 <= beta_0 <= 1.0:
@@ -39,11 +44,6 @@ def validate_weight_decay_ratio(weight_decay_ratio: float):
3944
if not 0.0 <= weight_decay_ratio < 1.0:
4045
raise ValueError(f'[-] weight_decay_ratio {weight_decay_ratio} must be in the range [0, 1)')
4146

42-
@staticmethod
43-
def validate_hessian_power(hessian_power: float):
44-
if not 0.0 <= hessian_power <= 1.0:
45-
raise ValueError(f'[-] hessian_power {hessian_power} must be in the range [0, 1]')
46-
4747
@staticmethod
4848
def validate_trust_coefficient(trust_coefficient: float):
4949
if trust_coefficient < 0.0:

pytorch_optimizer/nero.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base_optimizer import BaseOptimizer
5+
from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
6+
from pytorch_optimizer.utils import neuron_mean, neuron_norm
7+
8+
9+
class Nero(Optimizer, BaseOptimizer):
10+
"""
11+
Reference : https://github.com/jxbz/nero
12+
Example :
13+
from pytorch_optimizer import Nero
14+
...
15+
model = YourModel()
16+
optimizer = Nero(model.parameters())
17+
...
18+
for input, output in data:
19+
optimizer.zero_grad()
20+
loss = loss_function(output, model(input))
21+
loss.backward()
22+
optimizer.step()
23+
"""
24+
25+
def __init__(self, params: PARAMETERS, lr: float = 0.01, beta: float = 0.999, constraints: bool = True):
26+
"""AdamP optimizer
27+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
28+
:param lr: float. learning rate
29+
:param beta: float. coefficients used for computing running averages of gradient and the squared hessian trace
30+
:param constraints: bool.
31+
"""
32+
self.lr = lr
33+
self.beta = beta
34+
35+
self.validate_parameters()
36+
37+
defaults: DEFAULTS = dict(lr=lr, constraints=constraints)
38+
super().__init__(params, defaults)
39+
40+
def validate_parameters(self):
41+
self.validate_learning_rate(self.lr)
42+
self.validate_beta(self.beta)
43+
44+
@torch.no_grad()
45+
def reset(self):
46+
for group in self.param_groups:
47+
for p in group['params']:
48+
if group['constraints'] and p.dim() > 1:
49+
p.sub_(neuron_mean(p))
50+
p.div_(neuron_norm(p))
51+
52+
state = self.state[p]
53+
54+
state['step'] = 0
55+
state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))
56+
state['scale'] = neuron_norm(p).mean()
57+
58+
if state['scale'] == 0.0:
59+
state['scale'] = 0.01
60+
61+
@torch.no_grad()
62+
def step(self, closure: CLOSURE = None) -> LOSS:
63+
loss: LOSS = None
64+
if closure is not None:
65+
with torch.enable_grad():
66+
loss = closure()
67+
68+
for group in self.param_groups:
69+
for p in group['params']:
70+
if p.grad is None:
71+
continue
72+
73+
grad = p.grad
74+
if grad.is_sparse:
75+
raise RuntimeError('Nero does not support sparse gradients')
76+
77+
state = self.state[p]
78+
if len(state) == 0:
79+
if group['constraints'] and p.dim() > 1:
80+
p.sub_(neuron_mean(p))
81+
p.div_(neuron_norm(p))
82+
83+
state['step'] = 0
84+
state['exp_avg_sq'] = torch.zeros_like(neuron_norm(p))
85+
state['scale'] = neuron_norm(p).mean()
86+
if state['scale'] == 0.0:
87+
state['scale'] = 0.01
88+
89+
state['step'] += 1
90+
91+
bias_correction: float = 1.0 - self.beta ** state['step']
92+
state['exp_avg_sq'] = self.beta * state['exp_avg_sq'] + (1.0 - self.beta) * neuron_norm(grad) ** 2
93+
94+
grad_normed = grad / (state['exp_avg_sq'] / bias_correction).sqrt()
95+
grad_normed[torch.isnan(grad_normed)] = 0.0
96+
97+
p.sub_(group['lr'] * state['scale'] * grad_normed)
98+
99+
if group['constraints'] and p.dim() > 1:
100+
p.sub_(neuron_mean(p))
101+
p.div_(neuron_norm(p))
102+
103+
return loss

0 commit comments

Comments
 (0)