From af561d898e76556c52eae9aa4ebccc711a3a88cf Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 3 Oct 2025 15:35:30 -0700 Subject: [PATCH 1/4] added spectral clipping from previous PR Signed-off-by: mikail --- .../orthogonalized_optimizers/__init__.py | 1 + .../spectral_clipping_utils.py | 96 +++++++++++++ tests/ci/L0_Tests_GPU.sh | 1 + tests/ci/L1_Tests_GPU.sh | 1 + tests/test_spectral_clipping_utils.py | 126 ++++++++++++++++++ 5 files changed, 225 insertions(+) create mode 100644 emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py create mode 100644 tests/test_spectral_clipping_utils.py diff --git a/emerging_optimizers/orthogonalized_optimizers/__init__.py b/emerging_optimizers/orthogonalized_optimizers/__init__.py index 0afbeb2..8b8f9a4 100644 --- a/emerging_optimizers/orthogonalized_optimizers/__init__.py +++ b/emerging_optimizers/orthogonalized_optimizers/__init__.py @@ -14,3 +14,4 @@ # limitations under the License. from emerging_optimizers.orthogonalized_optimizers.muon import * from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import * +from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import * diff --git a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py new file mode 100644 index 0000000..bebbfbe --- /dev/null +++ b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz + + +__all__ = ["spectral_hardcap", "spectral_clip"] + + +def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor: + r"""Applies spectral clipping to the input tensor. + + From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices + using the matrix sign function, computed using Newton-Schulz iteration for efficiency. + + Based on https://leloykun.github.io/ponder/spectral-clipping/. + + Args: + X: The input tensor. + sigma_min: The minimum singular value. + sigma_max: The maximum singular value. + + Returns: + The spectral clipped tensor. + """ + if needs_transpose := X.shape[0] > X.shape[1]: + X = X.T + OX = newton_schulz(X, steps=8, coefficient_type="polar_express") + result = (sigma_min + sigma_max) * OX + identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype) + for s, sign in zip([sigma_min, sigma_max], [1, -1]): + A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1) + B = torch.add(s * OX, X, alpha=-1) + result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B) + result = result * 0.5 + + if needs_transpose: + result = result.T + return result + + +def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor: + r"""Spectral hardcap function clips singular values from above to be less than beta. + + Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap. + Based on https://leloykun.github.io/ponder/spectral-clipping/. + + Args: + X: The input tensor. + beta: The upper bound on the singular values. + + Returns: + The spectral hardcapped tensor. + + """ + if needs_transpose := X.shape[0] > X.shape[1]: + X = X.T + OX = newton_schulz(X, steps=8, coefficient_type="polar_express") + aX = torch.add(beta * OX, X, alpha=-1) + result = torch.add(beta * OX, X) + result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1) + result = result * 0.5 + if needs_transpose: + result = result.T + return result + + +def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor: + r"""Applies weight decay to the input tensor while applying spectral hardcapping. + + This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988). + + Based on https://leloykun.github.io/ponder/spectral-clipping/. + + Args: + X: The input tensor. + beta: The upper bound on the singular values. + c: The coefficient parameter. + + Returns: + The spectral clipped weight decay tensor. + """ + return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c) diff --git a/tests/ci/L0_Tests_GPU.sh b/tests/ci/L0_Tests_GPU.sh index 53b1fd1..bd1a818 100644 --- a/tests/ci/L0_Tests_GPU.sh +++ b/tests/ci/L0_Tests_GPU.sh @@ -21,3 +21,4 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py +coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py \ No newline at end of file diff --git a/tests/ci/L1_Tests_GPU.sh b/tests/ci/L1_Tests_GPU.sh index 0e7570a..7af079e 100644 --- a/tests/ci/L1_Tests_GPU.sh +++ b/tests/ci/L1_Tests_GPU.sh @@ -19,3 +19,4 @@ python tests/test_soap_functions.py python tests/test_soap_utils.py python tests/soap_smoke_test.py python tests/test_scalar_optimizers.py +python tests/test_spectral_clipping_utils.py \ No newline at end of file diff --git a/tests/test_spectral_clipping_utils.py b/tests/test_spectral_clipping_utils.py new file mode 100644 index 0000000..2be4557 --- /dev/null +++ b/tests/test_spectral_clipping_utils.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from absl import logging +from absl.testing import absltest, parameterized + +import emerging_optimizers.orthogonalized_optimizers as orthogonalized_optimizers + + +class TestSpectralClipping(parameterized.TestCase): + def setUp(self): + self.prev_precision = torch.get_float32_matmul_precision() + torch.set_float32_matmul_precision("highest") + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logging.info(f"Using device: {self.device}") + torch.manual_seed(1234) + + def tearDown(self): + torch.set_float32_matmul_precision(self.prev_precision) + + @parameterized.product( + dims=[(256, 128), (128, 256), (512, 512), (2048, 2048)], + sigma_range=[(0.2, 0.8), (0.1, 20)], + ) + def test_spectral_clipping(self, dims, sigma_range): + """Test that spectral clipping properly clips singular values to the specified range.""" + + sigma_min, sigma_max = sigma_range + x = torch.randn(dims, device=self.device, dtype=torch.float32) + + _, original_singular_values, _ = torch.linalg.svd(x, full_matrices=False) + original_min_sv = original_singular_values.min().item() + original_max_sv = original_singular_values.max().item() + + clipped_x = orthogonalized_optimizers.spectral_clip(x, sigma_min=sigma_min, sigma_max=sigma_max) + + _, singular_values, _ = torch.linalg.svd(clipped_x, full_matrices=False) + + min_sv = singular_values.min().item() + max_sv = singular_values.max().item() + + logging.debug(f"Original matrix shape: {x.shape}") + logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]") + logging.debug(f"Clipped singular values range: [{min_sv:.6f}, {max_sv:.6f}]") + logging.debug(f"Target range: [{sigma_min:.6f}, {sigma_max:.6f}]") + logging.debug(f"Shape preservation: input {x.shape} -> output {clipped_x.shape}") + + # use higher tolerance for lower singular values + # typically, this algorithm introduces more error for lower singular values + tolerance_upper = 1e-1 + tolerance_lower = 5e-1 + self.assertGreaterEqual( + min_sv + tolerance_lower, + sigma_min, + ) + self.assertLessEqual( + max_sv - tolerance_upper, + sigma_max, + ) + + self.assertEqual(clipped_x.shape, x.shape) + + @parameterized.product( + dims=[(256, 128), (128, 256), (512, 512), (100, 200)], + beta=[0.5, 1.0, 0.8, 2.0], + ) + def test_spectral_hardcap(self, dims, beta): + """Test that spectral hardcap properly clips singular values from above to be less than beta.""" + x = torch.randn(dims, device=self.device, dtype=torch.float32) + + U_orig, original_singular_values, Vt_orig = torch.linalg.svd(x, full_matrices=False) + original_min_sv = original_singular_values.min().item() + original_max_sv = original_singular_values.max().item() + logging.debug(f"Original matrix shape: {x.shape}") + logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]") + + hardcapped_x = orthogonalized_optimizers.spectral_hardcap(x, beta=beta) + + U_hard, singular_values, Vt_hard = torch.linalg.svd(hardcapped_x, full_matrices=False) + + tolerance_upper = 1e-1 + + max_sv = singular_values.max().item() + + logging.debug(f"Hardcapped max singular value: {max_sv:.6f}") + logging.debug(f"Beta (upper bound): {beta:.6f}") + logging.debug(f"Shape preservation: input {x.shape} -> output {hardcapped_x.shape}") + + self.assertLessEqual( + max_sv - tolerance_upper, + beta, + ) + + self.assertEqual(hardcapped_x.shape, x.shape) + + # Test that singular vectors are preserved (polar factor UV^T should be similar) + polar_orig = U_orig @ Vt_orig + polar_hard = U_hard @ Vt_hard + + # The polar factors should be very similar since hardcap only changes singular values, compute the relative difference + relative_polar_frobenius_diff = torch.norm(polar_orig - polar_hard, "fro") / torch.norm(polar_orig, "fro") + polar_tolerance = 1e-4 + + logging.debug(f"Polar factor Frobenius norm difference: {relative_polar_frobenius_diff:.6f}") + + self.assertLessEqual( + relative_polar_frobenius_diff, + polar_tolerance, + ) + + +if __name__ == "__main__": + absltest.main() From eace573fcc4858d621cdf5d8238ae63aae6bf9ba Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 3 Oct 2025 15:40:13 -0700 Subject: [PATCH 2/4] missed addmm Signed-off-by: mikail --- .../orthogonalized_optimizers/spectral_clipping_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py index bebbfbe..dd8e26a 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py @@ -44,7 +44,7 @@ def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1 for s, sign in zip([sigma_min, sigma_max], [1, -1]): A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1) B = torch.add(s * OX, X, alpha=-1) - result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B) + result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign) result = result * 0.5 if needs_transpose: From e46e361661705dfb17eafd0e1f2dad133734e52e Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 3 Oct 2025 15:43:57 -0700 Subject: [PATCH 3/4] added missing torch addmm with mm Signed-off-by: mikail --- .../orthogonalized_optimizers/spectral_clipping_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py index dd8e26a..d406fbd 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py @@ -71,7 +71,9 @@ def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor: OX = newton_schulz(X, steps=8, coefficient_type="polar_express") aX = torch.add(beta * OX, X, alpha=-1) result = torch.add(beta * OX, X) - result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1) + result = torch.addmm( + result, aX, torch.mm(newton_schulz(aX, steps=8, coefficient_type="polar_express").T, OX), alpha=-1 + ) result = result * 0.5 if needs_transpose: result = result.T From b6cc6b6e692ee082f75b634752b3e49182924751 Mon Sep 17 00:00:00 2001 From: mikail Date: Fri, 3 Oct 2025 16:01:46 -0700 Subject: [PATCH 4/4] add one missing torch addmm Signed-off-by: mikail --- .../orthogonalized_optimizers/spectral_clipping_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py index d406fbd..2f7519e 100644 --- a/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py +++ b/emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py @@ -42,7 +42,7 @@ def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1 result = (sigma_min + sigma_max) * OX identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype) for s, sign in zip([sigma_min, sigma_max], [1, -1]): - A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1) + A = torch.addmm(s * identity_matrix, OX, X.T, beta=1.0, alpha=-1.0) B = torch.add(s * OX, X, alpha=-1) result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign) result = result * 0.5