Skip to content

Commit beacaab

Browse files
Add "spectral clipped" weight decay to be used together with matrix-based preconditioning optimizers (#38)
* added spectral clipping from previous PR Signed-off-by: mikail <[email protected]>
1 parent 9b433f7 commit beacaab

File tree

5 files changed

+227
-0
lines changed

5 files changed

+227
-0
lines changed

emerging_optimizers/orthogonalized_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# limitations under the License.
1515
from emerging_optimizers.orthogonalized_optimizers.muon import *
1616
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
17+
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
17+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
18+
19+
20+
__all__ = ["spectral_hardcap", "spectral_clip"]
21+
22+
23+
def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor:
24+
r"""Applies spectral clipping to the input tensor.
25+
26+
From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices
27+
using the matrix sign function, computed using Newton-Schulz iteration for efficiency.
28+
29+
Based on https://leloykun.github.io/ponder/spectral-clipping/.
30+
31+
Args:
32+
X: The input tensor.
33+
sigma_min: The minimum singular value.
34+
sigma_max: The maximum singular value.
35+
36+
Returns:
37+
The spectral clipped tensor.
38+
"""
39+
if needs_transpose := X.shape[0] > X.shape[1]:
40+
X = X.T
41+
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
42+
result = (sigma_min + sigma_max) * OX
43+
identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype)
44+
for s, sign in zip([sigma_min, sigma_max], [1, -1]):
45+
A = torch.addmm(s * identity_matrix, OX, X.T, beta=1.0, alpha=-1.0)
46+
B = torch.add(s * OX, X, alpha=-1)
47+
result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign)
48+
result = result * 0.5
49+
50+
if needs_transpose:
51+
result = result.T
52+
return result
53+
54+
55+
def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
56+
r"""Spectral hardcap function clips singular values from above to be less than beta.
57+
58+
Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap.
59+
Based on https://leloykun.github.io/ponder/spectral-clipping/.
60+
61+
Args:
62+
X: The input tensor.
63+
beta: The upper bound on the singular values.
64+
65+
Returns:
66+
The spectral hardcapped tensor.
67+
68+
"""
69+
if needs_transpose := X.shape[0] > X.shape[1]:
70+
X = X.T
71+
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
72+
aX = torch.add(beta * OX, X, alpha=-1)
73+
result = torch.add(beta * OX, X)
74+
result = torch.addmm(
75+
result, aX, torch.mm(newton_schulz(aX, steps=8, coefficient_type="polar_express").T, OX), alpha=-1
76+
)
77+
result = result * 0.5
78+
if needs_transpose:
79+
result = result.T
80+
return result
81+
82+
83+
def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor:
84+
r"""Applies weight decay to the input tensor while applying spectral hardcapping.
85+
86+
This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988).
87+
88+
Based on https://leloykun.github.io/ponder/spectral-clipping/.
89+
90+
Args:
91+
X: The input tensor.
92+
beta: The upper bound on the singular values.
93+
c: The coefficient parameter.
94+
95+
Returns:
96+
The spectral clipped weight decay tensor.
97+
"""
98+
return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c)

tests/ci/L0_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ coverage run -p --source=emerging_optimizers tests/test_soap_utils.py
2121
coverage run -p --source=emerging_optimizers tests/soap_smoke_test.py
2222
coverage run -p --source=emerging_optimizers tests/soap_mnist_test.py
2323
coverage run -p --source=emerging_optimizers tests/test_scalar_optimizers.py
24+
coverage run -p --source=emerging_optimizers tests/test_spectral_clipping_utils.py

