Skip to content

Commit 0408134

Browse files
authored
Merge branch 'main' into main
2 parents e5a78dd + 06d0bca commit 0408134

33 files changed

+3120
-356
lines changed

.github/workflows/cicd-main.yml

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,16 @@ permissions:
2929
contents: read
3030

3131
jobs:
32+
pre-flight:
33+
uses: NVIDIA-NeMo/FW-CI-templates/.github/workflows/_cicd_preflight.yml@main
3234

3335
cicd-wait-in-queue:
36+
needs: [pre-flight]
3437
runs-on: ubuntu-latest
3538
environment: test
39+
if: |
40+
needs.pre-flight.outputs.is_ci_workload == 'false'
41+
&& needs.pre-flight.outputs.docs_only == 'false'
3642
steps:
3743
- name: Running CI tests
3844
run: |
@@ -62,10 +68,17 @@ jobs:
6268
runner: linux-amd64-cpu16
6369
timeout: 30
6470
cpu-only: true
65-
needs: [cicd-container-build]
71+
needs: [pre-flight, cicd-container-build]
6672
runs-on: ${{ matrix.runner }}
6773
name: ${{ matrix.script }}
6874
environment: nemo-ci
75+
if: |
76+
(
77+
success()
78+
|| needs.pre-flight.outputs.is_ci_workload == 'true'
79+
|| needs.pre-flight.outputs.force_run_all == 'true'
80+
)
81+
&& !cancelled()
6982
steps:
7083
- name: Checkout
7184
uses: actions/checkout@v4
@@ -85,9 +98,17 @@ jobs:
8598

8699
Nemo_CICD_Test:
87100
needs:
101+
- pre-flight
88102
- cicd-container-build
89103
- cicd-unit-tests
90-
if: always()
104+
if: |
105+
(
106+
needs.pre-flight.outputs.docs_only == 'true'
107+
|| needs.pre-flight.outputs.is_deployment_workflow == 'true'
108+
|| needs.pre-flight.outputs.is_ci_workload == 'true'
109+
|| always()
110+
)
111+
&& !cancelled()
91112
runs-on: ubuntu-latest
92113
permissions: write-all
93114
steps:
@@ -99,13 +120,15 @@ jobs:
99120
env:
100121
GH_TOKEN: ${{ github.token }}
101122
RUN_ID: ${{ github.run_id }}
123+
DOCS_ONLY: ${{ needs.pre-flight.outputs.docs_only }}
124+
IS_DEPLOYMENT: ${{ needs.pre-flight.outputs.is_deployment_workflow }}
125+
IS_CI_WORKLOAD: ${{ needs.pre-flight.outputs.is_ci_workload }}
102126
run: |
103127
# Get workflow run details and check job conclusions
104-
LATEST_ATTEMPT=$(gh run view $RUN_ID --json jobs -q '[.jobs[] | select(.conclusion != null) | .conclusion] | last')
105128
NUM_FAILED=$(gh run view $RUN_ID --json jobs -q '[.jobs[] | select(.conclusion == "failure") | .name] | length')
106129
NUM_CANCELLED=$(gh run view $RUN_ID --json jobs -q '[.jobs[] | select(.conclusion == "cancelled") | .name] | length')
107130
108-
if [[ $NUM_FAILED -eq 0 && $NUM_CANCELLED -eq 0 ]]; then
131+
if [[ ($NUM_FAILED -eq 0 && $NUM_CANCELLED -eq 0) || $DOCS_ONLY == 'true' || $IS_DEPLOYMENT == 'true' || $IS_CI_WORKLOAD == 'true' ]]; then
109132
RESULT="success"
110133
elif [[ $NUM_CANCELLED -gt 0 ]]; then
111134
RESULT="cancelled"
@@ -180,9 +203,41 @@ jobs:
180203
exit 1
181204
fi
182205
206+
Coverage_Fake:
207+
runs-on: ubuntu-latest
208+
needs: [Nemo_CICD_Test, pre-flight]
209+
if: |
210+
(
211+
needs.pre-flight.outputs.docs_only == 'true'
212+
|| needs.pre-flight.outputs.is_deployment_workflow == 'true'
213+
)
214+
&& needs.pre-flight.outputs.is_ci_workload == 'false'
215+
&& !cancelled()
216+
environment: nemo-ci
217+
steps:
218+
- name: Generate fake coverage report
219+
uses: actions/github-script@v6
220+
with:
221+
github-token: ${{ secrets.PAT }}
222+
script: |
223+
await github.rest.repos.createCommitStatus({
224+
owner: context.repo.owner,
225+
repo: context.repo.repo,
226+
sha: context.sha,
227+
state: 'success',
228+
description: 'No code changes - coverage check skipped',
229+
context: 'codecov/patch'
230+
});
231+
183232
Coverage:
184233
runs-on: ubuntu-latest
185-
needs: [Nemo_CICD_Test]
234+
needs: [pre-flight, Nemo_CICD_Test]
235+
if: |
236+
(
237+
(needs.pre-flight.outputs.is_ci_workload == 'true' && !failure())
238+
|| success()
239+
)
240+
&& !cancelled()
186241
strategy:
187242
matrix:
188243
flag: [unit-test]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Emerging optimizers have demonstrated significant practical impact in large-scal
2626

