Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable

import torch
from absl import logging
Expand Down Expand Up @@ -69,9 +67,6 @@ def __init__(
use_nesterov: bool = True,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
split_qkv: bool = False,
is_qkv_fn: Callable[[torch.Tensor], bool] | None = None,
qkv_split_shapes: tuple[int, int, int] | None = None,
fp32_matmul_prec: str = "medium",
coefficient_type: str = "quintic",
num_ns_steps: int = 5,
Expand All @@ -95,10 +90,15 @@ def __init__(
f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False."
)
use_syrk = False
orthogonalize_fn = partial(
newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk
)
scale_factor_fn = partial(get_muon_scale_factor, mode=scale_mode, extra_scale_factor=extra_scale_factor)

def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
logging.debug(
f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, "
f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}"
)
orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk)
scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode)
return orth_grad * scale_factor * extra_scale_factor

super().__init__(
params,
Expand All @@ -107,21 +107,15 @@ def __init__(
use_nesterov,
weight_decay,
use_decoupled_weight_decay,
split_qkv,
is_qkv_fn,
qkv_split_shapes,
fp32_matmul_prec,
orthogonalize_fn,
scale_factor_fn,
scaled_orthogonalize_fn,
)


Muon.__doc__ = Muon.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]


def get_muon_scale_factor(
size_out: int, size_in: int, mode: str = "spectral", extra_scale_factor: float = 1.0
) -> float:
def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") -> float:
"""Get the scale for the update.

Default mode is "spectral", which is the mode that allows for learning rate transferability from AdamW.
Expand All @@ -133,19 +127,18 @@ def get_muon_scale_factor(
size_out: The size of the output tensor.
size_in: The size of the input tensor.
mode: The mode to use for the scale.
extra_scale_factor: The additional scale factor to use for the update.
Returns:
The scale factor for the update.
"""
if mode == "shape_scaling":
# Suggested by Muon (https://kellerjordan.github.io/posts/muon/)
return extra_scale_factor * max(1, size_out / size_in) ** 0.5
return max(1, size_out / size_in) ** 0.5
elif mode == "spectral":
# Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982)
return extra_scale_factor * max(size_out, size_in) ** 0.5
return max(size_out, size_in) ** 0.5
elif mode == "unit_rms_norm":
# Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al.
# (https://jeremybernste.in/writing/deriving-muon)
return extra_scale_factor * (size_out / size_in) ** 0.5
return (size_out / size_in) ** 0.5
else:
raise ValueError(f"Invalid mode for Muon update scale factor: {mode}")
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,6 @@
weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers
representing the sizes of Q, K, V components along the first dimension.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
"""

Expand All @@ -48,7 +44,8 @@ class OrthogonalizedOptimizer(optim.Optimizer):
"""Base class for orthogonalized optimizers.

This class is a wrapper around a base optimizer that performs orthogonalization on the updates.
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:
The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the
following papers:

- Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.*
In International Conference on Artificial Intelligence and Statistics (2015a).
Expand All @@ -62,15 +59,33 @@ class OrthogonalizedOptimizer(optim.Optimizer):
arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_]

Note:
Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide
a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the
leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across
the network must have the same size.
OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately.
Subclass can override the orthogonalize function to support this, see example below.

.. code-block:: python
:caption: Split QKV example

class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
def __init__(..., split_qkv_shapes):
super().__init__(...)
self.qkv_split_shapes = split_qkv_shapes

def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:

# Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
# scaled_orthogonalize_fn.
if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
else:
grad = self.scaled_orthogonalize_fn(grad)

return grad

Args:
{_args_doc}
orthogonalize_fn: Function to orthogonalize the updates.
scale_factor_fn: Function to compute the scale factor for the update.
scaled_orthogonalize_fn: Function to orthogonalize and scale the updates.
**kwargs: Arguments passed through to the base optimizer.

