Skip to content

Commit d6b8fa6

Browse files
authored
Generalize fused weight split support (#57)
* Provide interface for user to split instead of trying to generalize for everything. Signed-off-by: Hao Wu <[email protected]>
1 parent 9d93954 commit d6b8fa6

File tree

4 files changed

+79
-124
lines changed

4 files changed

+79
-124
lines changed

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from functools import partial
16-
from typing import Callable
1715

1816
import torch
1917
from absl import logging
@@ -69,9 +67,6 @@ def __init__(
6967
use_nesterov: bool = True,
7068
weight_decay: float = 0.01,
7169
use_decoupled_weight_decay: bool = True,
72-
split_qkv: bool = False,
73-
is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
74-
qkv_split_shapes: tuple[int, int, int] | None = None,
7570
fp32_matmul_prec: str = "medium",
7671
coefficient_type: str = "quintic",
7772
num_ns_steps: int = 5,
@@ -95,10 +90,15 @@ def __init__(
9590
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
9691
)
9792
use_syrk = False
98-
orthogonalize_fn = partial(
99-
newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk
100-
)
101-
scale_factor_fn = partial(get_muon_scale_factor, mode=scale_mode, extra_scale_factor=extra_scale_factor)
93+
94+
def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
95+
logging.debug(
96+
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, "
97+
f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
98+
)
99+
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk)
100+
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
101+
return orth_grad * scale_factor * extra_scale_factor
102102

103103
super().__init__(
104104
params,
@@ -107,21 +107,15 @@ def __init__(
107107
use_nesterov,
108108
weight_decay,
109109
use_decoupled_weight_decay,
110-
split_qkv,
111-
is_qkv_fn,
112-
qkv_split_shapes,
113110
fp32_matmul_prec,
114-
orthogonalize_fn,
115-
scale_factor_fn,
111+
scaled_orthogonalize_fn,
116112
)
117113

118114

119115
Muon.__doc__ = Muon.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]
120116

121117

122-
def get_muon_scale_factor(
123-
size_out: int, size_in: int, mode: str = "spectral", extra_scale_factor: float = 1.0
124-
) -> float:
118+
def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") -> float:
125119
"""Get the scale for the update.
126120
127121
Default mode is "spectral", which is the mode that allows for learning rate transferability from AdamW.
@@ -133,19 +127,18 @@ def get_muon_scale_factor(
133127
size_out: The size of the output tensor.
134128
size_in: The size of the input tensor.
135129
mode: The mode to use for the scale.
136-
extra_scale_factor: The additional scale factor to use for the update.
137130
Returns:
138131
The scale factor for the update.
139132
"""
140133
if mode == "shape_scaling":
141134
# Suggested by Muon (https://kellerjordan.github.io/posts/muon/)
142-
return extra_scale_factor * max(1, size_out / size_in) ** 0.5
135+
return max(1, size_out / size_in) ** 0.5
143136
elif mode == "spectral":
144137
# Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982)
145-
return extra_scale_factor * max(size_out, size_in) ** 0.5
138+
return max(size_out, size_in) ** 0.5
146139
elif mode == "unit_rms_norm":
147140
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al.
148141
# (https://jeremybernste.in/writing/deriving-muon)
149-
return extra_scale_factor * (size_out / size_in) ** 0.5
142+
return (size_out / size_in) ** 0.5
150143
else:
151144
raise ValueError(f"Invalid mode for Muon update scale factor: {mode}")

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
3737
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
3838
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
39-
split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
40-
is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
41-
qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers
42-
representing the sizes of Q, K, V components along the first dimension.
4339
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
4440
"""
4541

@@ -48,7 +44,8 @@ class OrthogonalizedOptimizer(optim.Optimizer):
4844
"""Base class for orthogonalized optimizers.
4945
5046
This class is a wrapper around a base optimizer that performs orthogonalization on the updates.
51-
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:
47+
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the
48+
following papers:
5249
5350
- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
5451
In International Conference on Artificial Intelligence and Statistics (2015a).
@@ -62,15 +59,33 @@ class OrthogonalizedOptimizer(optim.Optimizer):
6259
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]
6360
6461
Note:
65-
Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide
66-
a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the
67-
leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across
68-
the network must have the same size.
62+
OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately.
63+
Subclass can override the orthogonalize function to support this, see example below.
64+
65+
.. code-block:: python
66+
:caption: Split QKV example
67+
68+
class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
69+
def __init__(..., split_qkv_shapes):
70+
super().__init__(...)
71+
self.qkv_split_shapes = split_qkv_shapes
72+
73+
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
74+
75+
# Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
76+
# scaled_orthogonalize_fn.
77+
if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
78+
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
79+
qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
80+
grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
81+
else:
82+
grad = self.scaled_orthogonalize_fn(grad)
83+
84+
return grad
6985
7086
Args:
7187
{_args_doc}
72-
orthogonalize_fn: Function to orthogonalize the updates.
73-
scale_factor_fn: Function to compute the scale factor for the update.
88+
scaled_orthogonalize_fn: Function to orthogonalize and scale the updates.
7489
**kwargs: Arguments passed through to the base optimizer.
7590
7691
Note:
@@ -85,40 +100,13 @@ def __init__(
85100
use_nesterov: bool,
86101
weight_decay: float,
87102
use_decoupled_weight_decay: bool,
88-
split_qkv: bool,
89-
is_qkv_fn: Callable[[torch.Tensor], bool] | None,
90-
qkv_split_shapes: tuple[int, int, int] | None,
91103
fp32_matmul_prec: str,
92-
orthogonalize_fn: Callable | None = None,
93-
scale_factor_fn: Callable | None = None,
104+
scaled_orthogonalize_fn: Callable | None = None,
94105
**kwargs: Any,
95106
):
96-
if orthogonalize_fn is None:
97-
logging.warning("orthogonalize_fn not provided. Using noop")
98-
orthogonalize_fn = torch.nn.Identity()
99-
100-
if scale_factor_fn is None:
101-
logging.warning("scale_factor_fn not provided. Using default scale_factor_fn.")
102-
103-
def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
104-
return 1.0
105-
106-
scale_factor_fn = return_one
107-
108-
if split_qkv:
109-
assert is_qkv_fn is not None, "is_qkv_fn must be provided when split_qkv is True"
110-
assert qkv_split_shapes is not None, "qkv_split_shapes must be provided when split_qkv is True"
111-
if len(qkv_split_shapes) != 3:
112-
raise ValueError(
113-
f"qkv_split_shapes must be a tuple of 3 integers, got {len(qkv_split_shapes)} elements"
114-
)
115-
if not all(isinstance(s, int) for s in qkv_split_shapes):
116-
raise ValueError(f"All elements in qkv_split_shapes must be integers, got {qkv_split_shapes}")
117-
if any(s <= 0 for s in qkv_split_shapes):
118-
raise ValueError(f"All elements in qkv_split_shapes must be positive, got {qkv_split_shapes}")
119-
self.split_qkv = split_qkv
120-
self.is_qkv_fn = is_qkv_fn
121-
self.qkv_split_shapes = qkv_split_shapes
107+
if scaled_orthogonalize_fn is None:
108+
logging.warning("scaled_orthogonalize_fn not provided. Using noop")
109+
scaled_orthogonalize_fn = torch.nn.Identity()
122110