2727
### Prerequisites
2828

29-
- Python 3.12 or higher
29+
- Python 3.10 or higher, 3.12 is recommended
3030
- PyTorch 2.0 or higher
3131

3232
### Install from Source

docs/apidocs/soap.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,10 @@ emerging_optimizers.soap
2121
.. autofunction:: update_kronecker_factors
2222
2323
.. autofunction:: update_eigenbasis_and_momentum
24+
25+
emerging_optimizers.soap.soap_utils
26+
=====================================
27+
28+
.. automodule:: emerging_optimizers.soap.soap_utils
29+
:members:
2430
```

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
"numpy": ("https://numpy.org/doc/stable", None),
7373
"torch": ("https://pytorch.org/docs/2.5", None),
7474
}
75+
autodoc_typehints = "description"
7576

7677

7778
def linkcode_resolve(domain, info):

emerging_optimizers/orthogonalized_optimizers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
# limitations under the License.
1515
from emerging_optimizers.orthogonalized_optimizers.muon import *
1616
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import *
17+
from emerging_optimizers.orthogonalized_optimizers.spectral_clipping_utils import *

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ class Muon(OrthogonalizedOptimizer):
3737
optimization via Frank-Wolfe.
3838
3939
References:
40-
- Jordan, K. *Muon Optimizer Implementation.* [`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
41-
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024). [`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]
42-
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). [`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
40+
- Jordan, K. *Muon Optimizer Implementation.*
41+
[`GitHub <https://github.com/KellerJordan/Muon/blob/master/muon.py>`_]
42+
- *Modular Duality in Deep Learning.* arXiv:2410.21265 (2024).
43+
[`arXiv:2410.21265 <https://arxiv.org/abs/2410.21265>`_]
44+
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025).
45+
[`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
4346
4447
Warning:
4548
- This optimizer requires that all parameters passed in are 2D.
@@ -122,7 +125,8 @@ def get_muon_scale_factor(
122125
# Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982)
123126
return extra_scale_factor * max(size_out, size_in) ** 0.5
124127
elif mode == "unit_rms_norm":
125-
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
128+
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al.
129+
# (https://jeremybernste.in/writing/deriving-muon)
126130
return extra_scale_factor * (size_out / size_in) ** 0.5
127131
else:
128132
raise ValueError(f"Invalid mode for Muon update scale factor: {mode}")

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import Any, Callable, override
15+
from typing import Any, Callable
16+
17+
18+
# TODO(@boxiangw): remove this once bump to python 3.12
19+
try:
20+
from typing import override
21+
except ImportError:
22+
from typing_extensions import override
1623

1724
import torch
1825
import torch.optim as optim
@@ -45,9 +52,11 @@ class OrthogonalizedOptimizer(optim.Optimizer):
4552
4653
- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
4754
In International Conference on Artificial Intelligence and Statistics (2015a).
48-
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. *Stochastic Spectral Descent for Discrete Graphical Models.*
55+
- Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V.
56+
*Stochastic Spectral Descent for Discrete Graphical Models.*
4957
In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).
50-
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. *Preconditioned spectral descent for deep learning.*
58+
- Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V.
59+
*Preconditioned spectral descent for deep learning.*
5160
In Neural Information Processing Systems (2015b).
5261
- Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.*
5362
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
17+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
18+
19+
20+
__all__ = ["spectral_hardcap", "spectral_clip"]
21+
22+
23+
def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor:
24+
r"""Applies spectral clipping to the input tensor.
25+
26+
From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices
27+
using the matrix sign function, computed using Newton-Schulz iteration for efficiency.
28+
29+
Based on https://leloykun.github.io/ponder/spectral-clipping/.
30+
31+
Args:
32+
X: The input tensor.
33+
sigma_min: The minimum singular value.
34+
sigma_max: The maximum singular value.
35+
36+
Returns:
37+
The spectral clipped tensor.
38+
"""
39+
if needs_transpose := X.shape[0] > X.shape[1]:
40+
X = X.T
41+
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
42+
result = (sigma_min + sigma_max) * OX
43+
identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype)
44+
for s, sign in zip([sigma_min, sigma_max], [1, -1]):
45+
A = torch.addmm(s * identity_matrix, OX, X.T, beta=1.0, alpha=-1.0)
46+
B = torch.add(s * OX, X, alpha=-1)
47+
result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign)
48+
result = result * 0.5
49+
50+
if needs_transpose:
51+
result = result.T
52+
return result
53+
54+
55+
def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
56+
r"""Spectral hardcap function clips singular values from above to be less than beta.
57+
58+
Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap.
59+
Based on https://leloykun.github.io/ponder/spectral-clipping/.
60+
61+
Args:
62+
X: The input tensor.
63+
beta: The upper bound on the singular values.
64+
65+
Returns:
66+
The spectral hardcapped tensor.
67+
68+
"""
69+
if needs_transpose := X.shape[0] > X.shape[1]:
70+
X = X.T
71+
OX = newton_schulz(X, steps=8, coefficient_type="polar_express")
72+
aX = torch.add(beta * OX, X, alpha=-1)
73+
result = torch.add(beta * OX, X)
74+
result = torch.addmm(
75+
result, aX, torch.mm(newton_schulz(aX, steps=8, coefficient_type="polar_express").T, OX), alpha=-1
76+
)
77+
result = result * 0.5
78+
if needs_transpose:
79+
result = result.T
80+
return result
81+
82+
83+
def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor:
84+
r"""Applies weight decay to the input tensor while applying spectral hardcapping.
85+
86+
This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988).
87+
88+
Based on https://leloykun.github.io/ponder/spectral-clipping/.
89+
90+
Args:
91+
X: The input tensor.
92+
beta: The upper bound on the singular values.
93+
c: The coefficient parameter.
94+
95+
Returns:
96+
The spectral clipped weight decay tensor.
97+
"""
98+
return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import torch
16+
17+
import emerging_optimizers.utils as utils
18+
from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_skew
19+
20+
21+
__all__ = [
22+
"procrustes_step",
23+
]
24+
25+
26+
@torch.compile # type: ignore[misc]
27+
def procrustes_step(Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8) -> torch.Tensor:
28+
r"""One step of an online solver for the orthogonal Procrustes problem.
29+
30+
The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
31+
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.
32+
33+
`max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)` to its 2nd order term.
34+
35+
This method is a second order expansion of a Lie algebra parametrized rotation that
36+
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.
37+
38+
Args:
39+
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
40+
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
41+
eps: Small number for numerical stability.
42+
"""
43+
# Note: this function is written in fp32 to avoid numerical instability while computing the taylor expansion of the exponential map
44+
with utils.fp32_matmul_precision("highest"):
45+
R = Q.T - Q
46+
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
47+
RQ = R @ Q
48+
# trace of RQ is always positive,
49+
# since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0
50+
tr_RQ = torch.trace(RQ)
51+
RRQ = R @ RQ
52+
tr_RRQ = torch.trace(RRQ)
53+
# clip step size to max_step_size, based on a 2nd order expansion.
54+
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
55+
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
56+
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
57+
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
58+
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
59+
# Q += step_size * (RQ + 0.5 * step_size * RRQ)
60+
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
61+
62+
return Q

0 commit comments

Comments
 (0)