tests/ci/L1_Tests_GPU.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,4 @@ python tests/test_soap_functions.py
1919
python tests/test_soap_utils.py
2020
python tests/soap_smoke_test.py
2121
python tests/test_scalar_optimizers.py
22+
python tests/test_spectral_clipping_utils.py
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from absl import logging
18+
from absl.testing import absltest, parameterized
19+
20+
import emerging_optimizers.orthogonalized_optimizers as orthogonalized_optimizers
21+
22+
23+
class TestSpectralClipping(parameterized.TestCase):
24+
def setUp(self):
25+
self.prev_precision = torch.get_float32_matmul_precision()
26+
torch.set_float32_matmul_precision("highest")
27+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
28+
logging.info(f"Using device: {self.device}")
29+
torch.manual_seed(1234)
30+
31+
def tearDown(self):
32+
torch.set_float32_matmul_precision(self.prev_precision)
33+
34+
@parameterized.product(
35+
dims=[(256, 128), (128, 256), (512, 512), (2048, 2048)],
36+
sigma_range=[(0.2, 0.8), (0.1, 20)],
37+
)
38+
def test_spectral_clipping(self, dims, sigma_range):
39+
"""Test that spectral clipping properly clips singular values to the specified range."""
40+
41+
sigma_min, sigma_max = sigma_range
42+
x = torch.randn(dims, device=self.device, dtype=torch.float32)
43+
44+
_, original_singular_values, _ = torch.linalg.svd(x, full_matrices=False)
45+
original_min_sv = original_singular_values.min().item()
46+
original_max_sv = original_singular_values.max().item()
47+
48+
clipped_x = orthogonalized_optimizers.spectral_clip(x, sigma_min=sigma_min, sigma_max=sigma_max)
49+
50+
_, singular_values, _ = torch.linalg.svd(clipped_x, full_matrices=False)
51+
52+
min_sv = singular_values.min().item()
53+
max_sv = singular_values.max().item()
54+
55+
logging.debug(f"Original matrix shape: {x.shape}")
56+
logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
57+
logging.debug(f"Clipped singular values range: [{min_sv:.6f}, {max_sv:.6f}]")
58+
logging.debug(f"Target range: [{sigma_min:.6f}, {sigma_max:.6f}]")
59+
logging.debug(f"Shape preservation: input {x.shape} -> output {clipped_x.shape}")
60+
61+
# use higher tolerance for lower singular values
62+
# typically, this algorithm introduces more error for lower singular values
63+
tolerance_upper = 1e-1
64+
tolerance_lower = 5e-1
65+
self.assertGreaterEqual(
66+
min_sv + tolerance_lower,
67+
sigma_min,
68+
)
69+
self.assertLessEqual(
70+
max_sv - tolerance_upper,
71+
sigma_max,
72+
)
73+
74+
self.assertEqual(clipped_x.shape, x.shape)
75+
76+
@parameterized.product(
77+
dims=[(256, 128), (128, 256), (512, 512), (100, 200)],
78+
beta=[0.5, 1.0, 0.8, 2.0],
79+
)
80+
def test_spectral_hardcap(self, dims, beta):
81+
"""Test that spectral hardcap properly clips singular values from above to be less than beta."""
82+
x = torch.randn(dims, device=self.device, dtype=torch.float32)
83+
84+
U_orig, original_singular_values, Vt_orig = torch.linalg.svd(x, full_matrices=False)
85+
original_min_sv = original_singular_values.min().item()
86+
original_max_sv = original_singular_values.max().item()
87+
logging.debug(f"Original matrix shape: {x.shape}")
88+
logging.debug(f"Original singular values range: [{original_min_sv:.6f}, {original_max_sv:.6f}]")
89+
90+
hardcapped_x = orthogonalized_optimizers.spectral_hardcap(x, beta=beta)
91+
92+
U_hard, singular_values, Vt_hard = torch.linalg.svd(hardcapped_x, full_matrices=False)
93+
94+
tolerance_upper = 1e-1
95+
96+
max_sv = singular_values.max().item()
97+
98+
logging.debug(f"Hardcapped max singular value: {max_sv:.6f}")
99+
logging.debug(f"Beta (upper bound): {beta:.6f}")
100+
logging.debug(f"Shape preservation: input {x.shape} -> output {hardcapped_x.shape}")
101+
102+
self.assertLessEqual(
103+
max_sv - tolerance_upper,
104+
beta,
105+
)
106+
107+
self.assertEqual(hardcapped_x.shape, x.shape)
108+
109+
# Test that singular vectors are preserved (polar factor UV^T should be similar)
110+
polar_orig = U_orig @ Vt_orig
111+
polar_hard = U_hard @ Vt_hard
112+
113+
# The polar factors should be very similar since hardcap only changes singular values, compute the relative difference
114+
relative_polar_frobenius_diff = torch.norm(polar_orig - polar_hard, "fro") / torch.norm(polar_orig, "fro")
115+
polar_tolerance = 1e-4
116+
117+
logging.debug(f"Polar factor Frobenius norm difference: {relative_polar_frobenius_diff:.6f}")
118+
119+
self.assertLessEqual(
120+
relative_polar_frobenius_diff,
121+
polar_tolerance,
122+
)
123+
124+
125+
if __name__ == "__main__":
126+
absltest.main()

0 commit comments

Comments
 (0)