123111
self.fp32_matmul_prec = fp32_matmul_prec
124112
default_args_dict = dict(
@@ -131,8 +119,7 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
131119
)
132120

133121
super().__init__(params, default_args_dict)
134-
self.orthogonalize_fn = orthogonalize_fn
135-
self.scale_factor_fn = scale_factor_fn
122+
self.scaled_orthogonalize_fn = scaled_orthogonalize_fn
136123

137124
@torch.no_grad() # type: ignore[misc]
138125
@override
@@ -182,36 +169,34 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
182169
grad = exp_avg
183170

184171
with utils.fp32_matmul_precision(self.fp32_matmul_prec):
185-
grad = self.orthogonalize(p, grad)
172+
group_kwargs = {k: v for k, v in group.items() if k != "params"}
173+
grad = self.orthogonalize(p, grad, **group_kwargs)
186174

187175
# perform weight update
188176
# scale is applied to have update RMS == 1
189177
p.add_(grad, alpha=-group["lr"])
190178

191179
return loss
192180

193-
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
181+
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
194182
"""Orthogonalize the momentum.
195183
184+
The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can
185+
override this function to implement different orthogonalization logic as well as split fused parameters.
186+
For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
187+
the parameter is a fused parameter and should be split for preconditioning.
188+
196189
Args:
197-
p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of
198-
information is only available in the param tensor, attributes for example.
190+
p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
191+
information is only available in the param tensor, attributes for example. Although not used in
192+
this default orthogonalize function.
199193
grad: The momentum tensor.
194+
**kwargs: keyword arguments of the param_group that p was belonged to.
200195
201196
Returns:
202197
The orthogonalized gradient tensor.
203198
"""
204-
if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc]
205-
logging.log_first_n(logging.INFO, f"split qkv with {p.shape} to {self.qkv_split_shapes}", 1)
206-
# split grouped attention parameters (e.g., QKV, GQA, etc.)
207-
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
208-
# Apply Newton-Schulz to each component
209-
qkv_whitened = [self.orthogonalize_fn(g) for g in qkv_grads]
210-
qkv_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in qkv_grads]
211-
# Apply individual scales to each component and concatenate
212-
grad = torch.cat([whitened * scale for whitened, scale in zip(qkv_whitened, qkv_scales)])
213-
else:
214-
grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1))
199+
grad = self.scaled_orthogonalize_fn(grad)
215200
return grad
216201

