Skip to content

Commit 19c3df6

Browse files
authored
Merge pull request #103 from kozistr/update/shampoo-optimizer
[Update] Support SVD method to calculate `M^{-1/p}`
2 parents de06f63 + 01b5c5a commit 19c3df6

File tree

11 files changed

+301
-91
lines changed

11 files changed

+301
-91
lines changed

docs/optimizer_api.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ Shampoo
193193
.. autoclass:: pytorch_optimizer.Shampoo
194194
:members:
195195

196+
.. _ScalableShampoo:
197+
198+
ScalableShampoo
199+
---------------
200+
201+
.. autoclass:: pytorch_optimizer.ScalableShampoo
202+
:members:
203+
196204
.. _GSAM:
197205

198206
GSAM

docs/util_api.rst

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,20 @@ matrix_power
154154
.. autoclass:: pytorch_optimizer.matrix_power
155155
:members:
156156

157-
.. _compute_power:
157+
.. _compute_power_schur_newton:
158158

159-
compute_power
160-
-------------
159+
compute_power_schur_newton
160+
--------------------------
161161

162-
.. autoclass:: pytorch_optimizer.compute_power
162+
.. autoclass:: pytorch_optimizer.compute_power_schur_newton
163+
:members:
164+
165+
.. _compute_power_svd:
166+
167+
compute_power_svd
168+
-----------------
169+
170+
.. autoclass:: pytorch_optimizer.compute_power_svd
163171
:members:
164172

165173
.. _merge_small_dims:

pytorch_optimizer/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from pytorch_optimizer.optimizer.ranger21 import Ranger21
4242
from pytorch_optimizer.optimizer.sam import SAM
4343
from pytorch_optimizer.optimizer.sgdp import SGDP
44-
from pytorch_optimizer.optimizer.shampoo import Shampoo
44+
from pytorch_optimizer.optimizer.shampoo import ScalableShampoo, Shampoo
4545
from pytorch_optimizer.optimizer.shampoo_utils import (
4646
AdaGradGraft,
4747
BlockPartitioner,
@@ -52,7 +52,8 @@
5252
RMSPropGraft,
5353
SGDGraft,
5454
SQRTNGraft,
55-
compute_power,
55+
compute_power_schur_newton,
56+
compute_power_svd,
5657
matrix_power,
5758
merge_small_dims,
5859
power_iter,
@@ -86,6 +87,7 @@
8687
Ranger21,
8788
SGDP,
8889
Shampoo,
90+
ScalableShampoo,
8991
DAdaptAdaGrad,
9092
DAdaptAdam,
9193
DAdaptSGD,

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 142 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,134 @@
1313
RMSPropGraft,
1414
SGDGraft,
1515
SQRTNGraft,
16+
compute_power_svd,
1617
)
1718

1819

