|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +import torch |
| 16 | + |
| 17 | +from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz |
| 18 | + |
| 19 | + |
| 20 | +__all__ = ["spectral_hardcap", "spectral_clip"] |
| 21 | + |
| 22 | + |
| 23 | +def spectral_clip(X: torch.Tensor, sigma_min: float = -1.0, sigma_max: float = 1.0) -> torch.Tensor: |
| 24 | + r"""Applies spectral clipping to the input tensor. |
| 25 | +
|
| 26 | + From the idea that clipping can be written using the sign function. This idea can be extended to singular values of matrices |
| 27 | + using the matrix sign function, computed using Newton-Schulz iteration for efficiency. |
| 28 | +
|
| 29 | + Based on https://leloykun.github.io/ponder/spectral-clipping/. |
| 30 | +
|
| 31 | + Args: |
| 32 | + X: The input tensor. |
| 33 | + sigma_min: The minimum singular value. |
| 34 | + sigma_max: The maximum singular value. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + The spectral clipped tensor. |
| 38 | + """ |
| 39 | + if needs_transpose := X.shape[0] > X.shape[1]: |
| 40 | + X = X.T |
| 41 | + OX = newton_schulz(X, steps=8, coefficient_type="polar_express") |
| 42 | + result = (sigma_min + sigma_max) * OX |
| 43 | + identity_matrix = torch.eye(X.shape[0], device=X.device, dtype=X.dtype) |
| 44 | + for s, sign in zip([sigma_min, sigma_max], [1, -1]): |
| 45 | + A = torch.addmm(s * identity_matrix, OX, X.T, beta=1.0, alpha=-1.0) |
| 46 | + B = torch.add(s * OX, X, alpha=-1) |
| 47 | + result = torch.addmm(result, newton_schulz(A, steps=8, coefficient_type="polar_express"), B, alpha=sign) |
| 48 | + result = result * 0.5 |
| 49 | + |
| 50 | + if needs_transpose: |
| 51 | + result = result.T |
| 52 | + return result |
| 53 | + |
| 54 | + |
| 55 | +def spectral_hardcap(X: torch.Tensor, beta: float = 1.0) -> torch.Tensor: |
| 56 | + r"""Spectral hardcap function clips singular values from above to be less than beta. |
| 57 | +
|
| 58 | + Simplifies the spectral clipping function to just an upper bound, resulting in a hardcap. |
| 59 | + Based on https://leloykun.github.io/ponder/spectral-clipping/. |
| 60 | +
|
| 61 | + Args: |
| 62 | + X: The input tensor. |
| 63 | + beta: The upper bound on the singular values. |
| 64 | +
|
| 65 | + Returns: |
| 66 | + The spectral hardcapped tensor. |
| 67 | +
|
| 68 | + """ |
| 69 | + if needs_transpose := X.shape[0] > X.shape[1]: |
| 70 | + X = X.T |
| 71 | + OX = newton_schulz(X, steps=8, coefficient_type="polar_express") |
| 72 | + aX = torch.add(beta * OX, X, alpha=-1) |
| 73 | + result = torch.add(beta * OX, X) |
| 74 | + result = torch.addmm( |
| 75 | + result, aX, torch.mm(newton_schulz(aX, steps=8, coefficient_type="polar_express").T, OX), alpha=-1 |
| 76 | + ) |
| 77 | + result = result * 0.5 |
| 78 | + if needs_transpose: |
| 79 | + result = result.T |
| 80 | + return result |
| 81 | + |
| 82 | + |
| 83 | +def spectral_clipped_weight_decay(X: torch.Tensor, beta: float = 1.0, c: float = 0.5) -> torch.Tensor: |
| 84 | + r"""Applies weight decay to the input tensor while applying spectral hardcapping. |
| 85 | +
|
| 86 | + This is the spectral version of Euclidean decoupled weight decay (Hanson & Pratt, 1988). |
| 87 | +
|
| 88 | + Based on https://leloykun.github.io/ponder/spectral-clipping/. |
| 89 | +
|
| 90 | + Args: |
| 91 | + X: The input tensor. |
| 92 | + beta: The upper bound on the singular values. |
| 93 | + c: The coefficient parameter. |
| 94 | +
|
| 95 | + Returns: |
| 96 | + The spectral clipped weight decay tensor. |
| 97 | + """ |
| 98 | + return torch.add((1 - c) * X, spectral_hardcap(X, beta), alpha=c) |
0 commit comments