Skip to content

Commit 9139d55

Browse files
Scion optimizer (#70)
* added Scion's franke-wolfe parametrization to orthogonalized Optimizer Signed-off-by: mikail <[email protected]>
1 parent 16f8399 commit 9139d55

File tree

1 file changed

+102
-0
lines changed
  • emerging_optimizers/orthogonalized_optimizers

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from absl import logging
18+
from torch.optim.optimizer import ParamsT
19+
20+
from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor
21+
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
22+
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
23+
24+
25+
class Scion(OrthogonalizedOptimizer):
26+
"""Scion: Stochastic CondItional descent with Operator Norms
27+
28+
Scion runs standard SGD-momentum and then performs an orthogonalization
29+
post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
30+
matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the
31+
advantage that it may be stably run on tensor cores on GPUs.
32+
33+
This implementation incorporates `step_size` and `spectral_radius`, refer to Scion which views weight decay as constrained
34+
optimization via Frank-Wolfe.
35+
36+
References:
37+
- *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025).
38+
[`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_]
39+
40+
Warning:
41+
- This optimizer requires that all parameters passed in are 2D.
42+
- It should not be used for the embedding layer, the final fully connected layer, or any 1-D
43+
parameters; those should all be optimized by the appropriate LMO for that layer. For example,
44+
for 1d params, it is scaled by the `ell_inf` radius.
45+
46+
47+
Args:
48+
params: Iterable of parameters to optimize or dicts defining parameter groups
49+
lr: The learning rate used by the internal SGD.
50+
momentum_beta: The momentum used by the internal SGD.
51+
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
52+
coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of
53+
["simple", "quintic", "polar_express"].
54+
num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration.
55+
spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius.
56+
"""
57+
58+
def __init__(
59+
self,
60+
params: ParamsT,
61+
lr: float = 3e-4,
62+
momentum_beta: float = 0.95,
63+
fp32_matmul_prec: str = "medium",
64+
coefficient_type: str = "quintic",
65+
num_ns_steps: int = 5,
66+
spectral_radius: float = 1.0,
67+
) -> None:
68+
if num_ns_steps < 1:
69+
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
70+
71+
# Add checks for weight decay arguments to enable Franke-Wolfe step.
72+
logging.info("Scion does not use weight decay. Setting weight_decay to 1.")
73+
weight_decay = 1
74+
75+
logging.info("Scion does not use weight decay. Setting use_decoupled_wd to True to allow Franke-Wolfe.")
76+
use_decoupled_wd = True
77+
78+
logging.info("Scion does not use weight decay. Setting use_independent_wd to False to allow Franke-Wolfe.")
79+
use_independent_wd = False
80+
81+
logging.info("Scion does not use Nesterov momentum. Setting use_nesterov to False.")
82+
use_nesterov = False
83+
84+
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
85+
logging.debug(
86+
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}"
87+
)
88+
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False)
89+
width_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode="unit_rms_norm")
90+
return orth_grad * width_factor * spectral_radius
91+
92+
super().__init__(
93+
params,
94+
lr,
95+
momentum_beta,
96+
use_nesterov,
97+
weight_decay,
98+
use_decoupled_wd,
99+
use_independent_wd,
100+
fp32_matmul_prec,
101+
scaled_orthogonalize_fn,
102+
)

0 commit comments

Comments
 (0)