@@ -130,6 +130,18 @@ def step(self, closure: CLOSURE = None) -> LOSS:
130130class ScalableShampoo (Optimizer , BaseOptimizer ):
131131 r"""Scalable Preconditioned Stochastic Tensor Optimization.
132132
133+ This version of Scalable Shampoo Optimizer aims for a single GPU environment, not for a distributed environment
134+ or XLA devices. So, the original intention is to compute pre-conditioners asynchronously on the distributed
135+ CPUs, but this implementation calculates them which takes 99% of the optimization time on a GPU synchronously.
136+
137+ Still, it is much faster than the previous Shampoo Optimizer because using coupled Newton iteration when
138+ computing G^{-1/p} matrices while the previous one uses SVD which is really slow.
139+
140+ Also, this implementation offers
141+ 1. lots of plug-ins (e.g. gradient grafting, type of pre-conditioning, etc)
142+ 2. not-yet implemented features in the official Pytorch code.
143+ 3. readable, organized, clean code.
144+
133145 Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.
134146
135147 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -151,6 +163,7 @@ class ScalableShampoo(Optimizer, BaseOptimizer):
151163 :param block_size: int. Block size for large layers (if > 0).
152164 Block size = 1 ==> Adagrad (Don't do this, extremely inefficient!)
153165 Block size should be as large as feasible under memory/time constraints.
166+ :param skip_preconditioning_rank_lt: int. Skips preconditioning for parameters with rank less than this value.
154167 :param no_preconditioning_for_layers_with_dim_gt: int. avoid preconditioning large layers to reduce overall memory.
155168 :param shape_interpretation: bool. Automatic shape interpretation (for eg: [4, 3, 1024, 512] would
156169 result in 12 x [1024, 512] L and R statistics. Disabled by default which results in Shampoo constructing
@@ -176,10 +189,11 @@ def __init__(
176189 decoupled_weight_decay : bool = False ,
177190 decoupled_learning_rate : bool = True ,
178191 inverse_exponent_override : int = 0 ,
179- start_preconditioning_step : int = 5 ,
180- preconditioning_compute_steps : int = 1 ,
192+ start_preconditioning_step : int = 25 ,
193+ preconditioning_compute_steps : int = 1000 ,
181194 statistics_compute_steps : int = 1 ,
182195 block_size : int = 256 ,
196+ skip_preconditioning_rank_lt : int = 1 ,
183197 no_preconditioning_for_layers_with_dim_gt : int = 8192 ,
184198 shape_interpretation : bool = True ,
185199 graft_type : int = LayerWiseGrafting .SGD ,
@@ -200,6 +214,7 @@ def __init__(
200214 self .preconditioning_compute_steps = preconditioning_compute_steps
201215 self .statistics_compute_steps = statistics_compute_steps
202216 self .block_size = block_size
217+ self .skip_preconditioning_rank_lt = skip_preconditioning_rank_lt
203218 self .no_preconditioning_for_layers_with_dim_gt = no_preconditioning_for_layers_with_dim_gt
204219 self .shape_interpretation = shape_interpretation
205220 self .graft_type = graft_type
@@ -230,20 +245,21 @@ def __str__(self) -> str:
230245 @torch .no_grad ()
231246 def reset (self ):
232247 for group in self .param_groups :
248+ group ['step' ] = 0
233249 for p in group ['params' ]:
234250 state = self .state [p ]
235251
236- state ['step' ] = 0
237252 state ['momentum' ] = torch .zeros_like (p )
238253 state ['pre_conditioner' ] = PreConditioner (
239254 p ,
240255 group ['betas' ][1 ], # beta2
241256 self .inverse_exponent_override ,
242257 self .block_size ,
258+ self .skip_preconditioning_rank_lt ,
243259 self .no_preconditioning_for_layers_with_dim_gt ,
244260 self .shape_interpretation ,
245- self .matrix_eps ,
246261 self .pre_conditioner_type ,
262+ self .matrix_eps ,
247263 self .use_svd ,
248264 )
249265 state ['graft' ] = build_graft (p , self .graft_type , self .diagonal_eps )
@@ -259,6 +275,14 @@ def step(self, closure: CLOSURE = None) -> LOSS:
259275 loss = closure ()
260276
261277 for group in self .param_groups :
278+ if 'step' in group :
279+ group ['step' ] += 1
280+ else :
281+ group ['step' ] = 1
282+
283+ is_precondition_step : bool = self .is_precondition_step (group ['step' ])
284+ pre_conditioner_multiplier : float = group ['lr' ] if not self .decoupled_learning_rate else 1.0
285+
262286 beta1 , beta2 = group ['betas' ]
263287 for p in group ['params' ]:
264288 if p .grad is None :
@@ -270,41 +294,37 @@ def step(self, closure: CLOSURE = None) -> LOSS:
270294
271295 state = self .state [p ]
272296 if len (state ) == 0 :
273- state ['step' ] = 0
274297 state ['momentum' ] = torch .zeros_like (p )
275298 state ['pre_conditioner' ] = PreConditioner (
276299 p ,
277300 beta2 ,
278301 self .inverse_exponent_override ,
279302 self .block_size ,
303+ self .skip_preconditioning_rank_lt ,
280304 self .no_preconditioning_for_layers_with_dim_gt ,
281305 self .shape_interpretation ,
282- self .matrix_eps ,
283306 self .pre_conditioner_type ,
307+ self .matrix_eps ,
284308 self .use_svd ,
285309 )
286310 state ['graft' ] = build_graft (p , self .graft_type , self .diagonal_eps )
287311
288- state ['step' ] += 1
289312 pre_conditioner , graft = state ['pre_conditioner' ], state ['graft' ]
290313
291314 graft .add_statistics (grad , beta2 )
292- if state ['step' ] % self .statistics_compute_steps == 0 :
315+ if group ['step' ] % self .statistics_compute_steps == 0 :
293316 pre_conditioner .add_statistics (grad )
294- if state ['step' ] % self .preconditioning_compute_steps == 0 :
317+ if group ['step' ] % self .preconditioning_compute_steps == 0 :
295318 pre_conditioner .compute_pre_conditioners ()
296319
297- is_precondition_step : bool = self .is_precondition_step (state ['step' ])
298- pre_conditioner_multiplier : float = group ['lr' ] if not self .decoupled_learning_rate else 1.0
299-
300320 graft_grad : torch .Tensor = graft .precondition_gradient (grad * pre_conditioner_multiplier )
301321 shampoo_grad : torch .Tensor = (
302322 pre_conditioner .preconditioned_grad (grad ) if is_precondition_step else grad
303323 )
304324
305325 if self .graft_type != LayerWiseGrafting .NONE :
306- graft_norm = torch .norm (graft_grad )
307- shampoo_norm = torch .norm (shampoo_grad )
326+ graft_norm = torch .linalg . norm (graft_grad )
327+ shampoo_norm = torch .linalg . norm (shampoo_grad )
308328
309329 shampoo_grad .mul_ (graft_norm / (shampoo_norm + 1e-16 ))
310330
@@ -319,15 +339,12 @@ def step(self, closure: CLOSURE = None) -> LOSS:
319339 state ['momentum' ].mul_ (beta1 ).add_ (shampoo_grad )
320340 graft_momentum = graft .update_momentum (grad , beta1 )
321341
322- if is_precondition_step :
323- momentum_update = state ['momentum' ]
324- wd_update = shampoo_grad
325- else :
326- momentum_update = graft_momentum
327- wd_update = graft_grad
342+ momentum_update = state ['momentum' ] if is_precondition_step else graft_momentum
328343
329344 if self .nesterov :
330345 w : float = (1.0 - beta1 ) if self .moving_average_for_momentum else 1.0
346+
347+ wd_update = shampoo_grad if is_precondition_step else graft_grad
331348 wd_update .mul_ (w )
332349
333350 momentum_update .mul_ (beta1 ).add_ (wd_update )
0 commit comments