@@ -61,7 +61,6 @@ class SOAP(optim.Optimizer):
6161 precondition_warmup_steps: How many steps to warm up the preconditioner (i.e. update every step)
6262 adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)
6363 precondition_1d: Whether to precondition 1D gradients (like biases).
64- max_precond_dim: Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds.
6564 trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix
6665 normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer
6766 correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA
@@ -91,7 +90,6 @@ def __init__(
9190 precondition_warmup_steps : int = 0 ,
9291 adam_warmup_steps : int = 1 ,
9392 precondition_1d : bool = False ,
94- max_precond_dim : int = 8192 ,
9593 trace_normalization : bool = False ,
9694 normalize_preconditioned_grads : bool = False ,
9795 correct_bias : bool = True ,
@@ -141,7 +139,6 @@ def __init__(
141139 "precondition_warmup_steps" : precondition_warmup_steps ,
142140 "adam_warmup_steps" : adam_warmup_steps ,
143141 "precondition_1d" : precondition_1d ,
144- "max_precond_dim" : max_precond_dim ,
145142 "trace_normalization" : trace_normalization ,
146143 "normalize_preconditioned_grads" : normalize_preconditioned_grads ,
147144 "use_nesterov" : use_nesterov ,
@@ -194,7 +191,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
194191 state ["GG" ] = init_kronecker_factors (
195192 grad ,
196193 precondition_1d = group ["precondition_1d" ],
197- max_precond_dim = group ["max_precond_dim" ],
198194 )
199195
200196 # Update preconditioner matrices with gradient statistics, do not use shampoo_beta for EMA at first step
@@ -204,7 +200,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
204200 grad = grad ,
205201 shampoo_beta = 0.0 ,
206202 precondition_1d = group ["precondition_1d" ],
207- max_precond_dim = group ["max_precond_dim" ],
208203 )
209204
210205 # Increment step counter
@@ -284,7 +279,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
284279 grad = grad ,
285280 shampoo_beta = shampoo_beta ,
286281 precondition_1d = group ["precondition_1d" ],
287- max_precond_dim = group ["max_precond_dim" ],
288282 )
289283 torch .cuda .nvtx .range_pop ()
290284
@@ -330,7 +324,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
330324def init_kronecker_factors (
331325 grad : torch .Tensor ,
332326 precondition_1d : bool = False ,
333- max_precond_dim : int = 8192 ,
334327) -> List [torch .Tensor ]:
335328 """Initializes the kronecker factor matrices for the SOAP optimizer.
336329
@@ -354,8 +347,6 @@ def init_kronecker_factors(
354347 The shape of this tensor determines the size of the kronecker factor matrices.
355348 precondition_1d: Whether to create kronecker factor matrices for 1D tensors
356349 (like biases). If False, 1D tensors will skip preconditioning.
357- max_precond_dim: Maximum dimension of the preconditioner matrices.
358- Skips preconditioning if any tensor dimension exceeds.
359350
360351 Returns:
361352 List[torch.Tensor]: List of kronecker factor matrices (L and R in paper).
@@ -387,21 +378,11 @@ def init_kronecker_factors(
387378 else :
388379 # Create a square preconditioner matrix for 1D tensors
389380 size = grad .shape [0 ]
390- if size > max_precond_dim :
391- # if tensor dimension is larger than max_precond_dim, skip preconditioning this dimension
392- # append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail
393- kronecker_factor_list .append (torch .empty (0 , device = grad .device ))
394- else :
395- kronecker_factor_list .append (torch .zeros (size , size , device = grad .device ))
381+ kronecker_factor_list .append (torch .zeros (size , size , device = grad .device ))
396382 else :
397383 # Create a square kronecker factor matrix for each dimension
398384 for size in grad .shape :
399- if size > max_precond_dim :
400- # append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail
401- # skip preconditioning this dimension
402- kronecker_factor_list .append (torch .empty (0 , device = grad .device ))
403- else :
404- kronecker_factor_list .append (torch .zeros (size , size , device = grad .device ))
385+ kronecker_factor_list .append (torch .zeros (size , size , device = grad .device ))
405386
406387 return kronecker_factor_list
407388
@@ -412,7 +393,6 @@ def update_kronecker_factors(
412393 grad : torch .Tensor ,
413394 shampoo_beta : float ,
414395 precondition_1d : bool = False ,
415- max_precond_dim : int = 8192 ,
416396) -> None :
417397 """Updates the preconditioner matrices using gradient outer products.
418398
@@ -429,8 +409,6 @@ def update_kronecker_factors(
429409 Controls how much weight to give to new vs old gradient statistics.
430410 precondition_1d: Whether to apply preconditioning to 1D tensors (like biases).
431411 If False, 1D tensors will skip preconditioning.
432- max_precond_dim: Maximum dimension of the preconditioner matrices.
433- Skips preconditioning if any tensor dimension exceeds.
434412
435413 Example:
436414 >>> grad = torch.randn(10, 20)
@@ -446,20 +424,22 @@ def update_kronecker_factors(
446424 kronecker_factor_list [0 ].lerp_ (outer_product , 1 - shampoo_beta )
447425 else :
448426 # For 1D tensors, skip preconditioning
427+ logging .error (
428+ "1D tensor is passed to update_kronecker_factors, but precondition_1d is not set to True, skipping preconditioning."
429+ )
449430 return
450431 else :
451432 # For higher dimensional tensors, compute outer products for each dimension
452433 for idx , dim_size in enumerate (grad .shape ):
453- if dim_size <= max_precond_dim :
454- # Compute outer product by contracting all dimensions except idx
455- contract_dims = [* chain (range (idx ), range (idx + 1 , grad .dim ()))]
456- outer_product = torch .tensordot (
457- grad ,
458- grad ,
459- dims = [contract_dims ] * 2 ,
460- )
461- # Update the corresponding Kronecker factor
462- kronecker_factor_list [idx ].lerp_ (outer_product , 1 - shampoo_beta )
434+ # Compute outer product by contracting all dimensions except idx
435+ contract_dims = [* chain (range (idx ), range (idx + 1 , grad .dim ()))]
436+ outer_product = torch .tensordot (
437+ grad ,
438+ grad ,
439+ dims = [contract_dims ] * 2 ,
440+ )
441+ # Update the corresponding Kronecker factor
442+ kronecker_factor_list [idx ].lerp_ (outer_product , 1 - shampoo_beta )
463443
464444
465445@torch .no_grad () # type: ignore[misc]
0 commit comments