Skip to content

Commit f66a612

Browse files
authored
Merge pull request #55 from kozistr/feature/shampoo-optimizer
[Feature] Shampoo optimizer
2 parents 51fdbce + c41b72b commit f66a612

15 files changed

+257
-59
lines changed

README.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ Supported Optimizers
7474
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
7575
| Lamb | *Large Batch Optimization for Deep Learning* | `github <https://github.com/cybertronai/pytorch-lamb>`__ | `https://arxiv.org/abs/1904.00962 <https://arxiv.org/abs/1904.00962>`__ |
7676
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
77+
| Shampoo | *Preconditioned Stochastic Tensor Optimization* | `github <https://github.com/moskomule/shampoo.pytorch>`__ | `https://arxiv.org/abs/1802.09568 <https://arxiv.org/abs/1802.09568>`__ |
78+
+--------------+----------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------+
7779

7880
Useful Resources
7981
----------------
@@ -467,6 +469,19 @@ AdamD: Improved bias-correction in Adam
467469
year={2021}
468470
}
469471

472+
Shampoo: Preconditioned Stochastic Tensor Optimization
473+
474+
::
475+
476+
@inproceedings{gupta2018shampoo,
477+
title={Shampoo: Preconditioned stochastic tensor optimization},
478+
author={Gupta, Vineet and Koren, Tomer and Singer, Yoram},
479+
booktitle={International Conference on Machine Learning},
480+
pages={1842--1850},
481+
year={2018},
482+
organization={PMLR}
483+
}
484+
470485
Author
471486
------
472487

pytorch_optimizer/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,11 @@
2121
from pytorch_optimizer.ranger21 import Ranger21
2222
from pytorch_optimizer.sam import SAM
2323
from pytorch_optimizer.sgdp import SGDP
24-
from pytorch_optimizer.utils import clip_grad_norm, get_optimizer_parameters, normalize_gradient, unit_norm
24+
from pytorch_optimizer.shampoo import Shampoo
25+
from pytorch_optimizer.utils import (
26+
clip_grad_norm,
27+
get_optimizer_parameters,
28+
matrix_power,
29+
normalize_gradient,
30+
unit_norm,
31+
)

pytorch_optimizer/adabelief.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
126126
state['max_exp_avg_var'] = torch.zeros_like(p)
127127

128128
if self.weight_decouple:
129-
if not self.fixed_decay:
130-
p_fp32.mul_(1.0 - group['lr'] * group['weight_decay'])
131-
else:
132-
p_fp32.mul_(1.0 - group['weight_decay'])
133-
else:
134-
if group['weight_decay'] != 0:
135-
grad.add_(p_fp32, alpha=group['weight_decay'])
129+
decay: float = (
130+
group['lr'] * group['weight_decay'] if not self.fixed_decay else group['weight_decay']
131+
)
132+
p_fp32.mul_(1.0 - decay)
133+
elif group['weight_decay'] != 0:
134+
grad.add_(p_fp32, alpha=group['weight_decay'])
136135

137136
exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var']
138137

pytorch_optimizer/base_optimizer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ def validate_reduction(reduction: str):
8080
if reduction not in ('mean', 'sum'):
8181
raise ValueError(f'[-] reduction {reduction} must be one of (\'mean\' or \'sum\')')
8282

83+
@staticmethod
84+
def validate_update_frequency(update_frequency: int):
85+
if update_frequency < 1:
86+
raise ValueError(f'[-] update_frequency {update_frequency} must be positive')
87+
8388
@abstractmethod
8489
def validate_parameters(self):
8590
raise NotImplementedError

pytorch_optimizer/optimizers.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,41 @@
1212
from pytorch_optimizer.ranger import Ranger
1313
from pytorch_optimizer.ranger21 import Ranger21
1414
from pytorch_optimizer.sgdp import SGDP
15+
from pytorch_optimizer.shampoo import Shampoo
1516

1617

1718
def load_optimizers(optimizer: str):
1819
optimizer: str = optimizer.lower()
1920

