-
Notifications
You must be signed in to change notification settings - Fork 12
Scion optimizer #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Scion optimizer #70
Changes from 9 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
c24b358
added Scion's franke-wolfe parametrization to orthogonalized Optimizer
mkhona-nvidia f5fc32e
Revert "added Scion's franke-wolfe parametrization to orthogonalized …
mkhona-nvidia e6db31e
added Scion's franke-wolfe parametrization to orthogonalized Optimizer
mkhona-nvidia a230420
removed choice of scale since we always use unit RMS scaling for muP
mkhona-nvidia 617bf0f
changed named to scion from muon
mkhona-nvidia d48d7a0
changed full form of name to match the paper https://arxiv.org/pdf/25…
mkhona-nvidia bf129c9
removed wd args from scion
mkhona-nvidia 1a11302
removed nesterov as a choice in scion
mkhona-nvidia 64f0c91
used muon scale factor with preset mode
mkhona-nvidia ae811e5
removed syrk from demonstration class
mkhona-nvidia c000e59
made logging level info
mkhona-nvidia d67475e
changed args_doc
mkhona-nvidia 5521b58
removed meanignless comment, args
mkhona-nvidia File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import torch | ||
| from absl import logging | ||
| from torch.optim.optimizer import ParamsT | ||
|
|
||
| from emerging_optimizers import triton_kernels | ||
| from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor | ||
| from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz | ||
| from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc | ||
|
|
||
|
|
||
| class Scion(OrthogonalizedOptimizer): | ||
| """Scion: Stochastic CondItional descent with Operator Norms | ||
|
|
||
| Scion runs standard SGD-momentum and then performs an orthogonalization | ||
| post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | ||
| matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the | ||
| advantage that it may be stably run on tensor cores on GPUs. | ||
|
|
||
| This implementation incorporates `step_size` and `spectral_radius`, refer to Scion which views weight decay as constrained | ||
| optimization via Frank-Wolfe. | ||
|
|
||
| References: | ||
| - *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). | ||
| [`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_] | ||
|
|
||
| Warning: | ||
| - This optimizer requires that all parameters passed in are 2D. | ||
| - It should not be used for the embedding layer, the final fully connected layer, or any 1-D | ||
| parameters; those should all be optimized by the appropriate LMO for that layer. For example, | ||
| for 1d params, it is scaled by the `ell_inf` radius. | ||
|
|
||
| Args: | ||
| {_args_doc} | ||
| coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of | ||
| ["simple", "quintic", "polar_express"]. | ||
| num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. | ||
| spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius. | ||
| use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| params: ParamsT, | ||
| lr: float = 3e-4, | ||
| momentum_beta: float = 0.95, | ||
| fp32_matmul_prec: str = "medium", | ||
| coefficient_type: str = "quintic", | ||
| num_ns_steps: int = 5, | ||
| spectral_radius: float = 1.0, | ||
| use_syrk: bool = False, | ||
| ) -> None: | ||
| if num_ns_steps < 1: | ||
| raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") | ||
|
|
||
| if use_syrk: | ||
mkhona-nvidia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if torch.cuda.is_available(): | ||
| sm_version = torch.cuda.get_device_capability() | ||
| else: | ||
| sm_version = (0, 0) | ||
| if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined] | ||
| logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.") | ||
| use_syrk = False | ||
| elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)): | ||
| logging.error( | ||
| f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." | ||
| ) | ||
| use_syrk = False | ||
|
|
||
| # Add checks for weight decay arguments to enable Franke-Wolfe step. | ||
| logging.warning("Scion does not use weight decay. Setting weight_decay to 1.") | ||
mkhona-nvidia marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| weight_decay = 1 | ||
|
|
||
| logging.warning("Scion does not use weight decay. Setting use_decoupled_wd to True to allow Franke-Wolfe.") | ||
| use_decoupled_wd = True | ||
|
|
||
| logging.warning("Scion does not use weight decay. Setting use_independent_wd to False to allow Franke-Wolfe.") | ||
| use_independent_wd = False | ||
|
|
||
| # Scion does not use Nesterov momentum. | ||
| logging.warning("Scion does not use Nesterov momentum. Setting use_nesterov to False.") | ||
| use_nesterov = False | ||
|
|
||
| def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: | ||
| logging.debug( | ||
| f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}" | ||
| ) | ||
| orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk) | ||
| width_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode="unit_rms_norm") | ||
| return orth_grad * width_factor * spectral_radius | ||
|
|
||
| super().__init__( | ||
| params, | ||
| lr, | ||
| momentum_beta, | ||
| use_nesterov, | ||
| weight_decay, | ||
| use_decoupled_wd, | ||
| use_independent_wd, | ||
| fp32_matmul_prec, | ||
| scaled_orthogonalize_fn, | ||
| ) | ||
|
|
||
|
|
||
| Scion.__doc__ = Scion.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.