diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 3d096e6..b17283c 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Callable import torch from absl import logging @@ -69,9 +67,6 @@ def __init__( use_nesterov: bool = True, weight_decay: float = 0.01, use_decoupled_weight_decay: bool = True, - split_qkv: bool = False, - is_qkv_fn: Callable[[torch.Tensor], bool] | None = None, - qkv_split_shapes: tuple[int, int, int] | None = None, fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, @@ -95,10 +90,15 @@ def __init__( f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." ) use_syrk = False - orthogonalize_fn = partial( - newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk - ) - scale_factor_fn = partial(get_muon_scale_factor, mode=scale_mode, extra_scale_factor=extra_scale_factor) + + def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: + logging.debug( + f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, " + f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" + ) + orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk) + scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) + return orth_grad * scale_factor * extra_scale_factor super().__init__( params, @@ -107,21 +107,15 @@ def __init__( use_nesterov, weight_decay, use_decoupled_weight_decay, - split_qkv, - is_qkv_fn, - qkv_split_shapes, fp32_matmul_prec, - orthogonalize_fn, - scale_factor_fn, + scaled_orthogonalize_fn, ) Muon.__doc__ = Muon.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] -def get_muon_scale_factor( - size_out: int, size_in: int, mode: str = "spectral", extra_scale_factor: float = 1.0 -) -> float: +def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") -> float: """Get the scale for the update. Default mode is "spectral", which is the mode that allows for learning rate transferability from AdamW. @@ -133,19 +127,18 @@ def get_muon_scale_factor( size_out: The size of the output tensor. size_in: The size of the input tensor. mode: The mode to use for the scale. - extra_scale_factor: The additional scale factor to use for the update. Returns: The scale factor for the update. """ if mode == "shape_scaling": # Suggested by Muon (https://kellerjordan.github.io/posts/muon/) - return extra_scale_factor * max(1, size_out / size_in) ** 0.5 + return max(1, size_out / size_in) ** 0.5 elif mode == "spectral": # Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982) - return extra_scale_factor * max(size_out, size_in) ** 0.5 + return max(size_out, size_in) ** 0.5 elif mode == "unit_rms_norm": # Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. # (https://jeremybernste.in/writing/deriving-muon) - return extra_scale_factor * (size_out / size_in) ** 0.5 + return (size_out / size_in) ** 0.5 else: raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index bead1fc..7d49c75 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -36,10 +36,6 @@ weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True. - split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False. - is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.). - qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers - representing the sizes of Q, K, V components along the first dimension. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. """ @@ -48,7 +44,8 @@ class OrthogonalizedOptimizer(optim.Optimizer): """Base class for orthogonalized optimizers. This class is a wrapper around a base optimizer that performs orthogonalization on the updates. - The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers: + The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the + following papers: - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.* In International Conference on Artificial Intelligence and Statistics (2015a). @@ -62,15 +59,33 @@ class OrthogonalizedOptimizer(optim.Optimizer): arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 `_] Note: - Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide - a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the - leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across - the network must have the same size. + OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately. + Subclass can override the orthogonalize function to support this, see example below. + + .. code-block:: python + :caption: Split QKV example + + class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer): + def __init__(..., split_qkv_shapes): + super().__init__(...) + self.qkv_split_shapes = split_qkv_shapes + + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: + + # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the + # scaled_orthogonalize_fn. + if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False): + qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) + qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads] + grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized]) + else: + grad = self.scaled_orthogonalize_fn(grad) + + return grad Args: {_args_doc} - orthogonalize_fn: Function to orthogonalize the updates. - scale_factor_fn: Function to compute the scale factor for the update. + scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. **kwargs: Arguments passed through to the base optimizer. Note: @@ -85,40 +100,13 @@ def __init__( use_nesterov: bool, weight_decay: float, use_decoupled_weight_decay: bool, - split_qkv: bool, - is_qkv_fn: Callable[[torch.Tensor], bool] | None, - qkv_split_shapes: tuple[int, int, int] | None, fp32_matmul_prec: str, - orthogonalize_fn: Callable | None = None, - scale_factor_fn: Callable | None = None, + scaled_orthogonalize_fn: Callable | None = None, **kwargs: Any, ): - if orthogonalize_fn is None: - logging.warning("orthogonalize_fn not provided. Using noop") - orthogonalize_fn = torch.nn.Identity() - - if scale_factor_fn is None: - logging.warning("scale_factor_fn not provided. Using default scale_factor_fn.") - - def return_one(*args, **kwargs): # type: ignore[no-untyped-def] - return 1.0 - - scale_factor_fn = return_one - - if split_qkv: - assert is_qkv_fn is not None, "is_qkv_fn must be provided when split_qkv is True" - assert qkv_split_shapes is not None, "qkv_split_shapes must be provided when split_qkv is True" - if len(qkv_split_shapes) != 3: - raise ValueError( - f"qkv_split_shapes must be a tuple of 3 integers, got {len(qkv_split_shapes)} elements" - ) - if not all(isinstance(s, int) for s in qkv_split_shapes): - raise ValueError(f"All elements in qkv_split_shapes must be integers, got {qkv_split_shapes}") - if any(s <= 0 for s in qkv_split_shapes): - raise ValueError(f"All elements in qkv_split_shapes must be positive, got {qkv_split_shapes}") - self.split_qkv = split_qkv - self.is_qkv_fn = is_qkv_fn - self.qkv_split_shapes = qkv_split_shapes + if scaled_orthogonalize_fn is None: + logging.warning("scaled_orthogonalize_fn not provided. Using noop") + scaled_orthogonalize_fn = torch.nn.Identity() self.fp32_matmul_prec = fp32_matmul_prec default_args_dict = dict( @@ -131,8 +119,7 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def] ) super().__init__(params, default_args_dict) - self.orthogonalize_fn = orthogonalize_fn - self.scale_factor_fn = scale_factor_fn + self.scaled_orthogonalize_fn = scaled_orthogonalize_fn @torch.no_grad() # type: ignore[misc] @override @@ -182,7 +169,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = exp_avg with utils.fp32_matmul_precision(self.fp32_matmul_prec): - grad = self.orthogonalize(p, grad) + group_kwargs = {k: v for k, v in group.items() if k != "params"} + grad = self.orthogonalize(p, grad, **group_kwargs) # perform weight update # scale is applied to have update RMS == 1 @@ -190,28 +178,25 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: return loss - def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: """Orthogonalize the momentum. + The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can + override this function to implement different orthogonalization logic as well as split fused parameters. + For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if + the parameter is a fused parameter and should be split for preconditioning. + Args: - p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of - information is only available in the param tensor, attributes for example. + p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of + information is only available in the param tensor, attributes for example. Although not used in + this default orthogonalize function. grad: The momentum tensor. + **kwargs: keyword arguments of the param_group that p was belonged to. Returns: The orthogonalized gradient tensor. """ - if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc] - logging.log_first_n(logging.INFO, f"split qkv with {p.shape} to {self.qkv_split_shapes}", 1) - # split grouped attention parameters (e.g., QKV, GQA, etc.) - qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) - # Apply Newton-Schulz to each component - qkv_whitened = [self.orthogonalize_fn(g) for g in qkv_grads] - qkv_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in qkv_grads] - # Apply individual scales to each component and concatenate - grad = torch.cat([whitened * scale for whitened, scale in zip(qkv_whitened, qkv_scales)]) - else: - grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1)) + grad = self.scaled_orthogonalize_fn(grad) return grad diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index 4deb053..0ee1d4c 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -191,25 +191,6 @@ def test_get_scale_factor(self, size_pairs, mode): else: raise ValueError(f"Invalid mode: {mode}") - def test_qkv_split_shapes_validation(self): - """Test validation of qkv_split_shapes parameter""" - dummy_param = torch.nn.Parameter(torch.randn(4, 4)) - dummy_args = dict(split_qkv=True, is_qkv_fn=lambda x: True) - # Test non-integer values - with self.assertRaises(ValueError) as cm: - muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512.5, 256, 256)) - self.assertIn("must be integers", str(cm.exception)) - - # Test negative values - with self.assertRaises(ValueError) as cm: - muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, -256, 256)) - self.assertIn("must be positive", str(cm.exception)) - - # Test wrong number of elements - with self.assertRaises(ValueError) as cm: - muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, 256)) - self.assertIn("tuple of 3 integers", str(cm.exception)) - @absltest.skipIf( _SM_VERSION not in ((8, 0), (9, 0), (10, 0), (10, 3)), diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 264e2f8..6bdee12 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import torch import torch.nn as nn from absl.testing import absltest, parameterized @@ -42,9 +43,6 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None: use_nesterov=False, weight_decay=0.5, use_decoupled_weight_decay=True, - split_qkv=False, - is_qkv_fn=None, - qkv_split_shapes=None, fp32_matmul_prec="highest", ) @@ -86,9 +84,6 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> use_nesterov=False, weight_decay=0.0, use_decoupled_weight_decay=False, - split_qkv=False, - is_qkv_fn=None, - qkv_split_shapes=None, fp32_matmul_prec="highest", ) @@ -114,40 +109,41 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> rtol=0, ) - def test_split_qkv_matches_ref(self) -> None: - test_param = torch.randint(-5, 5, (6, 7), dtype=torch.float32, device="cuda") - test_param.grad = torch.randint_like(test_param, -5, 5) - split_shapes = (1, 2, 3) - lr = 2.0 + def test_split_fn_interleaved(self) -> None: + """Test a three way interleaved split function. - def is_qkv_fn(x: torch.Tensor) -> bool: - return x.shape == torch.Size([6, 7]) + With 0 weights and lr -1, returned param should match orthogonalized grads. + """ + test_param = torch.zeros((6, 7), dtype=torch.float32, device="cuda") + test_param.grad = torch.empty_like(test_param.data) - def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: - return x * x + for i in range(test_param.shape[0]): + test_param.grad[i] = i + 1 - ref_orth_grads = [] - for g in torch.split(test_param.grad, split_shapes, dim=0): - ref_orth_grads.append(dummy_orth_fn(g)) - ref_out = test_param - torch.cat(ref_orth_grads, dim=0) * lr + def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor: + out_list = [[], [], []] + for i in range(x.shape[0]): + out_list[i % 3].append(x[i : i + 1]) + orth_grad_list = [torch.cat(t, dim=0) for t in out_list] + return torch.cat([torch.empty_like(x).fill_(x.max()) for x in orth_grad_list], dim=0) orthogonalized_opt = OrthogonalizedOptimizer( [test_param], - lr=lr, + lr=-1, momentum_beta=0, use_nesterov=False, weight_decay=0.0, use_decoupled_weight_decay=False, - split_qkv=True, - is_qkv_fn=is_qkv_fn, - qkv_split_shapes=(1, 2, 3), fp32_matmul_prec="highest", - orthogonalize_fn=dummy_orth_fn, + scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn, ) orthogonalized_opt.step() + assert not torch.allclose(test_param, test_param.grad) + + ref_out = dummy_interleaved_split_orth_fn(test_param.grad) torch.testing.assert_close( - test_param.data, + test_param, ref_out, atol=0, rtol=0,