1920
class Shampoo(Optimizer, BaseOptimizer):
2021
r"""Preconditioned Stochastic Tensor Optimization.
2122
23+
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
24+
:param lr: float. learning rate.
25+
:param momentum: float. momentum.
26+
:param weight_decay: float. weight decay (L2 penalty).
27+
:param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
28+
requirements. How often to compute pre-conditioner.
29+
:param matrix_eps: float. term added to the denominator to improve numerical stability.
30+
"""
31+
32+
def __init__(
33+
self,
34+
params: PARAMETERS,
35+
lr: float = 1e-3,
36+
momentum: float = 0.0,
37+
weight_decay: float = 0.0,
38+
preconditioning_compute_steps: int = 1,
39+
matrix_eps: float = 1e-6,
40+
):
41+
self.lr = lr
42+
self.momentum = momentum
43+
self.weight_decay = weight_decay
44+
self.preconditioning_compute_steps = preconditioning_compute_steps
45+
self.matrix_eps = matrix_eps
46+
47+
self.validate_parameters()
48+
49+
defaults: DEFAULTS = {
50+
'lr': lr,
51+
'momentum': momentum,
52+
'weight_decay': weight_decay,
53+
}
54+
super().__init__(params, defaults)
55+
56+
def validate_parameters(self):
57+
self.validate_learning_rate(self.lr)
58+
self.validate_momentum(self.momentum)
59+
self.validate_weight_decay(self.weight_decay)
60+
self.validate_update_frequency(self.preconditioning_compute_steps)
61+
self.validate_epsilon(self.matrix_eps)
62+
63+
@property
64+
def __str__(self) -> str:
65+
return 'Shampoo'
66+
67+
@torch.no_grad()
68+
def reset(self):
69+
for group in self.param_groups:
70+
for p in group['params']:
71+
state = self.state[p]
72+
73+
state['step'] = 0
74+
75+
@torch.no_grad()
76+
def step(self, closure: CLOSURE = None) -> LOSS:
77+
loss: LOSS = None
78+
if closure is not None:
79+
with torch.enable_grad():
80+
loss = closure()
81+
82+
for group in self.param_groups:
83+
momentum = group['momentum']
84+
for p in group['params']:
85+
if p.grad is None:
86+
continue
87+
88+
grad = p.grad
89+
if grad.is_sparse:
90+
raise NoSparseGradientError(self.__str__)
91+
92+
state = self.state[p]
93+
if len(state) == 0:
94+
state['step'] = 0
95+
96+
if momentum > 0.0:
97+
state['momentum_buffer'] = grad.clone()
98+
99+
for dim_id, dim in enumerate(grad.size()):
100+
state[f'pre_cond_{dim_id}'] = self.matrix_eps * torch.eye(dim, out=grad.new(dim, dim))
101+
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()
102+
103+
state['step'] += 1
104+
105+
if momentum > 0.0:
106+
grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)
107+
108+
if group['weight_decay'] > 0.0:
109+
grad.add_(p, alpha=group['weight_decay'])
110+
111+
order: int = grad.ndimension()
112+
original_size: int = grad.size()
113+
for dim_id, dim in enumerate(grad.size()):
114+
pre_cond = state[f'pre_cond_{dim_id}']
115+
inv_pre_cond = state[f'inv_pre_cond_{dim_id}']
116+
117+
grad = grad.transpose_(0, dim_id).contiguous()
118+
transposed_size = grad.size()
119+
120+
grad = grad.view(dim, -1)
121+
122+
grad_t = grad.t()
123+
pre_cond.add_(grad @ grad_t)
124+
if state['step'] % self.preconditioning_compute_steps == 0:
125+
inv_pre_cond.copy_(compute_power_svd(pre_cond, -1.0 / order))
126+
127+
if dim_id == order - 1:
128+
grad = grad_t @ inv_pre_cond
129+
grad = grad.view(original_size)
130+
else:
131+
grad = inv_pre_cond @ grad
132+
grad = grad.view(transposed_size)
133+
134+
state['momentum_buffer'] = grad
135+
136+
p.add_(grad, alpha=-group['lr'])
137+
138+
return loss
139+
140+
141+
class ScalableShampoo(Optimizer, BaseOptimizer):
142+
r"""Scalable Preconditioned Stochastic Tensor Optimization.
143+
22144
Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.
23145
24146
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -45,6 +167,10 @@ class Shampoo(Optimizer, BaseOptimizer):
45167
:param nesterov: bool. Nesterov momentum.
46168
:param diagonal_eps: float. term added to the denominator to improve numerical stability.
47169
:param matrix_eps: float. term added to the denominator to improve numerical stability.
170+
:param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
171+
Theoretically, Schur-Newton method is faster than SVD method to calculate M^{-1/p}.
172+
However, the inefficiency of the loop code, SVD is much faster than that.
173+
see https://github.com/kozistr/pytorch_optimizer/pull/103
48174
"""
49175

