@@ -69,8 +69,6 @@ class SOAP(optim.Optimizer):
6969 or a callable function that takes the current step as input and returns the frequency.
7070 adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)
7171 precondition_1d: Whether to precondition 1D gradients (like biases).
72- trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix
73- normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer
7472 correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA
7573 fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations
7674 use_eigh: Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis.
@@ -83,23 +81,23 @@ class SOAP(optim.Optimizer):
8381 More steps can lead to better convergence but increased computation time.
8482 max_update_rms: Clip the update RMS to this value (0 means no clipping).
8583 use_kl_shampoo: Whether to use KL-Shampoo correction.
84+ correct_shampoo_beta_bias: Whether to correct shampoo beta bias. Decoupled it from correct_bias for
85+ testability because reference implementation of Soap doesn't bias correct shampoo beta.
8686 """
8787
8888 def __init__ (
8989 self ,
9090 params : ParamsT ,
91- lr : float = 3e-3 ,
92- betas : Tuple [float , float ] = (0.95 , 0.95 ),
91+ lr : float ,
92+ betas : Tuple [float , float ] = (0.9 , 0.95 ),
9393 shampoo_beta : float = 0.95 ,
9494 eps : float = 1e-8 ,
9595 weight_decay : float = 0.01 ,
9696 use_decoupled_wd : bool = True ,
9797 use_nesterov : bool = False ,
9898 precondition_frequency : Union [int , Callable [[int ], int ]] = 1 ,
99- adam_warmup_steps : int = 1 ,
99+ adam_warmup_steps : int = 0 ,
100100 precondition_1d : bool = False ,
101- trace_normalization : bool = False ,
102- normalize_preconditioned_grads : bool = False ,
103101 correct_bias : bool = True ,
104102 fp32_matmul_prec : str = "high" ,
105103 use_eigh : bool = False ,
@@ -109,12 +107,11 @@ def __init__(
109107 power_iter_steps : int = 1 ,
110108 max_update_rms : float = 0.0 ,
111109 use_kl_shampoo : bool = False ,
110+ correct_shampoo_beta_bias : bool | None = None ,
112111 ) -> None :
113112 self .precondition_frequency = precondition_frequency
114113 self .adam_warmup_steps = adam_warmup_steps
115114 self .precondition_1d = precondition_1d
116- self .trace_normalization = trace_normalization
117- self .normalize_preconditioned_grads = normalize_preconditioned_grads
118115 self .use_nesterov = use_nesterov
119116 self .correct_bias = correct_bias
120117 self .use_decoupled_wd = use_decoupled_wd
@@ -126,6 +123,10 @@ def __init__(
126123 self .power_iter_steps = power_iter_steps
127124 self .max_update_rms = max_update_rms
128125 self .use_kl_shampoo = use_kl_shampoo
126+ if correct_shampoo_beta_bias is not None :
127+ self .correct_shampoo_beta_bias = correct_shampoo_beta_bias
128+ else :
129+ self .correct_shampoo_beta_bias = correct_bias
129130
130131 defaults = {
131132 "lr" : lr ,
@@ -160,155 +161,132 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
160161 if "step" not in state :
161162 state ["step" ] = 0
162163
163- # State initialization
164- # (TODO @mkhona): Better way to check state initialization - use state initializer?
165- if "exp_avg" not in state :
164+ # NOTE: The upstream PyTorch implementations increment the step counter in the middle of the loop
165+ # to be used in bias correction. But this is confusing and error prone if anything else needs to use
166+ # the step counter.
167+ # We decided to follow Python and C convention to increment the step counter at the end of the loop.
168+ # An explicitly named 1-based iteration/step counter is created for bias correction and other terms
169+ # in the math equation that needs 1-based iteration count.
170+ curr_iter_1_based = state ["step" ] + 1
171+
172+ # TODO(Mkhona): Improve initialization handling.
173+ # - More protective checks can be added to avoid potential issues with checkpointing.
174+ # - Initializing zero buffers can also be avoided.
175+ if state ["step" ] == 0 :
176+ assert all (key not in state for key in ["exp_avg" , "exp_avg_sq" , "GG" ]), (
177+ "exp_avg and exp_avg_sq and GG should not be initialized at step 0. "
178+ "Some mismatch has been created likely in checkpointing"
179+ )
166180 # Exponential moving average of gradient values
167181 state ["exp_avg" ] = torch .zeros_like (grad )
168182 # Exponential moving average of squared gradient values
169183 state ["exp_avg_sq" ] = torch .zeros_like (grad )
170-
171- if "Q" not in state :
172- state ["Q" ] = [torch .eye (shape , device = grad .device ) for shape in grad .shape ]
184+ # Initialize kronecker factor matrices
185+ state ["GG" ] = init_kronecker_factors (
186+ grad ,
187+ precondition_1d = self .precondition_1d ,
188+ )
173189
174190 # Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
175191 # because it needs access to state and group
176- kronecker_factor_update_fn = partial (update_kronecker_factors , precondition_1d = self .precondition_1d )
177- if self .use_kl_shampoo :
192+ if not self .use_kl_shampoo :
193+ kronecker_factor_update_fn = partial (
194+ update_kronecker_factors ,
195+ precondition_1d = self .precondition_1d ,
196+ )
197+ else :
198+ if "Q" not in state :
199+ assert state ["step" ] == 0 , (
200+ f"Q should already be initialized at step { state ['step' ]} , Some mismatch has been created "
201+ "likely in checkpointing"
202+ )
203+ state ["Q" ] = [torch .eye (shape , device = grad .device ) for shape in grad .shape ]
178204 kronecker_factor_update_fn = partial (
179205 update_kronecker_factors_kl_shampoo ,
180206 eigenbasis_list = state ["Q" ],
181207 eps = group ["eps" ],
182208 )
183209
184- # Initialize kronecker factor matrices
185- if "GG" not in state :
186- state ["GG" ] = init_kronecker_factors (
187- grad ,
188- precondition_1d = self .precondition_1d ,
189- )
210+ shampoo_beta = group ["shampoo_beta" ]
211+ if self .correct_shampoo_beta_bias :
212+ shampoo_beta = 1 - (1 - shampoo_beta ) / (1 - shampoo_beta ** curr_iter_1_based )
190213
191- # Update preconditioner matrices with gradient statistics,
192- # do not use shampoo_beta for EMA at first step
193- with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
194- kronecker_factor_update_fn (
195- kronecker_factor_list = state ["GG" ], grad = grad , shampoo_beta = group ["shampoo_beta" ]
196- )
214+ torch .cuda .nvtx .range_push ("update_kronecker_factors" )
215+ with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
216+ kronecker_factor_update_fn (kronecker_factor_list = state ["GG" ], grad = grad , shampoo_beta = shampoo_beta )
217+ torch .cuda .nvtx .range_pop ()
197218
198- # Increment step counter
199- state ["step" ] += 1
219+ # After the adam_warmup_steps are completed , update eigenbases at precondition_frequency steps
220+ torch .cuda .nvtx .range_push ("Update eigen basis" )
221+ if _is_eigenbasis_update_step (
222+ state ["step" ],
223+ self .adam_warmup_steps ,
224+ self .precondition_frequency ,
225+ ):
226+ # Always use eigh for the first eigenbasis update
227+ use_eigh = self .use_eigh if state ["step" ] != self .adam_warmup_steps else True
228+
229+ with utils .fp32_matmul_precision (self .qr_fp32_matmul_prec ):
230+ state ["Q" ], state ["exp_avg" ], state ["exp_avg_sq" ] = update_eigenbasis_and_momentum (
231+ kronecker_factor_list = state ["GG" ],
232+ eigenbasis_list = state .get ("Q" , None ),
233+ exp_avg_sq = state ["exp_avg_sq" ],
234+ momentum = state ["exp_avg" ],
235+ use_eigh = use_eigh ,
236+ use_adaptive_criteria = self .use_adaptive_criteria ,
237+ adaptive_update_tolerance = self .adaptive_update_tolerance ,
238+ power_iter_steps = self .power_iter_steps ,
239+ )
240+ torch .cuda .nvtx .range_pop ()
200241
201- # Apply weight decay
202242 if group ["weight_decay" ] > 0.0 :
203243 if self .use_decoupled_wd :
204- # Apply decoupled weight decay
205244 p .add_ (p , alpha = (- group ["lr" ] * group ["weight_decay" ]))
206245 else :
207- # add l2 regularization before preconditioning (i.e. like adding a squared loss term)
208246 grad += group ["weight_decay" ] * p
209247
210- # Projecting gradients to the eigenbases of Shampoo's preconditioner
248+ grad_projected = grad
249+ # Project gradients to the eigenbases of Shampoo's preconditioner
211250 torch .cuda .nvtx .range_push ("precondition" )
212- with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
213- grad_projected = precondition (
214- grad = grad ,
215- eigenbasis_list = state ["Q" ],
216- dims = [[0 ], [0 ]],
217- )
251+ if state ["step" ] >= self .adam_warmup_steps :
252+ with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
253+ grad_projected = precondition (
254+ grad = grad ,
255+ eigenbasis_list = state ["Q" ],
256+ dims = [[0 ], [0 ]],
257+ )
218258 torch .cuda .nvtx .range_pop ()
219259
220- exp_avg , exp_avg_sq = state ["exp_avg" ], state ["exp_avg_sq" ]
221-
222260 # Calculate the Adam update for the projected gradient tensor
223- torch .cuda .nvtx .range_push ("calculate_adam_update" )
224261 adam_update = calculate_adam_update (
225262 grad_projected ,
226- exp_avg ,
227- exp_avg_sq ,
263+ state [ " exp_avg" ] ,
264+ state [ " exp_avg_sq" ] ,
228265 group ["betas" ],
229266 self .correct_bias ,
230267 self .use_nesterov ,
231- state [ "step" ],
268+ curr_iter_1_based , # 1-based iteration index is used for bias correction
232269 group ["eps" ],
233270 )
234- step_size = group ["lr" ]
235- torch .cuda .nvtx .range_pop ()
236271
237272 # Projecting back the preconditioned (by ADAM) exponential moving average of gradients
238273 torch .cuda .nvtx .range_push ("precondition" )
239- with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
240- norm_precond_grad = precondition (
241- grad = adam_update ,
242- eigenbasis_list = state ["Q" ],
243- dims = [[0 ], [1 ]],
244- )
245- torch .cuda .nvtx .range_pop ()
246-
247- if self .trace_normalization :
248- if state ["GG" ][0 ].numel () > 0 :
249- trace_normalization = 1 / torch .sqrt (torch .trace (state ["GG" ][0 ]))
250- norm_precond_grad = norm_precond_grad / trace_normalization
251-
252- if self .normalize_preconditioned_grads :
253- norm_precond_grad = norm_precond_grad / (1e-30 + torch .mean (norm_precond_grad ** 2 ) ** 0.5 )
254-
255- # Clip the update RMS to a maximum value
256- _clip_update_rms_in_place (norm_precond_grad , self .max_update_rms )
257-
258- torch .cuda .nvtx .range_push ("weight update" )
259- p .add_ (norm_precond_grad , alpha = - step_size )
260- torch .cuda .nvtx .range_pop ()
261-
262- # Update kronecker factor matrices with gradient statistics
263- shampoo_beta = group ["shampoo_beta" ] if group ["shampoo_beta" ] >= 0 else group ["betas" ][1 ]
264- if self .correct_bias :
265- # step size correction for shampoo kronecker factors EMA
266- shampoo_beta = 1 - (1 - shampoo_beta ) / (1 - shampoo_beta ** (state ["step" ] + 1 ))
267-
268- torch .cuda .nvtx .range_push ("update_kronecker_factors" )
269- with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
270- kronecker_factor_update_fn (
271- kronecker_factor_list = state ["GG" ],
272- grad = grad ,
273- shampoo_beta = 0.0 ,
274- )
275- torch .cuda .nvtx .range_pop ()
276-
277- # If current step is the last step to skip preconditioning, initialize eigenbases and
278- # end first order warmup
279- if state ["step" ] == self .adam_warmup_steps :
280- # Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition
281- state ["Q" ] = get_eigenbasis_eigh (state ["GG" ])
282- # rotate momentum to the new eigenbasis
274+ if state ["step" ] >= self .adam_warmup_steps :
283275 with utils .fp32_matmul_precision (self .fp32_matmul_prec ):
284- state ["exp_avg" ] = precondition (
285- grad = state ["exp_avg" ],
286- eigenbasis_list = state ["Q" ],
287- dims = [[0 ], [0 ]],
288- )
289- continue
290-
291- # After the adam_warmup_steps are completed.
292- # Update eigenbases at precondition_frequency steps
293- torch .cuda .nvtx .range_push ("Update eigen basis" )
294- if _is_eigenbasis_update_step (
295- state ["step" ],
296- self .adam_warmup_steps ,
297- self .precondition_frequency ,
298- ):
299- with utils .fp32_matmul_precision (self .qr_fp32_matmul_prec ):
300- state ["Q" ], state ["exp_avg" ], state ["exp_avg_sq" ] = update_eigenbasis_and_momentum (
301- kronecker_factor_list = state ["GG" ],
302- eigenbasis_list = state ["Q" ],
303- exp_avg_sq = state ["exp_avg_sq" ],
304- momentum = state ["exp_avg" ],
305- use_eigh = self .use_eigh ,
306- use_adaptive_criteria = self .use_adaptive_criteria ,
307- adaptive_update_tolerance = self .adaptive_update_tolerance ,
308- power_iter_steps = self .power_iter_steps ,
276+ precond_update = precondition (
277+ grad = adam_update ,
278+ eigenbasis_list = state .get ("Q" , None ),
279+ dims = [[0 ], [1 ]],
309280 )
281+ else :
282+ precond_update = adam_update
310283 torch .cuda .nvtx .range_pop ()
311284
285+ _clip_update_rms_in_place (precond_update , self .max_update_rms )
286+ p .add_ (precond_update , alpha = - group ["lr" ])
287+
288+ state ["step" ] += 1
289+
312290 return loss
313291
314292
@@ -581,7 +559,7 @@ def update_eigenbasis_and_momentum(
581559@torch .compile # type: ignore[misc]
582560def precondition (
583561 grad : torch .Tensor ,
584- eigenbasis_list : Optional [List [torch .Tensor ]],
562+ eigenbasis_list : Optional [List [torch .Tensor ]] = None ,
585563 dims : Optional [List [List [int ]]] = None ,
586564) -> torch .Tensor :
587565 """Projects the gradient to and from the eigenbases of the kronecker factor matrices.
@@ -607,7 +585,7 @@ def precondition(
607585 # Pick contraction dims to project to the eigenbasis
608586 dims = [[0 ], [0 ]]
609587
610- if not eigenbasis_list :
588+ if eigenbasis_list is None :
611589 # If eigenbases are not provided, return the gradient without any preconditioning
612590 return grad
613591
@@ -653,7 +631,7 @@ def _is_eigenbasis_update_step(
653631
654632
655633@torch .compile # type: ignore[misc]
656- def _clip_update_rms_in_place (u : torch .Tensor , max_rms : float = 1.0 , eps : float = 1e-12 ) -> None :
634+ def _clip_update_rms_in_place (u : torch .Tensor , max_rms : float , eps : float = 1e-7 ) -> None :
657635 """Clip the update root mean square (RMS) to a maximum value, in place.
658636
659637 Do not clip if max_rms is 0.
0 commit comments