Skip to content

Commit 06dce18

Browse files
authored
Merge pull request #104 from kozistr/fix/svd
[Fix] singular value in `compute_power_svd()`
2 parents 19c3df6 + 55dcb36 commit 06dce18

File tree

4 files changed

+37
-38
lines changed

4 files changed

+37
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.4.0"
3+
version = "2.4.1"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
8080
loss = closure()
8181

8282
for group in self.param_groups:
83-
momentum = group['momentum']
83+
momentum, weight_decay = group['momentum'], group['weight_decay']
8484
for p in group['params']:
8585
if p.grad is None:
8686
continue
@@ -100,29 +100,26 @@ def step(self, closure: CLOSURE = None) -> LOSS:
100100
state[f'pre_cond_{dim_id}'] = self.matrix_eps * torch.eye(dim, out=grad.new(dim, dim))
101101
state[f'inv_pre_cond_{dim_id}'] = grad.new(dim, dim).zero_()
102102

103-
state['step'] += 1
104-
105103
if momentum > 0.0:
106104
grad.mul_(1.0 - momentum).add_(state['momentum_buffer'], alpha=momentum)
107105

108-
if group['weight_decay'] > 0.0:
109-
grad.add_(p, alpha=group['weight_decay'])
106+
if weight_decay > 0.0:
107+
grad.add_(p, alpha=weight_decay)
110108

111109
order: int = grad.ndimension()
112110
original_size: int = grad.size()
113111
for dim_id, dim in enumerate(grad.size()):
114-
pre_cond = state[f'pre_cond_{dim_id}']
115-
inv_pre_cond = state[f'inv_pre_cond_{dim_id}']
112+
pre_cond, inv_pre_cond = state[f'pre_cond_{dim_id}'], state[f'inv_pre_cond_{dim_id}']
116113

117114
grad = grad.transpose_(0, dim_id).contiguous()
118115
transposed_size = grad.size()
119116

120117
grad = grad.view(dim, -1)
121-
122118
grad_t = grad.t()
119+
123120
pre_cond.add_(grad @ grad_t)
124121
if state['step'] % self.preconditioning_compute_steps == 0:
125-
inv_pre_cond.copy_(compute_power_svd(pre_cond, -1.0 / order))
122+
inv_pre_cond = compute_power_svd(pre_cond, -1.0 / order)
126123

127124
if dim_id == order - 1:
128125
grad = grad_t @ inv_pre_cond
@@ -131,6 +128,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
131128
grad = inv_pre_cond @ grad
132129
grad = grad.view(transposed_size)
133130

131+
state['step'] += 1
134132
state['momentum_buffer'] = grad
135133