2021
if optimizer == 'adamp':
21-
opt = AdamP
22-
elif optimizer == 'ranger':
23-
opt = Ranger
24-
elif optimizer == 'ranger21':
25-
opt = Ranger21
26-
elif optimizer == 'sgdp':
27-
opt = SGDP
28-
elif optimizer == 'radam':
29-
opt = RAdam
30-
elif optimizer == 'adabelief':
31-
opt = AdaBelief
32-
elif optimizer == 'adabound':
33-
opt = AdaBound
34-
elif optimizer == 'madgrad':
35-
opt = MADGRAD
36-
elif optimizer == 'diffgrad':
37-
opt = DiffGrad
38-
elif optimizer == 'diffrgrad':
39-
opt = DiffRGrad
40-
elif optimizer == 'adahessian':
41-
opt = AdaHessian
42-
elif optimizer == 'lamb':
43-
opt = Lamb
44-
elif optimizer == 'ralamb':
45-
opt = RaLamb
46-
elif optimizer == 'lars':
47-
opt = LARS
48-
else:
49-
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')
22+
return AdamP
23+
if optimizer == 'ranger':
24+
return Ranger
25+
if optimizer == 'ranger21':
26+
return Ranger21
27+
if optimizer == 'sgdp':
28+
return SGDP
29+
if optimizer == 'radam':
30+
return RAdam
31+
if optimizer == 'adabelief':
32+
return AdaBelief
33+
if optimizer == 'adabound':
34+
return AdaBound
35+
if optimizer == 'madgrad':
36+
return MADGRAD
37+
if optimizer == 'diffgrad':
38+
return DiffGrad
39+
if optimizer == 'diffrgrad':
40+
return DiffRGrad
41+
if optimizer == 'adahessian':
42+
return AdaHessian
43+
if optimizer == 'lamb':
44+
return Lamb
45+
if optimizer == 'ralamb':
46+
return RaLamb
47+
if optimizer == 'lars':
48+
return LARS
49+
if optimizer == 'shampoo':
50+
return Shampoo
5051

51-
return opt
52+
raise NotImplementedError(f'[-] not implemented optimizer : {optimizer}')