217202

tests/test_muon_utils.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,6 @@ def test_get_scale_factor(self, size_pairs, mode):
191191
else:
192192
raise ValueError(f"Invalid mode: {mode}")
193193

194-
def test_qkv_split_shapes_validation(self):
195-
"""Test validation of qkv_split_shapes parameter"""
196-
dummy_param = torch.nn.Parameter(torch.randn(4, 4))
197-
dummy_args = dict(split_qkv=True, is_qkv_fn=lambda x: True)
198-
# Test non-integer values
199-
with self.assertRaises(ValueError) as cm:
200-
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512.5, 256, 256))
201-
self.assertIn("must be integers", str(cm.exception))
202-
203-
# Test negative values
204-
with self.assertRaises(ValueError) as cm:
205-
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, -256, 256))
206-
self.assertIn("must be positive", str(cm.exception))
207-
208-
# Test wrong number of elements
209-
with self.assertRaises(ValueError) as cm:
210-
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, 256))
211-
self.assertIn("tuple of 3 integers", str(cm.exception))
212-
213194

214195
@absltest.skipIf(
215196
_SM_VERSION not in ((8, 0), (9, 0), (10, 0), (10, 3)),

tests/test_orthogonalized_optimizer.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
1516
import torch
1617
import torch.nn as nn
1718
from absl.testing import absltest, parameterized
@@ -42,9 +43,6 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None:
4243
use_nesterov=False,
4344
weight_decay=0.5,
4445
use_decoupled_weight_decay=True,
45-
split_qkv=False,
46-
is_qkv_fn=None,
47-
qkv_split_shapes=None,
4846
fp32_matmul_prec="highest",
4947
)
5048

@@ -86,9 +84,6 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
8684
use_nesterov=False,
8785
weight_decay=0.0,
8886
use_decoupled_weight_decay=False,
89-
split_qkv=False,
90-
is_qkv_fn=None,
91-
qkv_split_shapes=None,
9287
fp32_matmul_prec="highest",
9388
)
9489

@@ -114,40 +109,41 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
114109
rtol=0,
115110
)
116111

117-
def test_split_qkv_matches_ref(self) -> None:
118-
test_param = torch.randint(-5, 5, (6, 7), dtype=torch.float32, device="cuda")
119-
test_param.grad = torch.randint_like(test_param, -5, 5)
120-
split_shapes = (1, 2, 3)
121-
lr = 2.0
112+
def test_split_fn_interleaved(self) -> None:
113+
"""Test a three way interleaved split function.
122114
123-
def is_qkv_fn(x: torch.Tensor) -> bool:
124-
return x.shape == torch.Size([6, 7])
115+
With 0 weights and lr -1, returned param should match orthogonalized grads.
116+
"""
117+
test_param = torch.zeros((6, 7), dtype=torch.float32, device="cuda")
118+
test_param.grad = torch.empty_like(test_param.data)
125119

126-
def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor:
127-
return x * x
120+
for i in range(test_param.shape[0]):
121+
test_param.grad[i] = i + 1
128122

129-
ref_orth_grads = []
130-
for g in torch.split(test_param.grad, split_shapes, dim=0):
131-
ref_orth_grads.append(dummy_orth_fn(g))
132-
ref_out = test_param - torch.cat(ref_orth_grads, dim=0) * lr
123+
def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor:
124+
out_list = [[], [], []]
125+
for i in range(x.shape[0]):
126+
out_list[i % 3].append(x[i : i + 1])
127+
orth_grad_list = [torch.cat(t, dim=0) for t in out_list]
128+
return torch.cat([torch.empty_like(x).fill_(x.max()) for x in orth_grad_list], dim=0)
133129

134130
orthogonalized_opt = OrthogonalizedOptimizer(
135131
[test_param],
136-
lr=lr,
132+
lr=-1,
137133
momentum_beta=0,
138134
use_nesterov=False,
139135
weight_decay=0.0,
140136
use_decoupled_weight_decay=False,
141-
split_qkv=True,
142-
is_qkv_fn=is_qkv_fn,
143-
qkv_split_shapes=(1, 2, 3),
144137
fp32_matmul_prec="highest",
145-
orthogonalize_fn=dummy_orth_fn,
138+
scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn,
146139
)
147140
orthogonalized_opt.step()
148141

142+
assert not torch.allclose(test_param, test_param.grad)
143+
144+
ref_out = dummy_interleaved_split_orth_fn(test_param.grad)
149145
torch.testing.assert_close(
150-
test_param.data,
146+
test_param,
151147
ref_out,
152148
atol=0,
153149
rtol=0,

0 commit comments

Comments
 (0)