Note:
Expand All @@ -85,40 +100,13 @@ def __init__(
use_nesterov: bool,
weight_decay: float,
use_decoupled_weight_decay: bool,
split_qkv: bool,
is_qkv_fn: Callable[[torch.Tensor], bool] | None,
qkv_split_shapes: tuple[int, int, int] | None,
fp32_matmul_prec: str,
orthogonalize_fn: Callable | None = None,
scale_factor_fn: Callable | None = None,
scaled_orthogonalize_fn: Callable | None = None,
**kwargs: Any,
):
if orthogonalize_fn is None:
logging.warning("orthogonalize_fn not provided. Using noop")
orthogonalize_fn = torch.nn.Identity()

if scale_factor_fn is None:
logging.warning("scale_factor_fn not provided. Using default scale_factor_fn.")

def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
return 1.0

scale_factor_fn = return_one

if split_qkv:
assert is_qkv_fn is not None, "is_qkv_fn must be provided when split_qkv is True"
assert qkv_split_shapes is not None, "qkv_split_shapes must be provided when split_qkv is True"
if len(qkv_split_shapes) != 3:
raise ValueError(
f"qkv_split_shapes must be a tuple of 3 integers, got {len(qkv_split_shapes)} elements"
)
if not all(isinstance(s, int) for s in qkv_split_shapes):
raise ValueError(f"All elements in qkv_split_shapes must be integers, got {qkv_split_shapes}")
if any(s <= 0 for s in qkv_split_shapes):
raise ValueError(f"All elements in qkv_split_shapes must be positive, got {qkv_split_shapes}")
self.split_qkv = split_qkv
self.is_qkv_fn = is_qkv_fn
self.qkv_split_shapes = qkv_split_shapes
if scaled_orthogonalize_fn is None:
logging.warning("scaled_orthogonalize_fn not provided. Using noop")
scaled_orthogonalize_fn = torch.nn.Identity()

self.fp32_matmul_prec = fp32_matmul_prec
default_args_dict = dict(
Expand All @@ -131,8 +119,7 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def]
)

super().__init__(params, default_args_dict)
self.orthogonalize_fn = orthogonalize_fn
self.scale_factor_fn = scale_factor_fn
self.scaled_orthogonalize_fn = scaled_orthogonalize_fn

@torch.no_grad() # type: ignore[misc]
@override
Expand Down Expand Up @@ -182,36 +169,34 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:
grad = exp_avg

with utils.fp32_matmul_precision(self.fp32_matmul_prec):
grad = self.orthogonalize(p, grad)
group_kwargs = {k: v for k, v in group.items() if k != "params"}
grad = self.orthogonalize(p, grad, **group_kwargs)

# perform weight update
# scale is applied to have update RMS == 1
p.add_(grad, alpha=-group["lr"])

return loss

def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor:
def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Orthogonalize the momentum.

The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can
override this function to implement different orthogonalization logic as well as split fused parameters.
For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if
the parameter is a fused parameter and should be split for preconditioning.

Args:
p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of
information is only available in the param tensor, attributes for example.
p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of
information is only available in the param tensor, attributes for example. Although not used in
this default orthogonalize function.
grad: The momentum tensor.
**kwargs: keyword arguments of the param_group that p was belonged to.