pytorch_optimizer/shampoo.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import torch
2+
from torch.optim.optimizer import Optimizer
3+
4+
from pytorch_optimizer.base_optimizer import BaseOptimizer
5+
from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
6+
from pytorch_optimizer.utils import matrix_power
7+
8+
9+
class Shampoo(Optimizer, BaseOptimizer):
10+
"""
11+
Reference : https://github.com/moskomule/shampoo.pytorch/blob/master/shampoo.py
12+
Example :
13+
from pytorch_optimizer import Shampoo
14+
...
15+
model = YourModel()
16+
optimizer = Shampoo(model.parameters())
17+
...
18+
for input, output in data:
19+
optimizer.zero_grad()
20+
loss = loss_function(output, model(input))
21+
loss.backward()
22+
optimizer.step()
23+
"""
24+
25+
def __init__(
26+
self,
27+
params: PARAMETERS,
28+
lr: float = 1e-3,
29+
momentum: float = 0.0,
30+
weight_decay: float = 0.0,
31+
update_freq: int = 1,
32+
eps: float = 1e-4,
33+
):
34+
"""Shampoo optimizer
35+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
36+
:param lr: float. learning rate
37+
:param momentum: float. momentum
38+
:param weight_decay: float. weight decay (L2 penalty)
39+
:param update_freq: int. update frequency to compute inverse
40+
:param eps: float. term added to the denominator to improve numerical stability
41+
"""
42+
self.lr = lr
43+
self.momentum = momentum
44+
self.weight_decay = weight_decay
45+
self.update_freq = update_freq
46+
self.eps = eps
47+
48+
self.validate_parameters()
49+
50+
defaults: DEFAULTS = dict(
51+
lr=lr,
52+
momentum=momentum,
53+
weight_decay=weight_decay,
54+
update_freq=update_freq,
55+
eps=eps,
56+
)
57+
super().__init__(params, defaults)
58+
59+
def validate_parameters(self):
60+
self.validate_learning_rate(self.lr)
61+
self.validate_momentum(self.momentum)
62+
self.validate_weight_decay(self.weight_decay)
63+
self.validate_update_frequency(self.update_freq)
64+
self.validate_epsilon(self.eps)
65+
66+
@torch.no_grad()
67+
def reset(self):
68+
for group in self.param_groups:
69+
for p in group['params']:
70+
state = self.state[p]
71+
72+
state['step'] = 0
73+
74+
@torch.no_grad()
75+
def step(self, closure: CLOSURE = None) -> LOSS:
76+
loss: LOSS = None
77+
if closure is not None:
78+
with torch.enable_grad():
79+
loss = closure()
80+
81+
for group in self.param_groups:
82+
for p in group['params']:
83+
if p.grad is None:
84+
continue
85+
86+
grad = p.grad
87+
if grad.is_sparse:
88+
raise RuntimeError('Shampoo does not support sparse gradients')
89+
90+
momentum = group['momentum']
91+
state = self.state[p]
92+
if len(state) == 0:
93+
state['step'] = 0
94+
95+
if momentum > 0.0:
96+
state['momentum_buffer'] = grad.clone()
97+
98+
# pre-condition matrices
99+
for dim_id, dim in enumerate(grad.size()):
100+
state[f'pre_cond_{dim_id}'] = group['eps'] * torch.eye(dim, out=grad.new(dim, dim))
101+
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()
102+
103+
if momentum > 0.0:
104+
grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)
105+
106+
weight_decay = group['weight_decay']
107+
if weight_decay > 0.0:
108+
grad.add_(p, alpha=weight_decay)
109+
110+
order: int = grad.ndimension()
111+
original_size: int = grad.size()
112+
for dim_id, dim in enumerate(grad.size()):
113+
pre_cond = state[f'pre_cond_{dim_id}']
114+
inv_pre_cond = state[f'inv_pre_cond_{dim_id}']
115+
116+
grad = grad.transpose_(0, dim_id).contiguous()
117+
transposed_size = grad.size()
118+
119+
grad = grad.view(dim, -1)
120+
121+
grad_t = grad.t()
122+
pre_cond.add_(grad @ grad_t)
123+
if state['step'] % group['update_freq'] == 0:
124+
inv_pre_cond.copy_(matrix_power(pre_cond, -1 / order))
125+
126+
if dim_id == order - 1:
127+
grad = grad_t @ inv_pre_cond
128+
grad = grad.view(original_size)
129+
else:
130+
grad = inv_pre_cond @ grad
131+
grad = grad.view(transposed_size)
132+
133+
state['step'] += 1
134+
state['momentum_buffer'] = grad
135+
136+
p.add_(grad, alpha=-group['lr'])
137+
138+
return loss

pytorch_optimizer/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,13 @@ def get_optimizer_parameters(
156156
},
157157
{'params': [p for n, p in param_optimizer if any(nd in n for nd in wd_ban_list)], 'weight_decay': 0.0},
158158
]
159+
160+
161+
def matrix_power(matrix: torch.Tensor, power: float) -> torch.Tensor:
162+
matrix_device = matrix.device
163+
164+
# use CPU for svd for speed up
165+
u, s, vh = torch.linalg.svd(matrix.cpu(), full_matrices=False)
166+
v = vh.transpose(-2, -1).conj()
167+
168+
return (u @ s.pow_(power).diag() @ v.t()).to(matrix_device)

pytorch_optimizer/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__VERSION__ = '0.4.2'
1+
__VERSION__ = '0.5.0'

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def read_version() -> str:
7070
'lamb',
7171
'ralamb',
7272
'lars',
73+
'shampoo',
7374
]
7475
)
7576

tests/test_load_optimizers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'lamb',
2020
'ralamb',
2121
'lars',
22+
'shampoo',
2223
]
2324

2425
INVALID_OPTIMIZER_NAMES: List[str] = [

0 commit comments

Comments
 (0)