Skip to content

Commit ec2d693

Browse files
authored
Merge pull request #129 from kozistr/update/scalable-shampoo-optimizer
[Update] yet another tweak for Scalable Shampoo optimizer
2 parents 41701ea + 02aaecb commit ec2d693

File tree

7 files changed

+283
-87
lines changed

7 files changed

+283
-87
lines changed

docs/util_api.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ PreConditioner
138138
.. autoclass:: pytorch_optimizer.PreConditioner
139139
:members:
140140

141-
.. _power_iter:
141+
.. _power_iteration:
142142

143-
power_iter
144-
----------
143+
power_iteration
144+
---------------
145145

146-
.. autoclass:: pytorch_optimizer.power_iter
146+
.. autoclass:: pytorch_optimizer.power_iteration
147147
:members:
148148

149149
.. _compute_power_schur_newton:

pytorch_optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
compute_power_schur_newton,
6363
compute_power_svd,
6464
merge_small_dims,
65-
power_iter,
65+
power_iteration,
6666
)
6767
from pytorch_optimizer.optimizer.utils import (
6868
clip_grad_norm,

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
130130
class ScalableShampoo(Optimizer, BaseOptimizer):
131131
r"""Scalable Preconditioned Stochastic Tensor Optimization.
132132
133+
This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment
134+
or XLA devices. So, the original intention is to compute pre-conditioners asynchronously on the distributed
135+
CPUs, but this implementation calculates them which takes 99% of the optimization time on a GPU synchronously.
136+
137+
Still, it is much faster than the previous Shampoo Optimizer because using coupled Newton iteration when
138+
computing G^{-1/p} matrices while the previous one uses SVD which is really slow.
139+
140+
Also, this implementation offers
141+
1. lots of plug-ins (e.g. gradient grafting, type of pre-conditioning, etc)
142+
2. not-yet implemented features in the official Pytorch code.
143+
3. readable, organized, clean code.
144+
133145
Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.
134146
135147
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -151,6 +163,7 @@ class ScalableShampoo(Optimizer, BaseOptimizer):
151163
:param block_size: int. Block size for large layers (if > 0).
152164
Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
153165
Block size should be as large as feasible under memory/time constraints.
166+
:param skip_preconditioning_rank_lt: int. Skips preconditioning for parameters with rank less than this value.
154167
:param no_preconditioning_for_layers_with_dim_gt: int. avoid preconditioning large layers to reduce overall memory.
155168
:param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
156169
result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
@@ -176,10 +189,11 @@ def __init__(
176189
decoupled_weight_decay: bool = False,
177190
decoupled_learning_rate: bool = True,
178191
inverse_exponent_override: int = 0,
179-
start_preconditioning_step: int = 5,
180-
preconditioning_compute_steps: int = 1,
192+
start_preconditioning_step: int = 25,
193+
preconditioning_compute_steps: int = 1000,
181194
statistics_compute_steps: int = 1,
182195
block_size: int = 256,
196+
skip_preconditioning_rank_lt: int = 1,
183197
no_preconditioning_for_layers_with_dim_gt: int = 8192,
184198
shape_interpretation: bool = True,
185199
graft_type: int = LayerWiseGrafting.SGD,
@@ -200,6 +214,7 @@ def __init__(
200214
self.preconditioning_compute_steps = preconditioning_compute_steps
201215
self.statistics_compute_steps = statistics_compute_steps
202216
self.block_size = block_size
217+
self.skip_preconditioning_rank_lt = skip_preconditioning_rank_lt
203218
self.no_preconditioning_for_layers_with_dim_gt = no_preconditioning_for_layers_with_dim_gt
204219
self.shape_interpretation = shape_interpretation
205220
self.graft_type = graft_type
@@ -230,20 +245,21 @@ def __str__(self) -> str:
230245
@torch.no_grad()
231246
def reset(self):
232247
for group in self.param_groups:
248+
group['step'] = 0
233249
for p in group['params']:
234250
state = self.state[p]
235251

236-
state['step'] = 0
237252
state['momentum'] = torch.zeros_like(p)
238253
state['pre_conditioner'] = PreConditioner(
239254
p,
240255
group['betas'][1], # beta2
241256
self.inverse_exponent_override,
242257
self.block_size,
258+
self.skip_preconditioning_rank_lt,
243259
self.no_preconditioning_for_layers_with_dim_gt,
244260
self.shape_interpretation,
245-
self.matrix_eps,
246261
self.pre_conditioner_type,
262+
self.matrix_eps,
247263
self.use_svd,
248264
)
249265
state['graft'] = build_graft(p, self.graft_type, self.diagonal_eps)
@@ -259,6 +275,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
259275
loss = closure()
260276

261277
for group in self.param_groups:
278+
if 'step' in group:
279+
group['step'] += 1
280+
else:
281+
group['step'] = 1
282+
283+
is_precondition_step: bool = self.is_precondition_step(group['step'])
284+
pre_conditioner_multiplier: float = group['lr'] if not self.decoupled_learning_rate else 1.0
285+
262286
beta1, beta2 = group['betas']
263287
for p in group['params']:
264288
if p.grad is None:
@@ -270,41 +294,37 @@ def step(self, closure: CLOSURE = None) -> LOSS:
270294

271295
state = self.state[p]
272296
if len(state) == 0:
273-
state['step'] = 0
274297
state['momentum'] = torch.zeros_like(p)
275298
state['pre_conditioner'] = PreConditioner(
276299
p,
277300
beta2,
278301
self.inverse_exponent_override,
279302
self.block_size,
303+
self.skip_preconditioning_rank_lt,
280304
self.no_preconditioning_for_layers_with_dim_gt,
281305
self.shape_interpretation,
282-
self.matrix_eps,
283306
self.pre_conditioner_type,
307+
self.matrix_eps,
284308
self.use_svd,
285309
)
286310
state['graft'] = build_graft(p, self.graft_type, self.diagonal_eps)
287311

288-
state['step'] += 1
289312
pre_conditioner, graft = state['pre_conditioner'], state['graft']
290313

291314
graft.add_statistics(grad, beta2)
292-
if state['step'] % self.statistics_compute_steps == 0:
315+
if group['step'] % self.statistics_compute_steps == 0:
293316
pre_conditioner.add_statistics(grad)
294-
if state['step'] % self.preconditioning_compute_steps == 0:
317+
if group['step'] % self.preconditioning_compute_steps == 0:
295318
pre_conditioner.compute_pre_conditioners()
296319

297-
is_precondition_step: bool = self.is_precondition_step(state['step'])
298-
pre_conditioner_multiplier: float = group['lr'] if not self.decoupled_learning_rate else 1.0
299-
300320
graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
301321
shampoo_grad: torch.Tensor = (
302322
pre_conditioner.preconditioned_grad(grad) if is_precondition_step else grad
303323
)
304324

305325
if self.graft_type != LayerWiseGrafting.NONE:
306-
graft_norm = torch.norm(graft_grad)
307-
shampoo_norm = torch.norm(shampoo_grad)
326+
graft_norm = torch.linalg.norm(graft_grad)
327+
shampoo_norm = torch.linalg.norm(shampoo_grad)
308328

309329
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
310330

@@ -319,15 +339,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
319339
state['momentum'].mul_(beta1).add_(shampoo_grad)
320340
graft_momentum = graft.update_momentum(grad, beta1)
321341

322-
if is_precondition_step:
323-
momentum_update = state['momentum']
324-
wd_update = shampoo_grad
325-
else:
326-
momentum_update = graft_momentum
327-
wd_update = graft_grad
342+
momentum_update = state['momentum'] if is_precondition_step else graft_momentum
328343

329344
if self.nesterov:
330345
w: float = (1.0 - beta1) if self.moving_average_for_momentum else 1.0
346+
347+
wd_update = shampoo_grad if is_precondition_step else graft_grad
331348
wd_update.mul_(w)
332349

333350
momentum_update.mul_(beta1).add_(wd_update)

0 commit comments

Comments
 (0)