Skip to content

Commit 090adb1

Browse files
committed
refactor: Shampoo optimizer
1 parent 9dae3f8 commit 090adb1

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8080
loss = closure()
8181

8282
for group in self.param_groups:
83-
momentum = group['momentum']
83+
momentum, weight_decay = group['momentum'], group['weight_decay']
8484
for p in group['params']:
8585
if p.grad is None:
8686
continue
@@ -100,29 +100,26 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100100
state[f'pre_cond_{dim_id}'] = self.matrix_eps * torch.eye(dim, out=grad.new(dim, dim))
101101
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()
102102

103-
state['step'] += 1
104-
105103
if momentum > 0.0:
106104
grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)
107105

108-
if group['weight_decay'] > 0.0:
109-
grad.add_(p, alpha=group['weight_decay'])
106+
if weight_decay > 0.0:
107+
grad.add_(p, alpha=weight_decay)
110108

111109
order: int = grad.ndimension()
112110
original_size: int = grad.size()
113111
for dim_id, dim in enumerate(grad.size()):
114-
pre_cond = state[f'pre_cond_{dim_id}']
115-
inv_pre_cond = state[f'inv_pre_cond_{dim_id}']
112+
pre_cond, inv_pre_cond = state[f'pre_cond_{dim_id}'], state[f'inv_pre_cond_{dim_id}']
116113

117114
grad = grad.transpose_(0, dim_id).contiguous()
118115
transposed_size = grad.size()
119116

120117
grad = grad.view(dim, -1)
121-
122118
grad_t = grad.t()
119+
123120
pre_cond.add_(grad @ grad_t)
124121
if state['step'] % self.preconditioning_compute_steps == 0:
125-
inv_pre_cond.copy_(compute_power_svd(pre_cond, -1.0 / order))
122+
inv_pre_cond = compute_power_svd(pre_cond, -1.0 / order)
126123

127124
if dim_id == order - 1:
128125
grad = grad_t @ inv_pre_cond
@@ -131,6 +128,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
131128
grad = inv_pre_cond @ grad
132129
grad = grad.view(transposed_size)
133130

131+
state['step'] += 1
134132
state['momentum_buffer'] = grad
135133

136134
p.add_(grad, alpha=-group['lr'])

0 commit comments

Comments
 (0)