Skip to content

Commit bc13a6b

Browse files
authored
Add Kullback–Leibler shampoo support to SOAP (#64)
Signed-off-by: Hao Wu <[email protected]>
1 parent cf9909b commit bc13a6b

File tree

5 files changed

+193
-128
lines changed

5 files changed

+193
-128
lines changed

emerging_optimizers/soap/soap.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
from functools import partial
1516
from itertools import chain
1617
from typing import Callable, Iterable, List, Optional, Tuple, Union
1718

@@ -81,6 +82,7 @@ class SOAP(optim.Optimizer):
8182
power_iter_steps: Number of power iteration steps to perform before QR decomposition.
8283
More steps can lead to better convergence but increased computation time.
8384
max_update_rms: Clip the update RMS to this value (0 means no clipping).
85+
use_kl_shampoo: Whether to use KL-Shampoo correction.
8486
"""
8587

8688
def __init__(
@@ -107,6 +109,7 @@ def __init__(
107109
adaptive_update_tolerance: Optional[float] = None,
108110
power_iter_steps: int = 1,
109111
max_update_rms: float = 0.0,
112+
use_kl_shampoo: bool = False,
110113
) -> None:
111114
# Check for betas.
112115
if betas is None:
@@ -159,6 +162,7 @@ def __init__(
159162
"adaptive_update_tolerance": adaptive_update_tolerance,
160163
"power_iter_steps": power_iter_steps,
161164
"max_update_rms": max_update_rms,
165+
"use_kl_shampoo": use_kl_shampoo,
162166
}
163167
super().__init__(params, defaults)
164168

@@ -194,6 +198,21 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
194198
# Exponential moving average of squared gradient values
195199
state["exp_avg_sq"] = torch.zeros_like(grad)
196200

201+
if "Q" not in state:
202+
state["Q"] = [torch.eye(shape, device=grad.device) for shape in grad.shape]
203+
204+
# Define kronecker_factor_update_fn based on whether to use KL-Shampoo here
205+
# because it needs access to state and group
206+
kronecker_factor_update_fn = partial(
207+
update_kronecker_factors, precondition_1d=group["precondition_1d"]
208+
)
209+
if group["use_kl_shampoo"]:
210+
kronecker_factor_update_fn = partial(
211+
update_kronecker_factors_kl_shampoo,
212+
eigenbasis_list=state["Q"],
213+
eps=group["eps"],
214+
)
215+
197216
# Initialize kronecker factor matrices
198217
if "GG" not in state:
199218
state["GG"] = init_kronecker_factors(
@@ -204,11 +223,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
204223
# Update preconditioner matrices with gradient statistics,
205224
# do not use shampoo_beta for EMA at first step
206225
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
207-
update_kronecker_factors(
208-
kronecker_factor_list=state["GG"],
209-
grad=grad,
210-
shampoo_beta=0.0,
211-
precondition_1d=group["precondition_1d"],
226+
kronecker_factor_update_fn(
227+
kronecker_factor_list=state["GG"], grad=grad, shampoo_beta=group["shampoo_beta"]
212228
)
213229

214230
# Increment step counter
@@ -228,7 +244,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
228244
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
229245
grad_projected = precondition(
230246
grad=grad,
231-
eigenbasis_list=state.get("Q"),
247+
eigenbasis_list=state["Q"],
232248
dims=[[0], [0]],
233249
)
234250
torch.cuda.nvtx.range_pop()
@@ -255,7 +271,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
255271
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
256272
norm_precond_grad = precondition(
257273
grad=adam_update,
258-
eigenbasis_list=state.get("Q"),
274+
eigenbasis_list=state["Q"],
259275
dims=[[0], [1]],
260276
)
261277
torch.cuda.nvtx.range_pop()
@@ -283,11 +299,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
283299

284300
torch.cuda.nvtx.range_push("update_kronecker_factors")
285301
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
286-
update_kronecker_factors(
302+
kronecker_factor_update_fn(
287303
kronecker_factor_list=state["GG"],
288304
grad=grad,
289-
shampoo_beta=shampoo_beta,
290-
precondition_1d=group["precondition_1d"],
305+
shampoo_beta=0.0,
291306
)
292307
torch.cuda.nvtx.range_pop()
293308

@@ -453,6 +468,48 @@ def update_kronecker_factors(
453468
kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta)
454469

455470

471+
@torch.no_grad() # type: ignore[misc]
472+
def update_kronecker_factors_kl_shampoo(
473+
kronecker_factor_list: List[torch.Tensor],
474+
grad: torch.Tensor,
475+
shampoo_beta: float,
476+
eigenbasis_list: List[torch.Tensor],
477+
eps: float,
478+
eigval_exp: float = -1.0,
479+
) -> None:
480+
"""Updates the kronecker factor matrices in place using KL-Shampoo correction.
481+
482+
Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378
483+
484+
Args:
485+
kronecker_factor_list: List of preconditioner matrices (L and R) to update.
486+
grad: Gradient tensor of the parameter being optimized
487+
shampoo_beta: Momentum coefficient for updating preconditioners.
488+
eigenbasis_list: List of orthonormal eigenbases of the kronecker factor matrices
489+
eps: Small offset for numerical stability.
490+
eigenval_exp: Exponent of the eigenvalues.
491+
"""
492+
assert grad.dim() == 2, "KL-Shampoo mathematical correction is only supported for 2D tensors"
493+
494+
# Scale the gradient matrix by the approximate eigenvalues and the eigenbasis
495+
# G@Q_R@λ_R^(−1)@[email protected]/dim(GG.T) and G.T@Q_L@λ_L^(−1)@Q_L.T@G/dim(G.TG)
496+
updates = []
497+
for idx, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)):
498+
approx_eigvals = utils.eig.conjugate(kronecker_factor, eigenbasis, diag=True)
499+
scale_factor = 1 / grad.shape[idx] * approx_eigvals.clamp_min(eps) ** eigval_exp
500+
501+
logging.debug(f"scale_factor[{idx}]: {scale_factor}")
502+
503+
correction = (eigenbasis * scale_factor[None, :]) @ eigenbasis.T
504+
505+
maybe_transpose_grad = grad.T if idx == 1 else grad
506+
updates.append(utils.eig.conjugate(correction, maybe_transpose_grad))
507+
508+
# Note that updates caculated in previous loop are in reverse order of the kronecker factor list they apply to
509+
for kronecker_factor, update in zip(kronecker_factor_list, updates[::-1], strict=True):
510+
kronecker_factor.lerp_(update, 1 - shampoo_beta)
511+
512+
456513
@torch.no_grad() # type: ignore[misc]
457514
def update_eigenbasis_and_momentum(
458515
kronecker_factor_list: List[torch.Tensor],

tests/ci/L0_Tests_GPU.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
1818
error=0
1919
coverage run -p --source=emerging_optimizers tests/test_muon_utils.py || error=1
2020
coverage run -p --source=emerging_optimizers tests/test_orthogonalized_optimizer.py || error=1
21-
coverage run -p --source=emerging_optimizers tests/test_soap_functions.py || error=1
2221
coverage run -p --source=emerging_optimizers tests/test_soap_utils.py || error=1
23-
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py || error=1
22+
coverage run -p --source=emerging_optimizers tests/test_soap.py || error=1
2423
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py || error=1
2524
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py --device=cuda || error=1
2625
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py || error=1

tests/ci/L1_Tests_GPU.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
1717
error=0
1818
python tests/test_muon_utils.py || error=1
1919
python tests/test_orthogonalized_optimizer.py || error=1
20-
python tests/test_soap_functions.py || error=1
2120
python tests/test_soap_utils.py || error=1
22-
python tests/soap_smoke_test.py || error=1
21+
python tests/test_soap.py || error=1
2322
python tests/test_scalar_optimizers.py --device=cuda || error=1
2423
python tests/test_spectral_clipping_utils.py || error=1
2524
python tests/test_triton_kernels.py || error=1

tests/soap_smoke_test.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

0 commit comments

Comments
 (0)