|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import torch |
| 17 | +from absl import logging |
| 18 | +from torch.optim.optimizer import ParamsT |
| 19 | + |
| 20 | +from emerging_optimizers import triton_kernels |
| 21 | +from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz |
| 22 | +from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc |
| 23 | + |
| 24 | + |
| 25 | +class Scion(OrthogonalizedOptimizer): |
| 26 | + """Muon: MomentUm Orthogonalized by Newton-schulz |
| 27 | +
|
| 28 | + Scion runs standard SGD-momentum and then performs an orthogonalization |
| 29 | + post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal |
| 30 | + matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the |
| 31 | + advantage that it may be stably run on tensor cores on GPUs. |
| 32 | +
|
| 33 | + This implementation incorporates `step_size` and `spectral_radius` refer to Scion which views weight decay as constrained |
| 34 | + optimization via Frank-Wolfe. |
| 35 | +
|
| 36 | + References: |
| 37 | + - *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). |
| 38 | + [`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_] |
| 39 | +
|
| 40 | + Warning: |
| 41 | + - This optimizer requires that all parameters passed in are 2D. |
| 42 | + - It should not be used for the embedding layer, the final fully connected layer, or any 1-D |
| 43 | + parameters; those should all be optimized by a standard method (e.g., AdamW). |
| 44 | +
|
| 45 | + Args: |
| 46 | + {_args_doc} |
| 47 | + coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of |
| 48 | + ["simple", "quintic", "polar_express"]. |
| 49 | + num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. |
| 50 | + scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling. |
| 51 | + spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius. |
| 52 | + use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + params: ParamsT, |
| 58 | + lr: float = 3e-4, |
| 59 | + momentum_beta: float = 0.95, |
| 60 | + use_nesterov: bool = False, |
| 61 | + weight_decay: float = 1, |
| 62 | + use_decoupled_wd: bool = True, |
| 63 | + use_independent_wd: bool = False, |
| 64 | + fp32_matmul_prec: str = "medium", |
| 65 | + coefficient_type: str = "quintic", |
| 66 | + num_ns_steps: int = 5, |
| 67 | + scale_mode: str = "spectral", |
| 68 | + spectral_radius: float = 1.0, |
| 69 | + use_syrk: bool = False, |
| 70 | + ) -> None: |
| 71 | + if num_ns_steps < 1: |
| 72 | + raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") |
| 73 | + |
| 74 | + if use_syrk: |
| 75 | + if torch.cuda.is_available(): |
| 76 | + sm_version = torch.cuda.get_device_capability() |
| 77 | + else: |
| 78 | + sm_version = (0, 0) |
| 79 | + if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined] |
| 80 | + logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.") |
| 81 | + use_syrk = False |
| 82 | + elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)): |
| 83 | + logging.error( |
| 84 | + f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." |
| 85 | + ) |
| 86 | + use_syrk = False |
| 87 | + |
| 88 | + # Add checks for weight decay arguments to enable Franke-Wolfe step. |
| 89 | + if weight_decay != 1: |
| 90 | + logging.warning("Scion does not use weight decay. Setting weight_decay to 1.") |
| 91 | + weight_decay = 1 |
| 92 | + |
| 93 | + if not use_decoupled_wd: |
| 94 | + logging.warning("Scion does not use weight decay. Setting use_decoupled_wd to True to allow Franke-Wolfe.") |
| 95 | + use_decoupled_wd = True |
| 96 | + |
| 97 | + if use_independent_wd: |
| 98 | + logging.warning( |
| 99 | + "Scion does not use weight decay. Setting use_independent_wd to False to allow Franke-Wolfe." |
| 100 | + ) |
| 101 | + use_independent_wd = False |
| 102 | + |
| 103 | + def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: |
| 104 | + logging.debug( |
| 105 | + f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}" |
| 106 | + ) |
| 107 | + orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk) |
| 108 | + width_factor = (grad.size(-2) / grad.size(-1)) ** 0.5 |
| 109 | + return orth_grad * width_factor * spectral_radius |
| 110 | + |
| 111 | + super().__init__( |
| 112 | + params, |
| 113 | + lr, |
| 114 | + momentum_beta, |
| 115 | + use_nesterov, |
| 116 | + weight_decay, |
| 117 | + use_decoupled_wd, |
| 118 | + use_independent_wd, |
| 119 | + fp32_matmul_prec, |
| 120 | + scaled_orthogonalize_fn, |
| 121 | + spectral_radius=spectral_radius, |
| 122 | + ) |
| 123 | + |
| 124 | + |
| 125 | +Scion.__doc__ = Scion.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] |
0 commit comments