-
Notifications
You must be signed in to change notification settings - Fork 10
Adding "spectral clipped" weight decay to be used together with matrix-based preconditioning optimizers #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
mkhona-nvidia
wants to merge
19
commits into
NVIDIA-NeMo:main
from
mkhona-nvidia:mkhona/spectral_clipping_weight_decay
Closed
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
f9c52ac
Update cherry-pick workflow to use v0.63.0 (#29)
pablo-garay 8c10d97
updated polar express coeffs with stable values from paper
mkhona-nvidia ab759e6
formatting
mkhona-nvidia 3f6ca7a
save unnecessary matmul (#30)
skyw 53b3bfb
added license file to spectral clip utils
mkhona-nvidia e0abe21
truncate abc values to their lower representable fp32 format
mkhona-nvidia 18097aa
improved tests
mkhona-nvidia 5c26b4b
improved test by checking polar factor preservation
mkhona-nvidia e1a0d7f
improved docstrings
mkhona-nvidia 7e4a563
removed msg override
mkhona-nvidia 319d2c6
improved memory pressure using torch.add
mkhona-nvidia 6df541f
improve test naming
mkhona-nvidia 7c01266
moved spectral clips to be imported by module
mkhona-nvidia 6e20002
import module
mkhona-nvidia 4168b01
Added tests to ci
mkhona-nvidia 373a545
made debug statements in logging
mkhona-nvidia 4afbc39
address PR comments
mkhona-nvidia 991f011
Improve document (#33)
skyw dc596fe
Update eig utils (#35)
skyw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
emerging_optimizers/orthogonalized_optimizers/spectral_clipping_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import torch | ||
|
|
||
| from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz | ||
|
|
||
|
|
||
| __all__ = ["spectral_hardcap", "spectral_clip"] | ||
|
|
||
|
|
||
| def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor: | ||
| r"""Applies spectral clipping to the input tensor. | ||
|
|
||
| From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices | ||
| using the matrix sign function, computed using Newton-Schulz iteration for efficiency. | ||
|
|
||
| Based on https://leloykun.github.io/ponder/spectral-clipping/. | ||
|
|
||
| Args: | ||
| X: The input tensor. | ||
| sigma_min: The minimum singular value. | ||
| sigma_max: The maximum singular value. | ||
|
|
||
| Returns: | ||
| The spectral clipped tensor. | ||
| """ | ||
| if needs_transpose := X.shape[0] > X.shape[1]: | ||
| X = X.T | ||
| OX = newton_schulz(X, steps=8, coefficient_type="polar_express") | ||
| result = (sigma_min + sigma_max) * OX | ||
| identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype) | ||
| for s, sign in zip([sigma_min, sigma_max], [1, -1]): | ||
| A = torch.add(s * identity_matrix, OX @ X.T, alpha=-1) | ||
| B = torch.add(s * OX, X, alpha=-1) | ||
| result = torch.add(result, sign * newton_schulz(A, steps=8, coefficient_type="polar_express") @ B) | ||
| result = result * 0.5 | ||
|
|
||
| if needs_transpose: | ||
| result = result.T | ||
| return result | ||
|
|
||
|
|
||
| def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor: | ||
| r"""Spectral hardcap function clips singular values from above to be less than beta. | ||
mkhona-nvidia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap. | ||
| Based on https://leloykun.github.io/ponder/spectral-clipping/. | ||
|
|
||
| Args: | ||
| X: The input tensor. | ||
| beta: The upper bound on the singular values. | ||
|
|
||
| Returns: | ||
| The spectral hardcapped tensor. | ||
|
|
||
| """ | ||
| if needs_transpose := X.shape[0] > X.shape[1]: | ||
| X = X.T | ||
| OX = newton_schulz(X, steps=8, coefficient_type="polar_express") | ||
| aX = torch.add(beta * OX, X, alpha=-1) | ||
| result = torch.add(beta * OX, X) | ||
| result = torch.add(result, aX @ newton_schulz(aX, steps=8, coefficient_type="polar_express").T @ OX, alpha=-1) | ||
| result = result * 0.5 | ||
| if needs_transpose: | ||
| result = result.T | ||
| return result | ||
|
|
||
|
|
||
| def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor: | ||
| r"""Applies weight decay to the input tensor while applying spectral hardcapping. | ||
|
|
||
| This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988). | ||
|
|
||
| Based on https://leloykun.github.io/ponder/spectral-clipping/. | ||
|
|
||
| Args: | ||
| X: The input tensor. | ||
| beta: The upper bound on the singular values. | ||
| c: The coefficient parameter. | ||
|
|
||
| Returns: | ||
| The spectral clipped weight decay tensor. | ||
| """ | ||
| return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.