Returns:
The orthogonalized gradient tensor.
"""
if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc]
logging.log_first_n(logging.INFO, f"split qkv with {p.shape} to {self.qkv_split_shapes}", 1)
# split grouped attention parameters (e.g., QKV, GQA, etc.)
qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
# Apply Newton-Schulz to each component
qkv_whitened = [self.orthogonalize_fn(g) for g in qkv_grads]
qkv_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in qkv_grads]
# Apply individual scales to each component and concatenate
grad = torch.cat([whitened * scale for whitened, scale in zip(qkv_whitened, qkv_scales)])
else:
grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1))
grad = self.scaled_orthogonalize_fn(grad)
return grad


Expand Down
19 changes: 0 additions & 19 deletions tests/test_muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,25 +191,6 @@ def test_get_scale_factor(self, size_pairs, mode):
else:
raise ValueError(f"Invalid mode: {mode}")

def test_qkv_split_shapes_validation(self):
"""Test validation of qkv_split_shapes parameter"""
dummy_param = torch.nn.Parameter(torch.randn(4, 4))
dummy_args = dict(split_qkv=True, is_qkv_fn=lambda x: True)
# Test non-integer values
with self.assertRaises(ValueError) as cm:
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512.5, 256, 256))
self.assertIn("must be integers", str(cm.exception))

# Test negative values
with self.assertRaises(ValueError) as cm:
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, -256, 256))
self.assertIn("must be positive", str(cm.exception))

# Test wrong number of elements
with self.assertRaises(ValueError) as cm:
muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, 256))
self.assertIn("tuple of 3 integers", str(cm.exception))


@absltest.skipIf(
_SM_VERSION not in ((8, 0), (9, 0), (10, 0), (10, 3)),
Expand Down
46 changes: 21 additions & 25 deletions tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from absl.testing import absltest, parameterized
Expand Down Expand Up @@ -42,9 +43,6 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None:
use_nesterov=False,
weight_decay=0.5,
use_decoupled_weight_decay=True,
split_qkv=False,
is_qkv_fn=None,
qkv_split_shapes=None,
fp32_matmul_prec="highest",
)

Expand Down Expand Up @@ -86,9 +84,6 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
use_nesterov=False,
weight_decay=0.0,
use_decoupled_weight_decay=False,
split_qkv=False,
is_qkv_fn=None,
qkv_split_shapes=None,
fp32_matmul_prec="highest",
)

Expand All @@ -114,40 +109,41 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) ->
rtol=0,
)

def test_split_qkv_matches_ref(self) -> None:
test_param = torch.randint(-5, 5, (6, 7), dtype=torch.float32, device="cuda")
test_param.grad = torch.randint_like(test_param, -5, 5)
split_shapes = (1, 2, 3)
lr = 2.0
def test_split_fn_interleaved(self) -> None:
"""Test a three way interleaved split function.

def is_qkv_fn(x: torch.Tensor) -> bool:
return x.shape == torch.Size([6, 7])
With 0 weights and lr -1, returned param should match orthogonalized grads.
"""
test_param = torch.zeros((6, 7), dtype=torch.float32, device="cuda")
test_param.grad = torch.empty_like(test_param.data)

def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor:
return x * x
for i in range(test_param.shape[0]):
test_param.grad[i] = i + 1

ref_orth_grads = []
for g in torch.split(test_param.grad, split_shapes, dim=0):
ref_orth_grads.append(dummy_orth_fn(g))
ref_out = test_param - torch.cat(ref_orth_grads, dim=0) * lr
def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor:
out_list = [[], [], []]
for i in range(x.shape[0]):
out_list[i % 3].append(x[i : i + 1])
orth_grad_list = [torch.cat(t, dim=0) for t in out_list]
return torch.cat([torch.empty_like(x).fill_(x.max()) for x in orth_grad_list], dim=0)

orthogonalized_opt = OrthogonalizedOptimizer(
[test_param],
lr=lr,
lr=-1,
momentum_beta=0,
use_nesterov=False,
weight_decay=0.0,
use_decoupled_weight_decay=False,
split_qkv=True,
is_qkv_fn=is_qkv_fn,
qkv_split_shapes=(1, 2, 3),
fp32_matmul_prec="highest",
orthogonalize_fn=dummy_orth_fn,
scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn,
)
orthogonalized_opt.step()

assert not torch.allclose(test_param, test_param.grad)

ref_out = dummy_interleaved_split_orth_fn(test_param.grad)
torch.testing.assert_close(
test_param.data,
test_param,
ref_out,
atol=0,
rtol=0,
Expand Down