Skip to content

Commit ff42aa8

Browse files
committed
added Scion's franke-wolfe parametrization to orthogonalized Optimizer
Signed-off-by: mikail <mkhona@nvidia.com>
1 parent f3b9d84 commit ff42aa8

File tree

1 file changed

+125
-0
lines changed
  • emerging_optimizers/orthogonalized_optimizers

1 file changed

+125
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from absl import logging
18+
from torch.optim.optimizer import ParamsT
19+
20+
from emerging_optimizers import triton_kernels
21+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
22+
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
23+
24+
25+
class Scion(OrthogonalizedOptimizer):
26+
"""Muon: MomentUm Orthogonalized by Newton-schulz
27+
28+
Scion runs standard SGD-momentum and then performs an orthogonalization
29+
post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
30+
matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the
31+
advantage that it may be stably run on tensor cores on GPUs.
32+
33+
This implementation incorporates `step_size` and `spectral_radius` refer to Scion which views weight decay as constrained
34+
optimization via Frank-Wolfe.
35+
36+
References:
37+
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025).
38+
[`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
39+
40+
Warning:
41+
- This optimizer requires that all parameters passed in are 2D.
42+
- It should not be used for the embedding layer, the final fully connected layer, or any 1-D
43+
parameters; those should all be optimized by a standard method (e.g., AdamW).
44+
45+
Args:
46+
{_args_doc}
47+
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
48+
["simple", "quintic", "polar_express"].
49+
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
50+
scale_mode: The type of scale factor to use for the update. Defaults to "spectral" style scaling.
51+
spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius.
52+
use_syrk: Whether to use the Triton kernel for the Newton-Schulz iteration.
53+
"""
54+
55+
def __init__(
56+
self,
57+
params: ParamsT,
58+
lr: float = 3e-4,
59+
momentum_beta: float = 0.95,
60+
use_nesterov: bool = False,
61+
weight_decay: float = 1,
62+
use_decoupled_wd: bool = True,
63+
use_independent_wd: bool = False,
64+
fp32_matmul_prec: str = "medium",
65+
coefficient_type: str = "quintic",
66+
num_ns_steps: int = 5,
67+
scale_mode: str = "spectral",
68+
spectral_radius: float = 1.0,
69+
use_syrk: bool = False,
70+
) -> None:
71+
if num_ns_steps < 1:
72+
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
73+
74+
if use_syrk:
75+
if torch.cuda.is_available():
76+
sm_version = torch.cuda.get_device_capability()
77+
else:
78+
sm_version = (0, 0)
79+
if not triton_kernels.HAS_TRITON_340: # type: ignore[attr-defined]
80+
logging.error("Triton 3.4.0 or higher is required for use_syrk to be True.")
81+
use_syrk = False
82+
elif sm_version not in ((8, 0), (9, 0), (10, 0), (10, 3)):
83+
logging.error(
84+
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
85+
)
86+
use_syrk = False
87+
88+
# Add checks for weight decay arguments to enable Franke-Wolfe step.
89+
if weight_decay != 1:
90+
logging.warning("Scion does not use weight decay. Setting weight_decay to 1.")
91+
weight_decay = 1
92+
93+
if not use_decoupled_wd:
94+
logging.warning("Scion does not use weight decay. Setting use_decoupled_wd to True to allow Franke-Wolfe.")
95+
use_decoupled_wd = True
96+
97+
if use_independent_wd:
98+
logging.warning(
99+
"Scion does not use weight decay. Setting use_independent_wd to False to allow Franke-Wolfe."
100+
)
101+
use_independent_wd = False
102+
103+
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
104+
logging.debug(
105+
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}"
106+
)
107+
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk)
108+
width_factor = (grad.size(-2) / grad.size(-1)) ** 0.5
109+
return orth_grad * width_factor * spectral_radius
110+
111+
super().__init__(
112+
params,
113+
lr,
114+
momentum_beta,
115+
use_nesterov,
116+
weight_decay,
117+
use_decoupled_wd,
118+
use_independent_wd,
119+
fp32_matmul_prec,
120+
scaled_orthogonalize_fn,
121+
spectral_radius=spectral_radius,
122+
)
123+
124+
125+
Scion.__doc__ = Scion.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]

0 commit comments

Comments
 (0)