Skip to content

Commit 44c423a

Browse files
authored
Merge pull request #99 from kozistr/update/shampoo-optimizer
[Feature] Implement more Shampoo features
2 parents 40f34df + 5da4074 commit 44c423a

File tree

8 files changed

+219
-68
lines changed

8 files changed

+219
-68
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "pytorch_optimizer"
3-
version = "2.3.0"
3+
version = "2.3.1"
44
description = "optimizer & lr scheduler implementations in PyTorch with clean-code, strict types. Also, including useful optimization ideas."
55
license = "Apache-2.0"
66
authors = ["kozistr <[email protected]>"]

pytorch_optimizer/optimizer/shampoo.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,17 @@
33

44
from pytorch_optimizer.base.exception import NoSparseGradientError
55
from pytorch_optimizer.base.optimizer import BaseOptimizer
6-
from pytorch_optimizer.base.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS
7-
from pytorch_optimizer.optimizer.shampoo_utils import AdagradGraft, Graft, LayerWiseGrafting, PreConditioner, SGDGraft
6+
from pytorch_optimizer.base.types import BETAS, CLOSURE, DEFAULTS, LOSS, PARAMETERS
7+
from pytorch_optimizer.optimizer.shampoo_utils import (
8+
AdagradGraft,
9+
Graft,
10+
LayerWiseGrafting,
11+
PreConditioner,
12+
PreConditionerType,
13+
RMSPropGraft,
14+
SGDGraft,
15+
SQRTNGraft,
16+
)
817

918

