Skip to content

Commit f3509f4

Browse files
authored
Merge branch 'main' into skyw/kl_shampoo_dev
2 parents 2549b46 + cf9909b commit f3509f4

File tree

11 files changed

+618
-133
lines changed

11 files changed

+618
-133
lines changed

docker/Dockerfile.ci

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
FROM nvcr.io/nvidia/pytorch:25.06-py3
14+
FROM nvcr.io/nvidia/pytorch:25.09-py3
1515

1616
ENV PIP_CONSTRAINT=""
1717

1818
# Install uv and python
19-
ARG UV_VERSION=0.7.2
19+
ARG UV_VERSION=0.9.3
2020
ENV PATH="/root/.local/bin:$PATH"
2121
RUN curl -LsSf https://astral.sh/uv/${UV_VERSION}/install.sh | sh
2222

@@ -34,6 +34,7 @@ RUN --mount=type=bind,source=pyproject.toml,target=/workspace/pyproject.toml \
3434
uv sync --link-mode symlink --locked --all-groups \
3535
--no-install-package absl-py \
3636
--no-install-package torch \
37+
--no-install-package triton \
3738
--no-install-package nvidia-cublas-cu12 \
3839
--no-install-package nvidia-cuda-cupti-cu12 \
3940
--no-install-package nvidia-cuda-nvrtc-cu12 \
@@ -45,5 +46,7 @@ RUN --mount=type=bind,source=pyproject.toml,target=/workspace/pyproject.toml \
4546
--no-install-package nvidia-cusolver-cu12 \
4647
--no-install-package nvidia-cusparse-cu12 \
4748
--no-install-package nvidia-cusparselt-cu12 \
48-
--no-install-package nvidia-nccl-cu12
49+
--no-install-package nvidia-nccl-cu12 \
50+
--no-install-package nvidia-nvjitlink-cu12 \
51+
--no-install-package nvidia-nvtx-cu12
4952
EOF

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from functools import partial
16-
from typing import Callable
1715

1816
import torch
1917
from absl import logging
@@ -66,12 +64,9 @@ def __init__(
6664
params: ParamsT,
6765
lr: float = 3e-4,
6866
momentum_beta: float = 0.95,
69-
use_nesterov: bool = True,
67+
use_nesterov: bool = False,
7068
weight_decay: float = 0.01,
7169
use_decoupled_weight_decay: bool = True,
72-
split_qkv: bool = False,
73-
is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
74-
qkv_split_shapes: tuple[int, int, int] | None = None,
7570
fp32_matmul_prec: str = "medium",
7671
coefficient_type: str = "quintic",
7772
num_ns_steps: int = 5,
@@ -95,10 +90,15 @@ def __init__(
9590
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
9691
)
9792
use_syrk = False
98-
orthogonalize_fn = partial(
99-
newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk
100-
)
101-
scale_factor_fn = partial(get_muon_scale_factor, mode=scale_mode, extra_scale_factor=extra_scale_factor)
93+
94+
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
95+
logging.debug(
96+
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, "
97+
f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
98+
)
99+
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk)
100+
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
101+
return orth_grad * scale_factor * extra_scale_factor
102102