50176
def __init__(
@@ -60,14 +186,15 @@ def __init__(
60186
start_preconditioning_step: int = 5,
61187
preconditioning_compute_steps: int = 1,
62188
statistics_compute_steps: int = 1,
63-
block_size: int = 128,
189+
block_size: int = 256,
64190
no_preconditioning_for_layers_with_dim_gt: int = 8192,
65191
shape_interpretation: bool = True,
66192
graft_type: int = LayerWiseGrafting.SGD,
67193
pre_conditioner_type: int = PreConditionerType.ALL,
68194
nesterov: bool = True,
69195
diagonal_eps: float = 1e-10,
70196
matrix_eps: float = 1e-6,
197+
use_svd: bool = False,
71198
):
72199
self.lr = lr
73200
self.betas = betas
@@ -87,6 +214,7 @@ def __init__(
87214
self.nesterov = nesterov
88215
self.diagonal_eps = diagonal_eps
89216
self.matrix_eps = matrix_eps
217+
self.use_svd = use_svd
90218

91219
self.validate_parameters()
92220

@@ -109,7 +237,7 @@ def validate_parameters(self):
109237

110238
@property
111239
def __str__(self) -> str:
112-
return 'Shampoo'
240+
return 'ScalableShampoo'
113241

114242
@torch.no_grad()
115243
def reset(self):
@@ -128,6 +256,7 @@ def reset(self):
128256
self.shape_interpretation,
129257
self.matrix_eps,
130258
self.pre_conditioner_type,
259+
self.use_svd,
131260
)
132261
if self.graft_type == LayerWiseGrafting.ADAGRAD:
133262
state['graft'] = AdaGradGraft(p, self.diagonal_eps)
@@ -140,6 +269,9 @@ def reset(self):
140269
else:
141270
state['graft'] = Graft(p)
142271

272+
def is_precondition_step(self, step: int) -> bool:
273+
return step >= self.start_preconditioning_step
274+
143275
@torch.no_grad()
144276
def step(self, closure: CLOSURE = None) -> LOSS:
145277
loss: LOSS = None
@@ -170,6 +302,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170302
self.shape_interpretation,
171303
self.matrix_eps,
172304
self.pre_conditioner_type,
305+
self.use_svd,
173306
)
174307
if self.graft_type == LayerWiseGrafting.ADAGRAD:
175308
state['graft'] = AdaGradGraft(p, self.diagonal_eps)
@@ -185,27 +318,26 @@ def step(self, closure: CLOSURE = None) -> LOSS:
185318
state['step'] += 1
186319
pre_conditioner, graft = state['pre_conditioner'], state['graft']
187320

188-
# gather statistics, compute pre-conditioners
321+
is_precondition_step: bool = self.is_precondition_step(state['step'])
322+
189323
graft.add_statistics(grad, beta2)
190324
if state['step'] % self.statistics_compute_steps == 0:
191325
pre_conditioner.add_statistics(grad)
192326
if state['step'] % self.preconditioning_compute_steps == 0:
193327
pre_conditioner.compute_pre_conditioners()
194328

195-
# pre-condition gradients
196329
pre_conditioner_multiplier: float = group['lr'] if not self.decoupled_learning_rate else 1.0
197330
graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
198331
shampoo_grad: torch.Tensor = grad
199-
if state['step'] >= self.start_preconditioning_step:
332+
if is_precondition_step:
200333
shampoo_grad = pre_conditioner.preconditioned_grad(grad)
201334

202-
# grafting
203-
graft_norm = torch.norm(graft_grad)
204-
shampoo_norm = torch.norm(shampoo_grad)
205335
if self.graft_type != LayerWiseGrafting.NONE:
336+
graft_norm = torch.norm(graft_grad)
337+
shampoo_norm = torch.norm(shampoo_grad)
338+
206339
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
207340

208-
# apply weight decay (adam style)
209341
if group['weight_decay'] > 0.0:
210342
if not self.decoupled_weight_decay:
211343
shampoo_grad.add_(p, alpha=group['weight_decay'])
@@ -214,11 +346,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
214346
shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
215347
graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
216348

217-
# Momentum and Nesterov momentum, if needed
218349
state['momentum'].mul_(beta1).add_(shampoo_grad)
219350
graft_momentum = graft.update_momentum(grad, beta1)
220351

221-
if state['step'] >= self.start_preconditioning_step:
352+
if is_precondition_step:
222353
momentum_update = state['momentum']
223354
wd_update = shampoo_grad
224355
else:

0 commit comments

Comments
 (0)