1313 RMSPropGraft ,
1414 SGDGraft ,
1515 SQRTNGraft ,
16+ compute_power_svd ,
1617)
1718
1819
1920class Shampoo (Optimizer , BaseOptimizer ):
2021 r"""Preconditioned Stochastic Tensor Optimization.
2122
23+ :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
24+ :param lr: float. learning rate.
25+ :param momentum: float. momentum.
26+ :param weight_decay: float. weight decay (L2 penalty).
27+ :param preconditioning_compute_steps: int. performance tuning params for controlling memory and compute
28+ requirements. How often to compute pre-conditioner.
29+ :param matrix_eps: float. term added to the denominator to improve numerical stability.
30+ """
31+
32+ def __init__ (
33+ self ,
34+ params : PARAMETERS ,
35+ lr : float = 1e-3 ,
36+ momentum : float = 0.0 ,
37+ weight_decay : float = 0.0 ,
38+ preconditioning_compute_steps : int = 1 ,
39+ matrix_eps : float = 1e-6 ,
40+ ):
41+ self .lr = lr
42+ self .momentum = momentum
43+ self .weight_decay = weight_decay
44+ self .preconditioning_compute_steps = preconditioning_compute_steps
45+ self .matrix_eps = matrix_eps
46+
47+ self .validate_parameters ()
48+
49+ defaults : DEFAULTS = {
50+ 'lr' : lr ,
51+ 'momentum' : momentum ,
52+ 'weight_decay' : weight_decay ,
53+ }
54+ super ().__init__ (params , defaults )
55+
56+ def validate_parameters (self ):
57+ self .validate_learning_rate (self .lr )
58+ self .validate_momentum (self .momentum )
59+ self .validate_weight_decay (self .weight_decay )
60+ self .validate_update_frequency (self .preconditioning_compute_steps )
61+ self .validate_epsilon (self .matrix_eps )
62+
63+ @property
64+ def __str__ (self ) -> str :
65+ return 'Shampoo'
66+
67+ @torch .no_grad ()
68+ def reset (self ):
69+ for group in self .param_groups :
70+ for p in group ['params' ]:
71+ state = self .state [p ]
72+
73+ state ['step' ] = 0
74+
75+ @torch .no_grad ()
76+ def step (self , closure : CLOSURE = None ) -> LOSS :
77+ loss : LOSS = None
78+ if closure is not None :
79+ with torch .enable_grad ():
80+ loss = closure ()
81+
82+ for group in self .param_groups :
83+ momentum = group ['momentum' ]
84+ for p in group ['params' ]:
85+ if p .grad is None :
86+ continue
87+
88+ grad = p .grad
89+ if grad .is_sparse :
90+ raise NoSparseGradientError (self .__str__ )
91+
92+ state = self .state [p ]
93+ if len (state ) == 0 :
94+ state ['step' ] = 0
95+
96+ if momentum > 0.0 :
97+ state ['momentum_buffer' ] = grad .clone ()
98+
99+ for dim_id , dim in enumerate (grad .size ()):
100+ state [f'pre_cond_{ dim_id } ' ] = self .matrix_eps * torch .eye (dim , out = grad .new (dim , dim ))
101+ state [f'inv_pre_cond_{ dim_id } ' ] = grad .new (dim , dim ).zero_ ()
102+
103+ state ['step' ] += 1
104+
105+ if momentum > 0.0 :
106+ grad .mul_ (1.0 - momentum ).add_ (state ['momentum_buffer' ], alpha = momentum )
107+
108+ if group ['weight_decay' ] > 0.0 :
109+ grad .add_ (p , alpha = group ['weight_decay' ])
110+
111+ order : int = grad .ndimension ()
112+ original_size : int = grad .size ()
113+ for dim_id , dim in enumerate (grad .size ()):
114+ pre_cond = state [f'pre_cond_{ dim_id } ' ]
115+ inv_pre_cond = state [f'inv_pre_cond_{ dim_id } ' ]
116+
117+ grad = grad .transpose_ (0 , dim_id ).contiguous ()
118+ transposed_size = grad .size ()
119+
120+ grad = grad .view (dim , - 1 )
121+
122+ grad_t = grad .t ()
123+ pre_cond .add_ (grad @ grad_t )
124+ if state ['step' ] % self .preconditioning_compute_steps == 0 :
125+ inv_pre_cond .copy_ (compute_power_svd (pre_cond , - 1.0 / order ))
126+
127+ if dim_id == order - 1 :
128+ grad = grad_t @ inv_pre_cond
129+ grad = grad .view (original_size )
130+ else :
131+ grad = inv_pre_cond @ grad
132+ grad = grad .view (transposed_size )
133+
134+ state ['momentum_buffer' ] = grad
135+
136+ p .add_ (grad , alpha = - group ['lr' ])
137+
138+ return loss
139+
140+
141+ class ScalableShampoo (Optimizer , BaseOptimizer ):
142+ r"""Scalable Preconditioned Stochastic Tensor Optimization.
143+
22144 Reference : https://github.com/google-research/google-research/blob/master/scalable_shampoo/pytorch/shampoo.py.
23145
24146 :param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups.
@@ -45,6 +167,10 @@ class Shampoo(Optimizer, BaseOptimizer):
45167 :param nesterov: bool. Nesterov momentum.
46168 :param diagonal_eps: float. term added to the denominator to improve numerical stability.
47169 :param matrix_eps: float. term added to the denominator to improve numerical stability.
170+ :param use_svd: bool. use SVD instead of Schur-Newton method to calculate M^{-1/p}.
171+ Theoretically, Schur-Newton method is faster than SVD method to calculate M^{-1/p}.
172+ However, the inefficiency of the loop code, SVD is much faster than that.
173+ see https://github.com/kozistr/pytorch_optimizer/pull/103
48174 """
49175
50176 def __init__ (
@@ -60,14 +186,15 @@ def __init__(
60186 start_preconditioning_step : int = 5 ,
61187 preconditioning_compute_steps : int = 1 ,
62188 statistics_compute_steps : int = 1 ,
63- block_size : int = 128 ,
189+ block_size : int = 256 ,
64190 no_preconditioning_for_layers_with_dim_gt : int = 8192 ,
65191 shape_interpretation : bool = True ,
66192 graft_type : int = LayerWiseGrafting .SGD ,
67193 pre_conditioner_type : int = PreConditionerType .ALL ,
68194 nesterov : bool = True ,
69195 diagonal_eps : float = 1e-10 ,
70196 matrix_eps : float = 1e-6 ,
197+ use_svd : bool = False ,
71198 ):
72199 self .lr = lr
73200 self .betas = betas
@@ -87,6 +214,7 @@ def __init__(
87214 self .nesterov = nesterov
88215 self .diagonal_eps = diagonal_eps
89216 self .matrix_eps = matrix_eps
217+ self .use_svd = use_svd
90218
91219 self .validate_parameters ()
92220
@@ -109,7 +237,7 @@ def validate_parameters(self):
109237
110238 @property
111239 def __str__ (self ) -> str :
112- return 'Shampoo '
240+ return 'ScalableShampoo '
113241
114242 @torch .no_grad ()
115243 def reset (self ):
@@ -128,6 +256,7 @@ def reset(self):
128256 self .shape_interpretation ,
129257 self .matrix_eps ,
130258 self .pre_conditioner_type ,
259+ self .use_svd ,
131260 )
132261 if self .graft_type == LayerWiseGrafting .ADAGRAD :
133262 state ['graft' ] = AdaGradGraft (p , self .diagonal_eps )
@@ -140,6 +269,9 @@ def reset(self):
140269 else :
141270 state ['graft' ] = Graft (p )
142271
272+ def is_precondition_step (self , step : int ) -> bool :
273+ return step >= self .start_preconditioning_step
274+
143275 @torch .no_grad ()
144276 def step (self , closure : CLOSURE = None ) -> LOSS :
145277 loss : LOSS = None
@@ -170,6 +302,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
170302 self .shape_interpretation ,
171303 self .matrix_eps ,
172304 self .pre_conditioner_type ,
305+ self .use_svd ,
173306 )
174307 if self .graft_type == LayerWiseGrafting .ADAGRAD :
175308 state ['graft' ] = AdaGradGraft (p , self .diagonal_eps )
@@ -185,27 +318,26 @@ def step(self, closure: CLOSURE = None) -> LOSS:
185318 state ['step' ] += 1
186319 pre_conditioner , graft = state ['pre_conditioner' ], state ['graft' ]
187320
188- # gather statistics, compute pre-conditioners
321+ is_precondition_step : bool = self .is_precondition_step (state ['step' ])
322+
189323 graft .add_statistics (grad , beta2 )
190324 if state ['step' ] % self .statistics_compute_steps == 0 :
191325 pre_conditioner .add_statistics (grad )
192326 if state ['step' ] % self .preconditioning_compute_steps == 0 :
193327 pre_conditioner .compute_pre_conditioners ()
194328
195- # pre-condition gradients
196329 pre_conditioner_multiplier : float = group ['lr' ] if not self .decoupled_learning_rate else 1.0
197330 graft_grad : torch .Tensor = graft .precondition_gradient (grad * pre_conditioner_multiplier )
198331 shampoo_grad : torch .Tensor = grad
199- if state [ 'step' ] >= self . start_preconditioning_step :
332+ if is_precondition_step :
200333 shampoo_grad = pre_conditioner .preconditioned_grad (grad )
201334
202- # grafting
203- graft_norm = torch .norm (graft_grad )
204- shampoo_norm = torch .norm (shampoo_grad )
205335 if self .graft_type != LayerWiseGrafting .NONE :
336+ graft_norm = torch .norm (graft_grad )
337+ shampoo_norm = torch .norm (shampoo_grad )
338+
206339 shampoo_grad .mul_ (graft_norm / (shampoo_norm + 1e-16 ))
207340
208- # apply weight decay (adam style)
209341 if group ['weight_decay' ] > 0.0 :
210342 if not self .decoupled_weight_decay :
211343 shampoo_grad .add_ (p , alpha = group ['weight_decay' ])
@@ -214,11 +346,10 @@ def step(self, closure: CLOSURE = None) -> LOSS:
214346 shampoo_grad .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
215347 graft_grad .mul_ (1.0 - group ['lr' ] * group ['weight_decay' ])
216348
217- # Momentum and Nesterov momentum, if needed
218349 state ['momentum' ].mul_ (beta1 ).add_ (shampoo_grad )
219350 graft_momentum = graft .update_momentum (grad , beta1 )
220351
221- if state [ 'step' ] >= self . start_preconditioning_step :
352+ if is_precondition_step :
222353 momentum_update = state ['momentum' ]
223354 wd_update = shampoo_grad
224355 else :
0 commit comments