Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cherry-pick-release-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ on:

jobs:
cherry-pick:
uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_cherry_pick.yml@v0.57.0
uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_cherry_pick.yml@v0.63.0
secrets:
PAT: ${{ secrets.PAT }}
SLACK_WEBHOOK_ADMIN: ${{ secrets.SLACK_WEBHOOK_ADMIN }}
Expand Down
6 changes: 6 additions & 0 deletions docs/apidocs/soap.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ emerging_optimizers.soap
.. autofunction:: update_kronecker_factors

.. autofunction:: update_eigenbasis_and_momentum

emerging_optimizers.soap.soap_utils
=====================================

.. automodule:: emerging_optimizers.soap.soap_utils
:members:
```
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/2.5", None),
}
autodoc_typehints = "description"


def linkcode_resolve(domain, info):
Expand Down
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.
from emerging_optimizers.orthogonalized_optimizers.muon import *
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
12 changes: 8 additions & 4 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ class Muon(OrthogonalizedOptimizer):
optimization via Frank-Wolfe.

References:
- Jordan, K. *Muon Optimizer Implementation.* [`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). [`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). [`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
- Jordan, K. *Muon Optimizer Implementation.*
[`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024).
[`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025).
[`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]

Warning:
- This optimizer requires that all parameters passed in are 2D.
Expand Down Expand Up @@ -122,7 +125,8 @@ def get_muon_scale_factor(
# Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982)
return extra_scale_factor * max(size_out, size_in) ** 0.5
elif mode == "unit_rms_norm":
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al.
# (https://jeremybernste.in/writing/deriving-muon)
return extra_scale_factor * (size_out / size_in) ** 0.5
else:
raise ValueError(f"Invalid mode for Muon update scale factor: {mode}")
16 changes: 8 additions & 8 deletions emerging_optimizers/orthogonalized_optimizers/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
],
"polar_express": [
# Polar Express iteration from: https://arxiv.org/abs/2505.16932
(7.2086, -15.5131, 9.0178),
(3.9623, -2.5813, 0.4542),
(3.9466, -2.5765, 0.4544),
(3.8991, -2.5671, 0.4566),
(3.7186, -2.5308, 0.4653),
(3.1390, -2.3073, 0.4733),
(2.1715, -1.5246, 0.3885),
(1.8648, -1.2224, 0.3577),
(8.205160, -22.90193, 16.46073),
(4.066395, -2.861154, 0.5183995),
(3.909595, -2.823351, 0.5250370),
(3.285564, -2.415302, 0.4852941),
(2.277873, -1.619822, 0.3984808),
(1.872576, -1.230704, 0.3585162),
(1.856437, -1.213239, 0.3567998),
(1.875, -1.25, 0.375),
],
"aol": [
# from https://github.com/thib-s/flash-newton-schulz/blob/main/newton_schulz_triton.py#L511
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,11 @@ class OrthogonalizedOptimizer(optim.Optimizer):

- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
In International Conference on Artificial Intelligence and Statistics (2015a).
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. *Stochastic Spectral Descent for Discrete Graphical Models.*
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V.
*Stochastic Spectral Descent for Discrete Graphical Models.*
In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. *Preconditioned spectral descent for deep learning.*
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V.
*Preconditioned spectral descent for deep learning.*
In Neural Information Processing Systems (2015b).
- Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.*
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz


__all__ = ["spectral_hardcap", "spectral_clip"]


def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor:
r"""Applies spectral clipping to the input tensor.

From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices
using the matrix sign function, computed using Newton-Schulz iteration for efficiency.

Based on https://leloykun.github.io/ponder/spectral-clipping/.

Args:
X: The input tensor.
sigma_min: The minimum singular value.
sigma_max: The maximum singular value.

Returns:
The spectral clipped tensor.
"""
if needs_transpose := X.shape[0] > X.shape[1]:
X = X.T
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
result = (sigma_min + sigma_max) * OX
identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype)
for s, sign in zip([sigma_min, sigma_max], [1, -1]):
A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1)
B = torch.add(s * OX, X, alpha=-1)
result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B)
result = result * 0.5

if needs_transpose:
result = result.T
return result


def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
r"""Spectral hardcap function clips singular values from above to be less than beta.

Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap.
Based on https://leloykun.github.io/ponder/spectral-clipping/.

Args:
X: The input tensor.
beta: The upper bound on the singular values.

Returns:
The spectral hardcapped tensor.

"""
if needs_transpose := X.shape[0] > X.shape[1]:
X = X.T
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
aX = torch.add(beta * OX, X, alpha=-1)
result = torch.add(beta * OX, X)
result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1)
result = result * 0.5
if needs_transpose:
result = result.T
return result


def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor:
r"""Applies weight decay to the input tensor while applying spectral hardcapping.

This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988).

Based on https://leloykun.github.io/ponder/spectral-clipping/.

Args:
X: The input tensor.
beta: The upper bound on the singular values.
c: The coefficient parameter.

Returns:
The spectral clipped weight decay tensor.
"""
return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c)
61 changes: 23 additions & 38 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class SOAP(optim.Optimizer):
precondition_warmup_steps: How many steps to warm up the preconditioner (i.e. update every step)
adam_warmup_steps: How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)
precondition_1d: Whether to precondition 1D gradients (like biases).
max_precond_dim: Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds.
trace_normalization: Whether to normalize update by the trace of the kronecker factor matrix
normalize_preconditioned_grads: Whether to normalize preconditioned gradients per layer
correct_bias: Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA
Expand Down Expand Up @@ -91,7 +90,6 @@ def __init__(
precondition_warmup_steps: int = 0,
adam_warmup_steps: int = 1,
precondition_1d: bool = False,
max_precond_dim: int = 8192,
trace_normalization: bool = False,
normalize_preconditioned_grads: bool = False,
correct_bias: bool = True,
Expand Down Expand Up @@ -127,7 +125,8 @@ def __init__(
original_adam_warmup_steps = adam_warmup_steps
adam_warmup_steps = max(1, precondition_warmup_steps - 1)
logging.info(
f"adam_warmup_steps ({original_adam_warmup_steps}) should be less than precondition_warmup_steps ({precondition_warmup_steps}). "
f"adam_warmup_steps ({original_adam_warmup_steps}) should be less "
f"than precondition_warmup_steps ({precondition_warmup_steps}). "
f"Setting adam_warmup_steps to {adam_warmup_steps} by default."
)

Expand All @@ -141,7 +140,6 @@ def __init__(
"precondition_warmup_steps": precondition_warmup_steps,
"adam_warmup_steps": adam_warmup_steps,
"precondition_1d": precondition_1d,
"max_precond_dim": max_precond_dim,
"trace_normalization": trace_normalization,
"normalize_preconditioned_grads": normalize_preconditioned_grads,
"use_nesterov": use_nesterov,
Expand Down Expand Up @@ -194,17 +192,16 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
state["GG"] = init_kronecker_factors(
grad,
precondition_1d=group["precondition_1d"],
max_precond_dim=group["max_precond_dim"],
)

# Update preconditioner matrices with gradient statistics, do not use shampoo_beta for EMA at first step
# Update preconditioner matrices with gradient statistics,
# do not use shampoo_beta for EMA at first step
with utils.fp32_matmul_precision(group["fp32_matmul_prec"]):
update_kronecker_factors(
kronecker_factor_list=state["GG"],
grad=grad,
shampoo_beta=0.0,
precondition_1d=group["precondition_1d"],
max_precond_dim=group["max_precond_dim"],
)

# Increment step counter
Expand Down Expand Up @@ -284,11 +281,11 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
grad=grad,
shampoo_beta=shampoo_beta,
precondition_1d=group["precondition_1d"],
max_precond_dim=group["max_precond_dim"],
)
torch.cuda.nvtx.range_pop()

# If current step is the last step to skip preconditioning, initialize eigenbases and end first order warmup
# If current step is the last step to skip preconditioning, initialize eigenbases and
# end first order warmup
if state["step"] == group["adam_warmup_steps"]:
# Obtain kronecker factor eigenbases from kronecker factor matrices using eigendecomposition
state["Q"] = get_eigenbasis_eigh(state["GG"])
Expand Down Expand Up @@ -330,7 +327,6 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
def init_kronecker_factors(
grad: torch.Tensor,
precondition_1d: bool = False,
max_precond_dim: int = 8192,
) -> List[torch.Tensor]:
"""Initializes the kronecker factor matrices for the SOAP optimizer.

Expand All @@ -354,8 +350,6 @@ def init_kronecker_factors(
The shape of this tensor determines the size of the kronecker factor matrices.
precondition_1d: Whether to create kronecker factor matrices for 1D tensors
(like biases). If False, 1D tensors will skip preconditioning.
max_precond_dim: Maximum dimension of the preconditioner matrices.
Skips preconditioning if any tensor dimension exceeds.

Returns:
List[torch.Tensor]: List of kronecker factor matrices (L and R in paper).
Expand Down Expand Up @@ -387,21 +381,11 @@ def init_kronecker_factors(
else:
# Create a square preconditioner matrix for 1D tensors
size = grad.shape[0]
if size > max_precond_dim:
# if tensor dimension is larger than max_precond_dim, skip preconditioning this dimension
# append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail
kronecker_factor_list.append(torch.empty(0, device=grad.device))
else:
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
else:
# Create a square kronecker factor matrix for each dimension
for size in grad.shape:
if size > max_precond_dim:
# append empty tensor to kronecker_factor_list so that subsequent check that use numel() to check if preconditioner is initialized will not fail
# skip preconditioning this dimension
kronecker_factor_list.append(torch.empty(0, device=grad.device))
else:
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))
kronecker_factor_list.append(torch.zeros(size, size, device=grad.device))

return kronecker_factor_list

Expand All @@ -412,7 +396,6 @@ def update_kronecker_factors(
grad: torch.Tensor,
shampoo_beta: float,
precondition_1d: bool = False,
max_precond_dim: int = 8192,
) -> None:
"""Updates the preconditioner matrices using gradient outer products.

Expand All @@ -429,8 +412,6 @@ def update_kronecker_factors(
Controls how much weight to give to new vs old gradient statistics.
precondition_1d: Whether to apply preconditioning to 1D tensors (like biases).
If False, 1D tensors will skip preconditioning.
max_precond_dim: Maximum dimension of the preconditioner matrices.
Skips preconditioning if any tensor dimension exceeds.

Example:
>>> grad = torch.randn(10, 20)
Expand All @@ -446,20 +427,23 @@ def update_kronecker_factors(
kronecker_factor_list[0].lerp_(outer_product, 1 - shampoo_beta)
else:
# For 1D tensors, skip preconditioning
logging.error(
"1D tensor is passed to update_kronecker_factors, "
"but precondition_1d is not set to True, skipping preconditioning."
)
return
else:
# For higher dimensional tensors, compute outer products for each dimension
for idx, dim_size in enumerate(grad.shape):
if dim_size <= max_precond_dim:
# Compute outer product by contracting all dimensions except idx
contract_dims = [*chain(range(idx), range(idx + 1, grad.dim()))]
outer_product = torch.tensordot(
grad,
grad,
dims=[contract_dims] * 2,
)
# Update the corresponding Kronecker factor
kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta)
# Compute outer product by contracting all dimensions except idx
contract_dims = [*chain(range(idx), range(idx + 1, grad.dim()))]
outer_product = torch.tensordot(
grad,
grad,
dims=[contract_dims] * 2,
)
# Update the corresponding Kronecker factor
kronecker_factor_list[idx].lerp_(outer_product, 1 - shampoo_beta)


@torch.no_grad() # type: ignore[misc]
Expand Down Expand Up @@ -606,7 +590,8 @@ def precondition(
)
else:
# Permute gradient dimensions to process the next dimension in the following iteration
# when preconditioning for the current dimension is skipped (Q is empty), in the case of one-sided preconditioning.
# when preconditioning for the current dimension is skipped (Q is empty), in the case of
# one-sided preconditioning.
permute_order = list(range(1, grad.dim())) + [0]
grad = grad.permute(permute_order)

Expand Down
Loading