1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15+ from functools import partial
1516from itertools import chain
1617from typing import Callable , Iterable , List , Optional , Tuple , Union
1718
@@ -81,6 +82,7 @@ class SOAP(optim.Optimizer):
8182 power_iter_steps: Number of power iteration steps to perform before QR decomposition.
8283 More steps can lead to better convergence but increased computation time.
8384 max_update_rms: Clip the update RMS to this value (0 means no clipping).
85+ use_kl_shampoo: Whether to use KL-Shampoo correction.
8486 """
8587
8688 def __init__ (
@@ -107,6 +109,7 @@ def __init__(
107109 adaptive_update_tolerance : Optional [float ] = None ,
108110 power_iter_steps : int = 1 ,
109111 max_update_rms : float = 0.0 ,
112+ use_kl_shampoo : bool = False ,
110113 ) -> None :
111114 # Check for betas.
112115 if betas is None :
@@ -159,6 +162,7 @@ def __init__(
159162 "adaptive_update_tolerance" : adaptive_update_tolerance ,
160163 "power_iter_steps" : power_iter_steps ,
161164 "max_update_rms" : max_update_rms ,
165+ "use_kl_shampoo" : use_kl_shampoo ,
162166 }
163167 super ().__init__ (params , defaults )
164168
@@ -194,6 +198,21 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
194198 # Exponential moving average of squared gradient values
195199 state ["exp_avg_sq" ] = torch .zeros_like (grad )
196200
201+ if "Q" not in state :
202+ state ["Q" ] = [torch .eye (shape , device = grad .device ) for shape in grad .shape ]
203+
204+ # Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
205+ # because it needs access to state and group
206+ kronecker_factor_update_fn = partial (
207+ update_kronecker_factors , precondition_1d = group ["precondition_1d" ]
208+ )
209+ if group ["use_kl_shampoo" ]:
210+ kronecker_factor_update_fn = partial (
211+ update_kronecker_factors_kl_shampoo ,
212+ eigenbasis_list = state ["Q" ],
213+ eps = group ["eps" ],
214+ )
215+
197216 # Initialize kronecker factor matrices
198217 if "GG" not in state :
199218 state ["GG" ] = init_kronecker_factors (
@@ -204,11 +223,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
204223 # Update preconditioner matrices with gradient statistics,
205224 # do not use shampoo_beta for EMA at first step
206225 with utils .fp32_matmul_precision (group ["fp32_matmul_prec" ]):
207- update_kronecker_factors (
208- kronecker_factor_list = state ["GG" ],
209- grad = grad ,
210- shampoo_beta = 0.0 ,
211- precondition_1d = group ["precondition_1d" ],
226+ kronecker_factor_update_fn (
227+ kronecker_factor_list = state ["GG" ], grad = grad , shampoo_beta = group ["shampoo_beta" ]
212228 )
213229
214230 # Increment step counter
@@ -228,7 +244,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
228244 with utils .fp32_matmul_precision (group ["fp32_matmul_prec" ]):
229245 grad_projected = precondition (
230246 grad = grad ,
231- eigenbasis_list = state . get ( "Q" ) ,
247+ eigenbasis_list = state [ "Q" ] ,
232248 dims = [[0 ], [0 ]],
233249 )
234250 torch .cuda .nvtx .range_pop ()
@@ -255,7 +271,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
255271 with utils .fp32_matmul_precision (group ["fp32_matmul_prec" ]):
256272 norm_precond_grad = precondition (
257273 grad = adam_update ,
258- eigenbasis_list = state . get ( "Q" ) ,
274+ eigenbasis_list = state [ "Q" ] ,
259275 dims = [[0 ], [1 ]],
260276 )
261277 torch .cuda .nvtx .range_pop ()
@@ -283,11 +299,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
283299
284300 torch .cuda .nvtx .range_push ("update_kronecker_factors" )
285301 with utils .fp32_matmul_precision (group ["fp32_matmul_prec" ]):
286- update_kronecker_factors (
302+ kronecker_factor_update_fn (
287303 kronecker_factor_list = state ["GG" ],
288304 grad = grad ,
289- shampoo_beta = shampoo_beta ,
290- precondition_1d = group ["precondition_1d" ],
305+ shampoo_beta = 0.0 ,
291306 )
292307 torch .cuda .nvtx .range_pop ()
293308
@@ -453,6 +468,48 @@ def update_kronecker_factors(
453468 kronecker_factor_list [idx ].lerp_ (outer_product , 1 - shampoo_beta )
454469
455470
471+ @torch .no_grad () # type: ignore[misc]
472+ def update_kronecker_factors_kl_shampoo (
473+ kronecker_factor_list : List [torch .Tensor ],
474+ grad : torch .Tensor ,
475+ shampoo_beta : float ,
476+ eigenbasis_list : List [torch .Tensor ],
477+ eps : float ,
478+ eigval_exp : float = - 1.0 ,
479+ ) -> None :
480+ """Updates the kronecker factor matrices in place using KL-Shampoo correction.
481+
482+ Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378
483+
484+ Args:
485+ kronecker_factor_list: List of preconditioner matrices (L and R) to update.
486+ grad: Gradient tensor of the parameter being optimized
487+ shampoo_beta: Momentum coefficient for updating preconditioners.
488+ eigenbasis_list: List of orthonormal eigenbases of the kronecker factor matrices
489+ eps: Small offset for numerical stability.
490+ eigenval_exp: Exponent of the eigenvalues.
491+ """
492+ assert grad .dim () == 2 , "KL-Shampoo mathematical correction is only supported for 2D tensors"
493+
494+ # Scale the gradient matrix by the approximate eigenvalues and the eigenbasis
495+ # G@Q_R@λ_R^(−1)@[email protected] /dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG) 496+ updates = []
497+ for idx , (kronecker_factor , eigenbasis ) in enumerate (zip (kronecker_factor_list , eigenbasis_list , strict = True )):
498+ approx_eigvals = utils .eig .conjugate (kronecker_factor , eigenbasis , diag = True )
499+ scale_factor = 1 / grad .shape [idx ] * approx_eigvals .clamp_min (eps ) ** eigval_exp
500+
501+ logging .debug (f"scale_factor[{ idx } ]: { scale_factor } " )
502+
503+ correction = (eigenbasis * scale_factor [None , :]) @ eigenbasis .T
504+
505+ maybe_transpose_grad = grad .T if idx == 1 else grad
506+ updates .append (utils .eig .conjugate (correction , maybe_transpose_grad ))
507+
508+ # Note that updates caculated in previous loop are in reverse order of the kronecker factor list they apply to
509+ for kronecker_factor , update in zip (kronecker_factor_list , updates [::- 1 ], strict = True ):
510+ kronecker_factor .lerp_ (update , 1 - shampoo_beta )
511+
512+
456513@torch .no_grad () # type: ignore[misc]
457514def update_eigenbasis_and_momentum (
458515 kronecker_factor_list : List [torch .Tensor ],
0 commit comments