Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions emerging_optimizers/orthogonalized_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
# limitations under the License.
from emerging_optimizers.orthogonalized_optimizers.muon import *
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz


__all__ = ["spectral_hardcap", "spectral_clip"]


def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor:
r"""Applies spectral clipping to the input tensor.

From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices
using the matrix sign function, computed using Newton-Schulz iteration for efficiency.

Based on https://leloykun.github.io/ponder/spectral-clipping/.

Args:
X: The input tensor.
sigma_min: The minimum singular value.
sigma_max: The maximum singular value.

Returns:
The spectral clipped tensor.
"""
if needs_transpose := X.shape[0] > X.shape[1]:
X = X.T
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
result = (sigma_min + sigma_max) * OX
identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype)
for s, sign in zip([sigma_min, sigma_max], [1, -1]):
A = torch.addmm(s * identity_matrix, OX, X.T, beta=1.0, alpha=-1.0)
B = torch.add(s * OX, X, alpha=-1)
result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign)
result = result * 0.5

if needs_transpose:
result = result.T
return result


def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
r"""Spectral hardcap function clips singular values from above to be less than beta.

Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap.
Based on https://leloykun.github.io/ponder/spectral-clipping/.

Args:
X: The input tensor.
beta: The upper bound on the singular values.

Returns:
The spectral hardcapped tensor.

"""
if needs_transpose := X.shape[0] > X.shape[1]:
X = X.T
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
aX = torch.add(beta * OX, X, alpha=-1)
result = torch.add(beta * OX, X)
result = torch.addmm(
result, aX, torch.mm(newton_schulz(aX, steps=8, coefficient_type="polar_express").T, OX), alpha=-1
)
result = result * 0.5
if needs_transpose:
result = result.T
return result


def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor:
r"""Applies weight decay to the input tensor while applying spectral hardcapping.

This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988).

Based on https://leloykun.github.io/ponder/spectral-clipping/.

Args:
X: The input tensor.
beta: The upper bound on the singular values.
c: The coefficient parameter.

Returns:
The spectral clipped weight decay tensor.
"""
return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c)
1 change: 1 addition & 0 deletions tests/ci/L0_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py
1 change: 1 addition & 0 deletions tests/ci/L1_Tests_GPU.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ python tests/test_soap_functions.py
python tests/test_soap_utils.py
python tests/soap_smoke_test.py
python tests/test_scalar_optimizers.py
python tests/test_spectral_clipping_utils.py
126 changes: 126 additions & 0 deletions tests/test_spectral_clipping_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from absl import logging
from absl.testing import absltest, parameterized

import emerging_optimizers.orthogonalized_optimizers as orthogonalized_optimizers


class TestSpectralClipping(parameterized.TestCase):
def setUp(self):
self.prev_precision = torch.get_float32_matmul_precision()
torch.set_float32_matmul_precision("highest")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logging.info(f"Using device: {self.device}")
torch.manual_seed(1234)

def tearDown(self):
torch.set_float32_matmul_precision(self.prev_precision)

@parameterized.product(
dims=[(256, 128), (128, 256), (512, 512), (2048, 2048)],
sigma_range=[(0.2, 0.8), (0.1, 20)],
)
def test_spectral_clipping(self, dims, sigma_range):
"""Test that spectral clipping properly clips singular values to the specified range."""

sigma_min, sigma_max = sigma_range
x = torch.randn(dims, device=self.device, dtype=torch.float32)

_, original_singular_values, _ = torch.linalg.svd(x, full_matrices=False)
original_min_sv = original_singular_values.min().item()
original_max_sv = original_singular_values.max().item()

clipped_x = orthogonalized_optimizers.spectral_clip(x, sigma_min=sigma_min, sigma_max=sigma_max)

_, singular_values, _ = torch.linalg.svd(clipped_x, full_matrices=False)

min_sv = singular_values.min().item()
max_sv = singular_values.max().item()

logging.debug(f"Original matrix shape: {x.shape}")
logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
logging.debug(f"Clipped singular values range: [{min_sv:.6f}, {max_sv:.6f}]")
logging.debug(f"Target range: [{sigma_min:.6f}, {sigma_max:.6f}]")
logging.debug(f"Shape preservation: input {x.shape} -> output {clipped_x.shape}")

# use higher tolerance for lower singular values
# typically, this algorithm introduces more error for lower singular values
tolerance_upper = 1e-1
tolerance_lower = 5e-1
self.assertGreaterEqual(
min_sv + tolerance_lower,
sigma_min,
)
self.assertLessEqual(
max_sv - tolerance_upper,
sigma_max,
)

self.assertEqual(clipped_x.shape, x.shape)

@parameterized.product(
dims=[(256, 128), (128, 256), (512, 512), (100, 200)],
beta=[0.5, 1.0, 0.8, 2.0],
)
def test_spectral_hardcap(self, dims, beta):
"""Test that spectral hardcap properly clips singular values from above to be less than beta."""
x = torch.randn(dims, device=self.device, dtype=torch.float32)

U_orig, original_singular_values, Vt_orig = torch.linalg.svd(x, full_matrices=False)
original_min_sv = original_singular_values.min().item()
original_max_sv = original_singular_values.max().item()
logging.debug(f"Original matrix shape: {x.shape}")
logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")

hardcapped_x = orthogonalized_optimizers.spectral_hardcap(x, beta=beta)

U_hard, singular_values, Vt_hard = torch.linalg.svd(hardcapped_x, full_matrices=False)

tolerance_upper = 1e-1

max_sv = singular_values.max().item()

logging.debug(f"Hardcapped max singular value: {max_sv:.6f}")
logging.debug(f"Beta (upper bound): {beta:.6f}")
logging.debug(f"Shape preservation: input {x.shape} -> output {hardcapped_x.shape}")

self.assertLessEqual(
max_sv - tolerance_upper,
beta,
)

self.assertEqual(hardcapped_x.shape, x.shape)

# Test that singular vectors are preserved (polar factor UV^T should be similar)
polar_orig = U_orig @ Vt_orig
polar_hard = U_hard @ Vt_hard

# The polar factors should be very similar since hardcap only changes singular values, compute the relative difference
relative_polar_frobenius_diff = torch.norm(polar_orig - polar_hard, "fro") / torch.norm(polar_orig, "fro")
polar_tolerance = 1e-4

logging.debug(f"Polar factor Frobenius norm difference: {relative_polar_frobenius_diff:.6f}")

self.assertLessEqual(
relative_polar_frobenius_diff,
polar_tolerance,
)


if __name__ == "__main__":
absltest.main()