103103
super().__init__(
104104
params,
@@ -107,21 +107,15 @@ def __init__(
107107
use_nesterov,
108108
weight_decay,
109109
use_decoupled_weight_decay,
110-
split_qkv,
111-
is_qkv_fn,
112-
qkv_split_shapes,
113110
fp32_matmul_prec,
114-
orthogonalize_fn,
115-
scale_factor_fn,
111+
scaled_orthogonalize_fn,
116112
)
117113

118114

119115
Muon.__doc__ = Muon.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
120116

121117

122-
def get_muon_scale_factor(
123-
size_out: int, size_in: int, mode: str = "spectral", extra_scale_factor: float = 1.0
124-
) -> float:
118+
def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") -> float:
125119
"""Get the scale for the update.
126120
127121
Default mode is "spectral", which is the mode that allows for learning rate transferability from AdamW.
@@ -133,19 +127,18 @@ def get_muon_scale_factor(
133127
size_out: The size of the output tensor.
134128
size_in: The size of the input tensor.
135129
mode: The mode to use for the scale.
136-
extra_scale_factor: The additional scale factor to use for the update.
137130
Returns:
138131
The scale factor for the update.
139132
"""
140133
if mode == "shape_scaling":
141134
# Suggested by Muon (https://kellerjordan.github.io/posts/muon/)
142-
return extra_scale_factor * max(1, size_out / size_in) ** 0.5
135+
return max(1, size_out / size_in) ** 0.5
143136
elif mode == "spectral":
144137
# Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982)
145-
return extra_scale_factor * max(size_out, size_in) ** 0.5
138+
return max(size_out, size_in) ** 0.5
146139
elif mode == "unit_rms_norm":
147140
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al.
148141
# (https://jeremybernste.in/writing/deriving-muon)
149-
return extra_scale_factor * (size_out / size_in) ** 0.5
142+
return (size_out / size_in) ** 0.5
150143
else:
151144
raise ValueError(f"Invalid mode for Muon update scale factor: {mode}")

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
3737
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
3838
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
39-
split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
40-
is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
41-
qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers
42-
representing the sizes of Q, K, V components along the first dimension.
4339
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
4440
"""
4541

@@ -48,7 +44,8 @@ class OrthogonalizedOptimizer(optim.Optimizer):
4844
"""Base class for orthogonalized optimizers.
4945
5046
This class is a wrapper around a base optimizer that performs orthogonalization on the updates.
51-
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:
47+
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the
48+
following papers:
5249
5350
- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
5451
In International Conference on Artificial Intelligence and Statistics (2015a).
@@ -62,15 +59,33 @@ class OrthogonalizedOptimizer(optim.Optimizer):
6259
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]
6360
6461
Note:
65-
Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide
66-
a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the
67-
leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across
68-
the network must have the same size.
62+
OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately.
63+
Subclass can override the orthogonalize function to support this, see example below.
64+
65+
.. code-block:: python
66+
:caption: Split QKV example
67+
68+
class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
69+
def __init__(..., split_qkv_shapes):
70+
super().__init__(...)
71+
self.qkv_split_shapes = split_qkv_shapes
72+
73+
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
74+
75+
# Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
76+
# scaled_orthogonalize_fn.
77+
if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
78+
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
79+
qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
80+
grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
81+
else:
82+
grad = self.scaled_orthogonalize_fn(grad)
83+
84+
return grad
6985
7086
Args:
7187
{_args_doc}
72-
orthogonalize_fn: Function to orthogonalize the updates.
73-
scale_factor_fn: Function to compute the scale factor for the update.
88+
scaled_orthogonalize_fn: Function to orthogonalize and scale the updates.
7489
**kwargs: Arguments passed through to the base optimizer.
7590
7691
Note:
@@ -85,40 +100,13 @@ def __init__(
85100
use_nesterov: bool,
86101
weight_decay: float,
87102
use_decoupled_weight_decay: bool,
88-
split_qkv: bool,
89-
is_qkv_fn: Callable[[torch.Tensor], bool] | None,
90-
qkv_split_shapes: tuple[int, int, int] | None,
91103
fp32_matmul_prec: str,
92-
orthogonalize_fn: Callable | None = None,
93-
scale_factor_fn: Callable | None = None,
104+
scaled_orthogonalize_fn: Callable | None = None,
94105
**kwargs: Any,
95106
):
96-
if orthogonalize_fn is None:
97-
logging.warning("orthogonalize_fn not provided. Using noop")
98-
orthogonalize_fn = torch.nn.Identity()
99-
100-
if scale_factor_fn is None:
101-
logging.warning("scale_factor_fn not provided. Using default scale_factor_fn.")
102-
103-
def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
104-
return 1.0
105-
106-
scale_factor_fn = return_one
107-
108-
if split_qkv:
109-
assert is_qkv_fn is not None, "is_qkv_fn must be provided when split_qkv is True"
110-
assert qkv_split_shapes is not None, "qkv_split_shapes must be provided when split_qkv is True"
111-
if len(qkv_split_shapes) != 3:
112-
raise ValueError(
113-
f"qkv_split_shapes must be a tuple of 3 integers, got {len(qkv_split_shapes)} elements"
114-
)
115-
if not all(isinstance(s, int) for s in qkv_split_shapes):
116-
raise ValueError(f"All elements in qkv_split_shapes must be integers, got {qkv_split_shapes}")
117-
if any(s <= 0 for s in qkv_split_shapes):
118-
raise ValueError(f"All elements in qkv_split_shapes must be positive, got {qkv_split_shapes}")
119-
self.split_qkv = split_qkv
120-
self.is_qkv_fn = is_qkv_fn
121-
self.qkv_split_shapes = qkv_split_shapes
107+
if scaled_orthogonalize_fn is None:
108+
logging.warning("scaled_orthogonalize_fn not provided. Using noop")
109+
scaled_orthogonalize_fn = torch.nn.Identity()
122110

123111
self.fp32_matmul_prec = fp32_matmul_prec
124112
default_args_dict = dict(
@@ -131,8 +119,7 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
131119
)
132120

133121
super().__init__(params, default_args_dict)
134-
self.orthogonalize_fn = orthogonalize_fn
135-
self.scale_factor_fn = scale_factor_fn
122+
self.scaled_orthogonalize_fn = scaled_orthogonalize_fn
136123

137124
@torch.no_grad() # type: ignore[misc]
138125
@override
@@ -182,36 +169,34 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
182169
grad = exp_avg
183170

184171
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
185-
grad = self.orthogonalize(p, grad)
172+
group_kwargs = {k: v for k, v in group.items() if k != "params"}
173+
grad = self.orthogonalize(p, grad, **group_kwargs)
186174

187175
# perform weight update
188176
# scale is applied to have update RMS == 1
189177
p.add_(grad, alpha=-group["lr"])
190178

191179
return loss
192180

193-
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
181+
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
194182
"""Orthogonalize the momentum.
195183
184+
The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can
185+
override this function to implement different orthogonalization logic as well as split fused parameters.
186+
For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
187+
the parameter is a fused parameter and should be split for preconditioning.
188+
196189
Args:
197-
p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of
198-
information is only available in the param tensor, attributes for example.
190+
p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
191+
information is only available in the param tensor, attributes for example. Although not used in
192+
this default orthogonalize function.
199193
grad: The momentum tensor.
194+
**kwargs: keyword arguments of the param_group that p was belonged to.
200195
201196
Returns:
202197
The orthogonalized gradient tensor.
203198
"""
204-
if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc]
205-
logging.log_first_n(logging.INFO, f"split qkv with {p.shape} to {self.qkv_split_shapes}", 1)
206-
# split grouped attention parameters (e.g., QKV, GQA, etc.)
207-
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
208-
# Apply Newton-Schulz to each component
209-
qkv_whitened = [self.orthogonalize_fn(g) for g in qkv_grads]
210-
qkv_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in qkv_grads]
211-
# Apply individual scales to each component and concatenate
212-
grad = torch.cat([whitened * scale for whitened, scale in zip(qkv_whitened, qkv_scales)])
213-
else:
214-
grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1))
199+
grad = self.scaled_orthogonalize_fn(grad)
215200
return grad
216201

217202

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from emerging_optimizers.psgd.psgd import *

0 commit comments

Comments
 (0)