1019
class Shampoo(Optimizer, BaseOptimizer):
@@ -14,9 +23,11 @@ class Shampoo(Optimizer, BaseOptimizer):
1423
1524
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1625
:param lr: float. learning rate.
17-
:param momentum: float. momentum.
18-
:param beta2: float. beta2.
26+
:param betas: BETAS. beta1, beta2.
27+
:param moving_average_for_momentum: bool. perform moving_average for momentum (beta1).
1928
:param weight_decay: float. weight decay (L2 penalty).
29+
:param decoupled_weight_decay: bool. use decoupled weight_decay.
30+
:param decoupled_learning_rate: bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.
2031
:param inverse_exponent_override: int. fixed exponent for pre-conditioner, if > 0.
2132
:param start_preconditioning_step: int.
2233
:param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
@@ -28,7 +39,8 @@ class Shampoo(Optimizer, BaseOptimizer):
2839
:param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
2940
result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
3041
statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
31-
:param graft_type: bool. Type of grafting (SGD or AdaGrad).
42+
:param graft_type: int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).
43+
:param pre_conditioner_type: int. type of pre-conditioner.
3244
:param nesterov: bool. Nesterov momentum.
3345
:param diagonal_eps: float. term added to the denominator to improve numerical stability.
3446
:param matrix_eps: float. term added to the denominator to improve numerical stability.
@@ -38,31 +50,37 @@ def __init__(
3850
self,
3951
params: PARAMETERS,
4052
lr: float = 1e-3,
41-
momentum: float = 0.0,
42-
beta2: float = 1.0,
53+
betas: BETAS = (0.9, 0.999),
54+
moving_average_for_momentum: bool = False,
4355
weight_decay: float = 0.0,
56+
decoupled_weight_decay: bool = False,
57+
decoupled_learning_rate: bool = True,
4458
inverse_exponent_override: int = 0,
45-
start_preconditioning_step: int = 1,
59+
start_preconditioning_step: int = 5,
4660
preconditioning_compute_steps: int = 1,
4761
statistics_compute_steps: int = 1,
4862
block_size: int = 128,
4963
shape_interpretation: bool = True,
5064
graft_type: int = LayerWiseGrafting.SGD,
65+
pre_conditioner_type: int = PreConditionerType.ALL,
5166
nesterov: bool = True,
52-
diagonal_eps: float = 1e-6,
53-
matrix_eps: float = 1e-12,
67+
diagonal_eps: float = 1e-10,
68+
matrix_eps: float = 1e-6,
5469
):
5570
self.lr = lr
56-
self.momentum = momentum
57-
self.beta2 = beta2
71+
self.betas = betas
72+
self.moving_average_for_momentum = moving_average_for_momentum
5873
self.weight_decay = weight_decay
74+
self.decoupled_weight_decay = decoupled_weight_decay
75+
self.decoupled_learning_rate = decoupled_learning_rate
5976
self.inverse_exponent_override = inverse_exponent_override
6077
self.start_preconditioning_step = start_preconditioning_step
6178
self.preconditioning_compute_steps = preconditioning_compute_steps
6279
self.statistics_compute_steps = statistics_compute_steps
6380
self.block_size = block_size
6481
self.shape_interpretation = shape_interpretation
6582
self.graft_type = graft_type
83+
self.pre_conditioner_type = pre_conditioner_type
6684
self.nesterov = nesterov
6785
self.diagonal_eps = diagonal_eps
6886
self.matrix_eps = matrix_eps
@@ -71,14 +89,14 @@ def __init__(
7189

7290
defaults: DEFAULTS = {
7391
'lr': lr,
74-
'momentum': momentum,
92+
'betas': betas,
7593
'weight_decay': weight_decay,
7694
}
7795
super().__init__(params, defaults)
7896

7997
def validate_parameters(self):
8098
self.validate_learning_rate(self.lr)
81-
self.validate_momentum(self.momentum)
99+
self.validate_betas(self.betas)
82100
self.validate_weight_decay(self.weight_decay)
83101
self.validate_update_frequency(self.start_preconditioning_step)
84102
self.validate_update_frequency(self.statistics_compute_steps)
@@ -100,16 +118,21 @@ def reset(self):
100118
state['momentum'] = torch.zeros_like(p)
101119
state['pre_conditioner'] = PreConditioner(
102120
p,
103-
self.beta2,
121+
group['betas'][1], # beta2
104122
self.inverse_exponent_override,
105123
self.block_size,
106124
self.shape_interpretation,
107125
self.matrix_eps,
126+
self.pre_conditioner_type,
108127
)
109128
if self.graft_type == LayerWiseGrafting.ADAGRAD:
110129
state['graft'] = AdagradGraft(p, self.diagonal_eps)
130+
elif self.graft_type == LayerWiseGrafting.RMSPROP:
131+
state['graft'] = RMSPropGraft(p, self.diagonal_eps)
111132
elif self.graft_type == LayerWiseGrafting.SGD:
112133
state['graft'] = SGDGraft(p)
134+
elif self.graft_type == LayerWiseGrafting.SQRTN:
135+
state['graft'] = SQRTNGraft(p)
113136
else:
114137
state['graft'] = Graft(p)
115138

@@ -121,6 +144,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
121144
loss = closure()
122145

123146
for group in self.param_groups:
147+
beta1, beta2 = group['betas']
124148
for p in group['params']:
125149
if p.grad is None:
126150
continue
@@ -135,48 +159,59 @@ def step(self, closure: CLOSURE = None) -> LOSS:
135159
state['momentum'] = torch.zeros_like(p)
136160
state['pre_conditioner'] = PreConditioner(
137161
p,
138-
self.beta2,
162+
beta2,
139163
self.inverse_exponent_override,
140164
self.block_size,
141165
self.shape_interpretation,
142166
self.matrix_eps,
167+
self.pre_conditioner_type,
143168
)
144169
if self.graft_type == LayerWiseGrafting.ADAGRAD:
145170
state['graft'] = AdagradGraft(p, self.diagonal_eps)
171+
elif self.graft_type == LayerWiseGrafting.RMSPROP:
172+
state['graft'] = RMSPropGraft(p, self.diagonal_eps)
146173
elif self.graft_type == LayerWiseGrafting.SGD:
147174
state['graft'] = SGDGraft(p)
175+
elif self.graft_type == LayerWiseGrafting.SQRTN:
176+
state['graft'] = SQRTNGraft(p)
148177
else:
149178
state['graft'] = Graft(p)
150179

151180
state['step'] += 1
152181
pre_conditioner, graft = state['pre_conditioner'], state['graft']
153182

154183
# gather statistics, compute pre-conditioners
155-
graft.add_statistics(grad)
184+
graft.add_statistics(grad, beta2)
156185
if state['step'] % self.statistics_compute_steps == 0:
157186
pre_conditioner.add_statistics(grad)
158187
if state['step'] % self.preconditioning_compute_steps == 0:
159188
pre_conditioner.compute_pre_conditioners()
160189

161190
# pre-condition gradients
162-
graft_grad: torch.Tensor = graft.precondition_gradient(grad)
191+
pre_conditioner_multiplier: float = group['lr'] if not self.decoupled_learning_rate else 1.0
192+
graft_grad: torch.Tensor = graft.precondition_gradient(grad * pre_conditioner_multiplier)
163193
shampoo_grad: torch.Tensor = grad
164194
if state['step'] >= self.start_preconditioning_step:
165195
shampoo_grad = pre_conditioner.preconditioned_grad(grad)
166196

167197
# grafting
168198
graft_norm = torch.norm(graft_grad)
169199
shampoo_norm = torch.norm(shampoo_grad)
170-
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
200+
if self.graft_type != LayerWiseGrafting.NONE:
201+
shampoo_grad.mul_(graft_norm / (shampoo_norm + 1e-16))
171202

172203
# apply weight decay (adam style)
173204
if group['weight_decay'] > 0.0:
174-
shampoo_grad.add_(p, alpha=group['weight_decay'])
175-
graft_grad.add_(p, alpha=group['weight_decay'])
205+
if not self.decoupled_weight_decay:
206+
shampoo_grad.add_(p, alpha=group['weight_decay'])
207+
graft_grad.add_(p, alpha=group['weight_decay'])
208+
else:
209+
shampoo_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
210+
graft_grad.mul_(1.0 - group['lr'] * group['weight_decay'])
176211

177212
# Momentum and Nesterov momentum, if needed
178-
state['momentum'].mul_(group['momentum']).add_(shampoo_grad)
179-
graft_momentum = graft.update_momentum(grad, group['momentum'])
213+
state['momentum'].mul_(beta1).add_(shampoo_grad)
214+
graft_momentum = graft.update_momentum(grad, beta1)
180215

181216
if state['step'] >= self.start_preconditioning_step:
182217
momentum_update = state['momentum']
@@ -186,7 +221,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
186221
wd_update = graft_grad
187222

188223
if self.nesterov:
189-
momentum_update.mul_(group['momentum']).add_(wd_update)
224+
w: float = (1.0 - beta1) if self.moving_average_for_momentum else 1.0
225+
wd_update.mul_(w)
226+
227+
momentum_update.mul_(beta1).add_(wd_update)
190228

191229
p.add_(momentum_update, alpha=-group['lr'])
192230

0 commit comments

Comments
 (0)