33
44from pytorch_optimizer .base .exception import NoSparseGradientError
55from pytorch_optimizer .base .optimizer import BaseOptimizer
6- from pytorch_optimizer .base .types import CLOSURE , DEFAULTS , LOSS , PARAMETERS
7- from pytorch_optimizer .optimizer .shampoo_utils import AdagradGraft , Graft , LayerWiseGrafting , PreConditioner , SGDGraft
6+ from pytorch_optimizer .base .types import BETAS , CLOSURE , DEFAULTS , LOSS , PARAMETERS
7+ from pytorch_optimizer .optimizer .shampoo_utils import (
8+ AdagradGraft ,
9+ Graft ,
10+ LayerWiseGrafting ,
11+ PreConditioner ,
12+ PreConditionerType ,
13+ RMSPropGraft ,
14+ SGDGraft ,
15+ SQRTNGraft ,
16+ )
817
918
1019class Shampoo (Optimizer , BaseOptimizer ):
@@ -14,9 +23,11 @@ class Shampoo(Optimizer, BaseOptimizer):
1423
1524 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
1625 :param lr: float. learning rate.
17- :param momentum: float. momentum .
18- :param beta2: float. beta2 .
26+ :param betas: BETAS. beta1, beta2 .
27+ :param moving_average_for_momentum: bool. perform moving_average for momentum (beta1) .
1928 :param weight_decay: float. weight decay (L2 penalty).
29+ :param decoupled_weight_decay: bool. use decoupled weight_decay.
30+ :param decoupled_learning_rate: bool. use decoupled lr, otherwise couple it w/ preconditioned gradient.
2031 :param inverse_exponent_override: int. fixed exponent for pre-conditioner, if > 0.
2132 :param start_preconditioning_step: int.
2233 :param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
@@ -28,7 +39,8 @@ class Shampoo(Optimizer, BaseOptimizer):
2839 :param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
2940 result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
3041 statistics [4, 4], [3, 3], [1024, 1024], [512, 512].
31- :param graft_type: bool. Type of grafting (SGD or AdaGrad).
42+ :param graft_type: int. type of grafting (SGD or AdaGrad or RMSProp or SQRT_N or None).
43+ :param pre_conditioner_type: int. type of pre-conditioner.
3244 :param nesterov: bool. Nesterov momentum.
3345 :param diagonal_eps: float. term added to the denominator to improve numerical stability.
3446 :param matrix_eps: float. term added to the denominator to improve numerical stability.
@@ -38,31 +50,37 @@ def __init__(
3850 self ,
3951 params : PARAMETERS ,
4052 lr : float = 1e-3 ,
41- momentum : float = 0.0 ,
42- beta2 : float = 1.0 ,
53+ betas : BETAS = ( 0.9 , 0.999 ) ,
54+ moving_average_for_momentum : bool = False ,
4355 weight_decay : float = 0.0 ,
56+ decoupled_weight_decay : bool = False ,
57+ decoupled_learning_rate : bool = True ,
4458 inverse_exponent_override : int = 0 ,
45- start_preconditioning_step : int = 1 ,
59+ start_preconditioning_step : int = 5 ,
4660 preconditioning_compute_steps : int = 1 ,
4761 statistics_compute_steps : int = 1 ,
4862 block_size : int = 128 ,
4963 shape_interpretation : bool = True ,
5064 graft_type : int = LayerWiseGrafting .SGD ,
65+ pre_conditioner_type : int = PreConditionerType .ALL ,
5166 nesterov : bool = True ,
52- diagonal_eps : float = 1e-6 ,
53- matrix_eps : float = 1e-12 ,
67+ diagonal_eps : float = 1e-10 ,
68+ matrix_eps : float = 1e-6 ,
5469 ):
5570 self .lr = lr
56- self .momentum = momentum
57- self .beta2 = beta2
71+ self .betas = betas
72+ self .moving_average_for_momentum = moving_average_for_momentum
5873 self .weight_decay = weight_decay
74+ self .decoupled_weight_decay = decoupled_weight_decay
75+ self .decoupled_learning_rate = decoupled_learning_rate
5976 self .inverse_exponent_override = inverse_exponent_override
6077 self .start_preconditioning_step = start_preconditioning_step
6178 self .preconditioning_compute_steps = preconditioning_compute_steps
6279 self .statistics_compute_steps = statistics_compute_steps
6380 self .block_size = block_size
6481 self .shape_interpretation = shape_interpretation
6582 self .graft_type = graft_type
83+ self .pre_conditioner_type = pre_conditioner_type
6684 self .nesterov = nesterov
6785 self .diagonal_eps = diagonal_eps
6886 self .matrix_eps = matrix_eps
@@ -71,14 +89,14 @@ def __init__(
7189
7290 defaults : DEFAULTS = {
7391 'lr' : lr ,
74- 'momentum ' : momentum ,
92+ 'betas ' : betas ,
7593 'weight_decay' : weight_decay ,
7694 }
7795 super ().__init__ (params , defaults )
7896
7997 def validate_parameters (self ):
8098 self .validate_learning_rate (self .lr )
81- self .validate_momentum (self .momentum )
99+ self .validate_betas (self .betas )
82100 self .validate_weight_decay (self .weight_decay )
83101 self .validate_update_frequency (self .start_preconditioning_step )
84102 self .validate_update_frequency (self .statistics_compute_steps )
@@ -100,16 +118,21 @@ def reset(self):
100118 state ['momentum' ] = torch .zeros_like (p )
101119 state ['pre_conditioner' ] = PreConditioner (
102120 p ,
103- self . beta2 ,
121+ group [ 'betas' ][ 1 ], # beta2
104122 self .inverse_exponent_override ,
105123 self .block_size ,
106124 self .shape_interpretation ,
107125 self .matrix_eps ,
126+ self .pre_conditioner_type ,
108127 )
109128 if self .graft_type == LayerWiseGrafting .ADAGRAD :
110129 state ['graft' ] = AdagradGraft (p , self .diagonal_eps )
130+ elif self .graft_type == LayerWiseGrafting .RMSPROP :
131+ state ['graft' ] = RMSPropGraft (p , self .diagonal_eps )
111132 elif self .graft_type == LayerWiseGrafting .SGD :
112133 state ['graft' ] = SGDGraft (p )
134+ elif self .graft_type == LayerWiseGrafting .SQRTN :
135+ state ['graft' ] = SQRTNGraft (p )
113136 else :
114137 state ['graft' ] = Graft (p )
115138
@@ -121,6 +144,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
121144 loss = closure ()
122145
123146 for group in self .param_groups :
147+ beta1 , beta2 = group ['betas' ]
124148 for p in group ['params' ]:
125149 if p .grad is None :
126150 continue
@@ -135,48 +159,59 @@ def step(self, closure: CLOSURE = None) -> LOSS:
135159 state ['momentum' ] = torch .zeros_like (p )
136160 state ['pre_conditioner' ] = PreConditioner (
137161 p ,
138- self . beta2 ,
162+ beta2 ,
139163 self .inverse_exponent_override ,
140164 self .block_size ,
141165 self .shape_interpretation ,
142166 self .matrix_eps ,
167+ self .pre_conditioner_type ,
143168 )
144169 if self .graft_type == LayerWiseGrafting .ADAGRAD :
145170 state ['graft' ] = AdagradGraft (p , self .diagonal_eps )
171+ elif self .graft_type == LayerWiseGrafting .RMSPROP :
172+ state ['graft' ] = RMSPropGraft (p , self .diagonal_eps )
146173 elif self .graft_type == LayerWiseGrafting .SGD :
147174 state ['graft' ] = SGDGraft (p )
175+ elif self .graft_type == LayerWiseGrafting .SQRTN :
176+ state ['graft' ] = SQRTNGraft (p )
148177 else :
149178 state ['graft' ] = Graft (p )
150179
151180 state ['step' ] += 1
152181 pre_conditioner , graft = state ['pre_conditioner' ], state ['graft' ]
153182
154183 # gather statistics, compute pre-conditioners
155- graft .add_statistics (grad )
184+ graft .add_statistics (grad , beta2 )
156185 if state ['step' ] % self .statistics_compute_steps == 0 :
157186 pre_conditioner .add_statistics (grad )
158187 if state ['step' ] % self .preconditioning_compute_steps == 0 :
159188 pre_conditioner .compute_pre_conditioners ()
160189
161190 # pre-condition gradients
162- graft_grad : torch .Tensor = graft .precondition_gradient (grad )
191+ pre_conditioner_multiplier : float = group ['lr' ] if not self .decoupled_learning_rate else 1.0
192+ graft_grad : torch .Tensor = graft .precondition_gradient (grad * pre_conditioner_multiplier )
163193 shampoo_grad : torch .Tensor = grad
164194 if state ['step' ] >= self .start_preconditioning_step :
165195 shampoo_grad = pre_conditioner .preconditioned_grad (grad )
166196
167197 # grafting
168198 graft_norm = torch .norm (graft_grad )
169199 shampoo_norm = torch .norm (shampoo_grad )
170- shampoo_grad .mul_ (graft_norm / (shampoo_norm + 1e-16 ))
200+ if self .graft_type != LayerWiseGrafting .NONE :
201+ shampoo_grad .mul_ (graft_norm / (shampoo_norm + 1e-16 ))
171202
172203 # apply weight decay (adam style)
173204 if group ['weight_decay' ] > 0.0 :
174- shampoo_grad .add_ (p , alpha = group ['weight_decay' ])
175- graft_grad .add_ (p , alpha = group ['weight_decay' ])
205+ if not self .decoupled_weight_decay :
206+ shampoo_grad .add_ (p , alpha = group ['weight_decay' ])
207+ graft_grad .add_ (p , alpha = group ['weight_decay' ])
208+ else :
209+ shampoo_grad .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
210+ graft_grad .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
176211
177212 # Momentum and Nesterov momentum, if needed
178- state ['momentum' ].mul_ (group [ 'momentum' ] ).add_ (shampoo_grad )
179- graft_momentum = graft .update_momentum (grad , group [ 'momentum' ] )
213+ state ['momentum' ].mul_ (beta1 ).add_ (shampoo_grad )
214+ graft_momentum = graft .update_momentum (grad , beta1 )
180215
181216 if state ['step' ] >= self .start_preconditioning_step :
182217 momentum_update = state ['momentum' ]
@@ -186,7 +221,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
186221 wd_update = graft_grad
187222
188223 if self .nesterov :
189- momentum_update .mul_ (group ['momentum' ]).add_ (wd_update )
224+ w : float = (1.0 - beta1 ) if self .moving_average_for_momentum else 1.0
225+ wd_update .mul_ (w )
226+
227+ momentum_update .mul_ (beta1 ).add_ (wd_update )
190228
191229 p .add_ (momentum_update , alpha = - group ['lr' ])
192230
0 commit comments