136134
p.add_(grad, alpha=-group['lr'])

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,8 @@ def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
488488
:param power: float. -1.0 / order.
489489
"""
490490
u, s, vh = torch.linalg.svd(matrix, full_matrices=False)
491-
return u @ s.pow_(power).diag_embed() @ vh
491+
s.pow_(power)
492+
return u @ (s.diag() if len(matrix.shape) == 2 else s.diag_embed()) @ vh
492493

493494

494495
def merge_small_dims(shape_to_merge: List[int], max_dim: int) -> List[int]:

tests/constants.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -133,17 +133,17 @@
133133
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'weight_decouple': False}, 10),
134134
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'fixed_decay': True}, 10),
135135
(AdaBelief, {'lr': 5e-1, 'weight_decay': 1e-3, 'rectify': False}, 10),
136-
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3}, 100),
137-
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'fixed_decay': True}, 100),
138-
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'weight_decouple': False}, 100),
139-
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True}, 100),
140-
(Adai, {'lr': 1e-1, 'weight_decay': 0.0}, 150),
141-
(Adai, {'lr': 1e-1, 'weight_decay': 0.0, 'use_gc': True}, 150),
142-
(Adai, {'lr': 1e-1, 'weight_decay': 0.0, 'dampening': 0.9}, 150),
143-
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False}, 100),
144-
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True}, 100),
145-
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False, 'use_stable_weight_decay': True}, 100),
146-
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True, 'use_stable_weight_decay': True}, 100),
136+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3}, 75),
137+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'fixed_decay': True}, 75),
138+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'weight_decouple': False}, 75),
139+
(AdaBound, {'lr': 5e-1, 'gamma': 0.1, 'weight_decay': 1e-3, 'amsbound': True}, 75),
140+
(Adai, {'lr': 2e-1, 'weight_decay': 0.0}, 50),
141+
(Adai, {'lr': 2e-1, 'weight_decay': 0.0, 'use_gc': True}, 75),
142+
(Adai, {'lr': 2e-1, 'weight_decay': 0.0, 'dampening': 0.9}, 50),
143+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False}, 50),
144+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True}, 50),
145+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': False, 'use_stable_weight_decay': True}, 50),
146+
(Adai, {'lr': 1e-1, 'weight_decay': 1e-4, 'weight_decouple': True, 'use_stable_weight_decay': True}, 50),
147147
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3}, 10),
148148
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 10),
149149
(AdamP, {'lr': 5e-1, 'weight_decay': 1e-3, 'nesterov': True}, 10),
@@ -156,18 +156,18 @@
156156
(Lamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'pre_norm': True, 'eps': 1e-8}, 100),
157157
(LARS, {'lr': 1e-1, 'weight_decay': 1e-3}, 100),
158158
(LARS, {'lr': 1e-1, 'nesterov': True}, 100),
159-
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 100),
160-
(RaLamb, {'lr': 1e-2, 'weight_decay': 1e-3, 'pre_norm': True}, 100),
161-
(RaLamb, {'lr': 1e-2, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 100),
162-
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3}, 100),
163-
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3, 'eps': 0.0}, 100),
164-
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3, 'momentum': 0.0}, 100),
165-
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3, 'decouple_decay': True}, 100),
166-
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 100),
167-
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 100),
159+
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3}, 50),
160+
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'pre_norm': True}, 50),
161+
(RaLamb, {'lr': 1e-1, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 50),
162+
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3}, 50),
163+
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3, 'eps': 0.0}, 50),
164+
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3, 'momentum': 0.0}, 50),
165+
(MADGRAD, {'lr': 1e-2, 'weight_decay': 1e-3, 'decouple_decay': True}, 50),
166+
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3}, 50),
167+
(RAdam, {'lr': 1e-1, 'weight_decay': 1e-3, 'degenerated_to_sgd': True}, 50),
168168
(SGDP, {'lr': 5e-2, 'weight_decay': 1e-4}, 50),
169169
(SGDP, {'lr': 5e-2, 'weight_decay': 1e-4, 'nesterov': True}, 50),
170-
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3}, 200),
170+
(Ranger, {'lr': 5e-1, 'weight_decay': 1e-3}, 150),
171171
(Ranger21, {'lr': 5e-1, 'weight_decay': 1e-3, 'num_iterations': 500}, 200),
172172
(Shampoo, {'lr': 5e-1, 'weight_decay': 1e-3, 'momentum': 0.1}, 10),
173173
(ScalableShampoo, {'lr': 1e-1, 'weight_decay': 1e-3, 'graft_type': 0}, 10),
@@ -188,12 +188,12 @@
188188
(AdaPNM, {'lr': 3e-1, 'weight_decay': 1e-3, 'amsgrad': False}, 50),
189189
(Nero, {'lr': 5e-1}, 50),
190190
(Nero, {'lr': 5e-1, 'constraints': False}, 50),
191-
(Adan, {'lr': 5e-1}, 100),
192-
(Adan, {'lr': 5e-1, 'max_grad_norm': 1.0}, 100),
193-
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 150),
194-
(Adan, {'lr': 1e-1, 'weight_decay': 1e-3, 'use_gc': True, 'weight_decouple': True}, 100),
195-
(DAdaptAdaGrad, {'lr': 1.0, 'weight_decay': 1e-2}, 150),
196-
(DAdaptAdaGrad, {'lr': 1.0, 'weight_decay': 1e-2, 'momentum': 0.1}, 150),
191+
(Adan, {'lr': 5e-1}, 75),
192+
(Adan, {'lr': 5e-1, 'max_grad_norm': 1.0}, 75),
193+
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True}, 100),
194+
(Adan, {'lr': 5e-1, 'weight_decay': 1e-3, 'use_gc': True, 'weight_decouple': True}, 75),
195+
(DAdaptAdaGrad, {'lr': 1.0, 'weight_decay': 1e-3}, 150),
196+
(DAdaptAdaGrad, {'lr': 1.0, 'weight_decay': 1e-3, 'momentum': 0.1}, 150),
197197
(DAdaptAdam, {'lr': 1.0, 'weight_decay': 1e-2}, 50),
198198
(DAdaptAdam, {'lr': 1.0, 'weight_decay': 1e-2, 'weight_decouple': True}, 50),
199199
(DAdaptSGD, {'lr': 1.0, 'weight_decay': 1e-2}, 30),

0 commit comments

Comments
 (0)