From a2156437486db490c3b7833ad620c3b37816774c Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Mon, 13 Oct 2025 15:00:02 -0700 Subject: [PATCH 1/5] update fused param handling Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/muon.py | 12 ++-- .../orthogonalized_optimizer.py | 64 ++++++++--------- tests/test_orthogonalized_optimizer.py | 70 ++++++++++++++++--- 3 files changed, 96 insertions(+), 50 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 3d096e6..f0a7470 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -69,9 +69,9 @@ def __init__( use_nesterov: bool = True, weight_decay: float = 0.01, use_decoupled_weight_decay: bool = True, - split_qkv: bool = False, - is_qkv_fn: Callable[[torch.Tensor], bool] | None = None, - qkv_split_shapes: tuple[int, int, int] | None = None, + split_fused: bool = False, + is_fused_fn: Callable[[torch.Tensor], bool] | None = None, + split_fn: Callable | None = None, fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, @@ -107,9 +107,9 @@ def __init__( use_nesterov, weight_decay, use_decoupled_weight_decay, - split_qkv, - is_qkv_fn, - qkv_split_shapes, + split_fused, + is_fused_fn, + split_fn, fp32_matmul_prec, orthogonalize_fn, scale_factor_fn, diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index bead1fc..8e47e50 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -36,10 +36,12 @@ weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True. - split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False. - is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.). - qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers - representing the sizes of Q, K, V components along the first dimension. + split_fused: Whether to split fused parameters (QKV, GQA, etc.) for preconditioning, default to be False. + is_fused_fn: Function to check if a parameter is fused parameters (QKV, GQA, etc.). + If multiple types of parameters are fused, the function should return True for all of which needs to be + split for preconditioning. + split_fn: Function to split the fused parameters (QKV, GQA, etc.) into a list of parameters. + It should support all the types of parameters that is_fused_fn returns True for. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. """ @@ -62,10 +64,9 @@ class OrthogonalizedOptimizer(optim.Optimizer): arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 `_] Note: - Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide - a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the - leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across - the network must have the same size. + Orthogonalizing fused parameters separately is supported but with limitations. User must provide + a function to check if a weight tensor is fused parameters (QKV, GQA, etc.) as well as the + split function to split the fused parameters into a list of parameters. Args: {_args_doc} @@ -85,9 +86,9 @@ def __init__( use_nesterov: bool, weight_decay: float, use_decoupled_weight_decay: bool, - split_qkv: bool, - is_qkv_fn: Callable[[torch.Tensor], bool] | None, - qkv_split_shapes: tuple[int, int, int] | None, + split_fused: bool, + is_fused_fn: Callable[[torch.Tensor], bool] | None, + split_fn: Callable | None, fp32_matmul_prec: str, orthogonalize_fn: Callable | None = None, scale_factor_fn: Callable | None = None, @@ -105,20 +106,13 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def] scale_factor_fn = return_one - if split_qkv: - assert is_qkv_fn is not None, "is_qkv_fn must be provided when split_qkv is True" - assert qkv_split_shapes is not None, "qkv_split_shapes must be provided when split_qkv is True" - if len(qkv_split_shapes) != 3: - raise ValueError( - f"qkv_split_shapes must be a tuple of 3 integers, got {len(qkv_split_shapes)} elements" - ) - if not all(isinstance(s, int) for s in qkv_split_shapes): - raise ValueError(f"All elements in qkv_split_shapes must be integers, got {qkv_split_shapes}") - if any(s <= 0 for s in qkv_split_shapes): - raise ValueError(f"All elements in qkv_split_shapes must be positive, got {qkv_split_shapes}") - self.split_qkv = split_qkv - self.is_qkv_fn = is_qkv_fn - self.qkv_split_shapes = qkv_split_shapes + if split_fused: + assert is_fused_fn is not None, "is_fused_fn must be provided when split_fused is True" + assert split_fn is not None, "split_fn must be provided when split_fused is True" + + self.split_fused = split_fused + self.is_fused_fn = is_fused_fn + self.split_fn = split_fn self.fp32_matmul_prec = fp32_matmul_prec default_args_dict = dict( @@ -201,15 +195,17 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: Returns: The orthogonalized gradient tensor. """ - if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc] - logging.log_first_n(logging.INFO, f"split qkv with {p.shape} to {self.qkv_split_shapes}", 1) - # split grouped attention parameters (e.g., QKV, GQA, etc.) - qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) - # Apply Newton-Schulz to each component - qkv_whitened = [self.orthogonalize_fn(g) for g in qkv_grads] - qkv_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in qkv_grads] - # Apply individual scales to each component and concatenate - grad = torch.cat([whitened * scale for whitened, scale in zip(qkv_whitened, qkv_scales)]) + if self.split_fused and self.is_fused_fn(p): # type: ignore[misc] + logging.log_first_n(logging.INFO, f"split fused parameters with {p.shape} by {self.split_fn}", 1) + split_grads = self.split_fn(grad) # type: ignore[misc] + + assert sum([g.numel() for g in split_grads]) == grad.numel(), "Split grads do not sum to the original grad" + + split_grads_whitened = [self.orthogonalize_fn(g) for g in split_grads] + split_grad_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in split_grads] + + # TODO(skyw): Revisit whether there are cases that concatenating is not done along dim=0. + grad = torch.cat([whitened * scale for whitened, scale in zip(split_grads_whitened, split_grad_scales)]) else: grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1)) return grad diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 264e2f8..df048a1 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -12,6 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial +from typing import List + import torch import torch.nn as nn from absl.testing import absltest, parameterized @@ -42,9 +45,9 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None: use_nesterov=False, weight_decay=0.5, use_decoupled_weight_decay=True, - split_qkv=False, - is_qkv_fn=None, - qkv_split_shapes=None, + split_fused=False, + is_fused_fn=None, + split_fn=None, fp32_matmul_prec="highest", ) @@ -86,9 +89,9 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> use_nesterov=False, weight_decay=0.0, use_decoupled_weight_decay=False, - split_qkv=False, - is_qkv_fn=None, - qkv_split_shapes=None, + split_fused=False, + is_fused_fn=None, + split_fn=None, fp32_matmul_prec="highest", ) @@ -114,7 +117,7 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> rtol=0, ) - def test_split_qkv_matches_ref(self) -> None: + def test_split_stacked_qkv_matches_ref(self) -> None: test_param = torch.randint(-5, 5, (6, 7), dtype=torch.float32, device="cuda") test_param.grad = torch.randint_like(test_param, -5, 5) split_shapes = (1, 2, 3) @@ -126,6 +129,8 @@ def is_qkv_fn(x: torch.Tensor) -> bool: def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: return x * x + split_fn = partial(torch.split, split_size_or_sections=split_shapes, dim=0) + ref_orth_grads = [] for g in torch.split(test_param.grad, split_shapes, dim=0): ref_orth_grads.append(dummy_orth_fn(g)) @@ -138,9 +143,9 @@ def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: use_nesterov=False, weight_decay=0.0, use_decoupled_weight_decay=False, - split_qkv=True, - is_qkv_fn=is_qkv_fn, - qkv_split_shapes=(1, 2, 3), + split_fused=True, + is_fused_fn=is_qkv_fn, + split_fn=split_fn, fp32_matmul_prec="highest", orthogonalize_fn=dummy_orth_fn, ) @@ -153,6 +158,51 @@ def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: rtol=0, ) + def test_split_fn_interleaved(self) -> None: + """Test a three way interleaved split function. + + With 0 weights and lr -1, returned param should match orthogonalized grads. + """ + test_param = torch.zeros((6, 7), dtype=torch.float32, device="cuda") + test_param.grad = torch.empty_like(test_param.data) + + for i in range(test_param.shape[0]): + test_param.grad[i] = i + 1 + + def three_way_interleaved_split_fn(x: torch.Tensor) -> List[torch.Tensor]: + out_list = [[], [], []] + for i in range(x.shape[0]): + out_list[i % 3].append(x[i : i + 1]) + return [torch.cat(t, dim=0) for t in out_list] + + def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x).fill_(x.max()) + + orthogonalized_opt = OrthogonalizedOptimizer( + [test_param], + lr=-1, + momentum_beta=0, + use_nesterov=False, + weight_decay=0.0, + use_decoupled_weight_decay=False, + split_fused=True, + is_fused_fn=lambda x: True, + split_fn=three_way_interleaved_split_fn, + orthogonalize_fn=dummy_orth_fn, + fp32_matmul_prec="highest", + ) + orthogonalized_opt.step() + + assert not torch.allclose(test_param, test_param.grad) + + ref_out = torch.cat([dummy_orth_fn(g) for g in three_way_interleaved_split_fn(test_param.grad)], dim=0) + torch.testing.assert_close( + test_param, + ref_out, + atol=0, + rtol=0, + ) + class MuonTest(parameterized.TestCase): @parameterized.parameters( From c1787251494b043541236b9b305c5ce25cc644d8 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Mon, 13 Oct 2025 15:02:14 -0700 Subject: [PATCH 2/5] update callable hint Signed-off-by: Hao Wu --- emerging_optimizers/orthogonalized_optimizers/muon.py | 4 ++-- .../orthogonalized_optimizers/orthogonalized_optimizer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index f0a7470..17e0b73 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable +from typing import Callable, List import torch from absl import logging @@ -71,7 +71,7 @@ def __init__( use_decoupled_weight_decay: bool = True, split_fused: bool = False, is_fused_fn: Callable[[torch.Tensor], bool] | None = None, - split_fn: Callable | None = None, + split_fn: Callable[[torch.Tensor], List[torch.Tensor]] | None = None, fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index 8e47e50..b38b3e0 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, List # TODO(@boxiangw): remove this once bump to python 3.12 @@ -88,7 +88,7 @@ def __init__( use_decoupled_weight_decay: bool, split_fused: bool, is_fused_fn: Callable[[torch.Tensor], bool] | None, - split_fn: Callable | None, + split_fn: Callable[[torch.Tensor], List[torch.Tensor]] | None, fp32_matmul_prec: str, orthogonalize_fn: Callable | None = None, scale_factor_fn: Callable | None = None, From faee14cda25d2944e0e98114eed20c0ac1bccc2d Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Oct 2025 13:33:23 -0700 Subject: [PATCH 3/5] get rid off split fn altogether Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/muon.py | 34 ++++------ .../orthogonalized_optimizer.py | 47 ++------------ tests/test_muon_utils.py | 19 ------ tests/test_orthogonalized_optimizer.py | 64 ++----------------- 4 files changed, 25 insertions(+), 139 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 17e0b73..1d96967 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import Callable, List import torch from absl import logging @@ -69,9 +67,6 @@ def __init__( use_nesterov: bool = True, weight_decay: float = 0.01, use_decoupled_weight_decay: bool = True, - split_fused: bool = False, - is_fused_fn: Callable[[torch.Tensor], bool] | None = None, - split_fn: Callable[[torch.Tensor], List[torch.Tensor]] | None = None, fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, @@ -95,10 +90,14 @@ def __init__( f"Correctness of Triton kernel on SM {sm_version} cannot be guaranteed. Setting use_syrk to False." ) use_syrk = False - orthogonalize_fn = partial( - newton_schulz, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk - ) - scale_factor_fn = partial(get_muon_scale_factor, mode=scale_mode, extra_scale_factor=extra_scale_factor) + + def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: + logging.debug( + f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, {scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" + ) + orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk) + scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) + return orth_grad * scale_factor * extra_scale_factor super().__init__( params, @@ -107,21 +106,15 @@ def __init__( use_nesterov, weight_decay, use_decoupled_weight_decay, - split_fused, - is_fused_fn, - split_fn, fp32_matmul_prec, - orthogonalize_fn, - scale_factor_fn, + scaled_orthogonalize_fn, ) Muon.__doc__ = Muon.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr] -def get_muon_scale_factor( - size_out: int, size_in: int, mode: str = "spectral", extra_scale_factor: float = 1.0 -) -> float: +def get_muon_scale_factor(size_out: int, size_in: int, mode: str = "spectral") -> float: """Get the scale for the update. Default mode is "spectral", which is the mode that allows for learning rate transferability from AdamW. @@ -133,19 +126,18 @@ def get_muon_scale_factor( size_out: The size of the output tensor. size_in: The size of the input tensor. mode: The mode to use for the scale. - extra_scale_factor: The additional scale factor to use for the update. Returns: The scale factor for the update. """ if mode == "shape_scaling": # Suggested by Muon (https://kellerjordan.github.io/posts/muon/) - return extra_scale_factor * max(1, size_out / size_in) ** 0.5 + return max(1, size_out / size_in) ** 0.5 elif mode == "spectral": # Suggested by K. Jordan and Kimi (https://arxiv.org/abs/2502.16982) - return extra_scale_factor * max(size_out, size_in) ** 0.5 + return max(size_out, size_in) ** 0.5 elif mode == "unit_rms_norm": # Suggested by Scion (https://arxiv.org/abs/2502.07529) and Bernstein et al. # (https://jeremybernste.in/writing/deriving-muon) - return extra_scale_factor * (size_out / size_in) ** 0.5 + return (size_out / size_in) ** 0.5 else: raise ValueError(f"Invalid mode for Muon update scale factor: {mode}") diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index b38b3e0..c36e954 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, List +from typing import Any, Callable # TODO(@boxiangw): remove this once bump to python 3.12 @@ -86,33 +86,13 @@ def __init__( use_nesterov: bool, weight_decay: float, use_decoupled_weight_decay: bool, - split_fused: bool, - is_fused_fn: Callable[[torch.Tensor], bool] | None, - split_fn: Callable[[torch.Tensor], List[torch.Tensor]] | None, fp32_matmul_prec: str, - orthogonalize_fn: Callable | None = None, - scale_factor_fn: Callable | None = None, + scaled_orthogonalize_fn: Callable | None = None, **kwargs: Any, ): - if orthogonalize_fn is None: - logging.warning("orthogonalize_fn not provided. Using noop") - orthogonalize_fn = torch.nn.Identity() - - if scale_factor_fn is None: - logging.warning("scale_factor_fn not provided. Using default scale_factor_fn.") - - def return_one(*args, **kwargs): # type: ignore[no-untyped-def] - return 1.0 - - scale_factor_fn = return_one - - if split_fused: - assert is_fused_fn is not None, "is_fused_fn must be provided when split_fused is True" - assert split_fn is not None, "split_fn must be provided when split_fused is True" - - self.split_fused = split_fused - self.is_fused_fn = is_fused_fn - self.split_fn = split_fn + if scaled_orthogonalize_fn is None: + logging.warning("scaled_orthogonalize_fn not provided. Using noop") + scaled_orthogonalize_fn = torch.nn.Identity() self.fp32_matmul_prec = fp32_matmul_prec default_args_dict = dict( @@ -125,8 +105,7 @@ def return_one(*args, **kwargs): # type: ignore[no-untyped-def] ) super().__init__(params, default_args_dict) - self.orthogonalize_fn = orthogonalize_fn - self.scale_factor_fn = scale_factor_fn + self.scaled_orthogonalize_fn = scaled_orthogonalize_fn @torch.no_grad() # type: ignore[misc] @override @@ -195,19 +174,7 @@ def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: Returns: The orthogonalized gradient tensor. """ - if self.split_fused and self.is_fused_fn(p): # type: ignore[misc] - logging.log_first_n(logging.INFO, f"split fused parameters with {p.shape} by {self.split_fn}", 1) - split_grads = self.split_fn(grad) # type: ignore[misc] - - assert sum([g.numel() for g in split_grads]) == grad.numel(), "Split grads do not sum to the original grad" - - split_grads_whitened = [self.orthogonalize_fn(g) for g in split_grads] - split_grad_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in split_grads] - - # TODO(skyw): Revisit whether there are cases that concatenating is not done along dim=0. - grad = torch.cat([whitened * scale for whitened, scale in zip(split_grads_whitened, split_grad_scales)]) - else: - grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1)) + grad = self.scaled_orthogonalize_fn(grad) return grad diff --git a/tests/test_muon_utils.py b/tests/test_muon_utils.py index 4deb053..0ee1d4c 100644 --- a/tests/test_muon_utils.py +++ b/tests/test_muon_utils.py @@ -191,25 +191,6 @@ def test_get_scale_factor(self, size_pairs, mode): else: raise ValueError(f"Invalid mode: {mode}") - def test_qkv_split_shapes_validation(self): - """Test validation of qkv_split_shapes parameter""" - dummy_param = torch.nn.Parameter(torch.randn(4, 4)) - dummy_args = dict(split_qkv=True, is_qkv_fn=lambda x: True) - # Test non-integer values - with self.assertRaises(ValueError) as cm: - muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512.5, 256, 256)) - self.assertIn("must be integers", str(cm.exception)) - - # Test negative values - with self.assertRaises(ValueError) as cm: - muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, -256, 256)) - self.assertIn("must be positive", str(cm.exception)) - - # Test wrong number of elements - with self.assertRaises(ValueError) as cm: - muon.Muon([dummy_param], **dummy_args, qkv_split_shapes=(512, 256)) - self.assertIn("tuple of 3 integers", str(cm.exception)) - @absltest.skipIf( _SM_VERSION not in ((8, 0), (9, 0), (10, 0), (10, 3)), diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index df048a1..6bdee12 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial -from typing import List import torch import torch.nn as nn @@ -45,9 +43,6 @@ def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None: use_nesterov=False, weight_decay=0.5, use_decoupled_weight_decay=True, - split_fused=False, - is_fused_fn=None, - split_fn=None, fp32_matmul_prec="highest", ) @@ -89,9 +84,6 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> use_nesterov=False, weight_decay=0.0, use_decoupled_weight_decay=False, - split_fused=False, - is_fused_fn=None, - split_fn=None, fp32_matmul_prec="highest", ) @@ -117,47 +109,6 @@ def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> rtol=0, ) - def test_split_stacked_qkv_matches_ref(self) -> None: - test_param = torch.randint(-5, 5, (6, 7), dtype=torch.float32, device="cuda") - test_param.grad = torch.randint_like(test_param, -5, 5) - split_shapes = (1, 2, 3) - lr = 2.0 - - def is_qkv_fn(x: torch.Tensor) -> bool: - return x.shape == torch.Size([6, 7]) - - def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: - return x * x - - split_fn = partial(torch.split, split_size_or_sections=split_shapes, dim=0) - - ref_orth_grads = [] - for g in torch.split(test_param.grad, split_shapes, dim=0): - ref_orth_grads.append(dummy_orth_fn(g)) - ref_out = test_param - torch.cat(ref_orth_grads, dim=0) * lr - - orthogonalized_opt = OrthogonalizedOptimizer( - [test_param], - lr=lr, - momentum_beta=0, - use_nesterov=False, - weight_decay=0.0, - use_decoupled_weight_decay=False, - split_fused=True, - is_fused_fn=is_qkv_fn, - split_fn=split_fn, - fp32_matmul_prec="highest", - orthogonalize_fn=dummy_orth_fn, - ) - orthogonalized_opt.step() - - torch.testing.assert_close( - test_param.data, - ref_out, - atol=0, - rtol=0, - ) - def test_split_fn_interleaved(self) -> None: """Test a three way interleaved split function. @@ -169,14 +120,12 @@ def test_split_fn_interleaved(self) -> None: for i in range(test_param.shape[0]): test_param.grad[i] = i + 1 - def three_way_interleaved_split_fn(x: torch.Tensor) -> List[torch.Tensor]: + def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor: out_list = [[], [], []] for i in range(x.shape[0]): out_list[i % 3].append(x[i : i + 1]) - return [torch.cat(t, dim=0) for t in out_list] - - def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: - return torch.empty_like(x).fill_(x.max()) + orth_grad_list = [torch.cat(t, dim=0) for t in out_list] + return torch.cat([torch.empty_like(x).fill_(x.max()) for x in orth_grad_list], dim=0) orthogonalized_opt = OrthogonalizedOptimizer( [test_param], @@ -185,17 +134,14 @@ def dummy_orth_fn(x: torch.Tensor) -> torch.Tensor: use_nesterov=False, weight_decay=0.0, use_decoupled_weight_decay=False, - split_fused=True, - is_fused_fn=lambda x: True, - split_fn=three_way_interleaved_split_fn, - orthogonalize_fn=dummy_orth_fn, fp32_matmul_prec="highest", + scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn, ) orthogonalized_opt.step() assert not torch.allclose(test_param, test_param.grad) - ref_out = torch.cat([dummy_orth_fn(g) for g in three_way_interleaved_split_fn(test_param.grad)], dim=0) + ref_out = dummy_interleaved_split_orth_fn(test_param.grad) torch.testing.assert_close( test_param, ref_out, From 27aa01c866403e55cd7fdee2343c25e10f20c0f0 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Oct 2025 13:42:48 -0700 Subject: [PATCH 4/5] comment update Signed-off-by: Hao Wu --- .../orthogonalized_optimizers/muon.py | 3 ++- .../orthogonalized_optimizer.py | 25 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index 1d96967..b17283c 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -93,7 +93,8 @@ def __init__( def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: logging.debug( - f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, {scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" + f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, " + f"{scale_mode} scale mode, extra_scale_factor={extra_scale_factor}" ) orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=use_syrk) scale_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode=scale_mode) diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index c36e954..a907273 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -36,12 +36,6 @@ weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True. - split_fused: Whether to split fused parameters (QKV, GQA, etc.) for preconditioning, default to be False. - is_fused_fn: Function to check if a parameter is fused parameters (QKV, GQA, etc.). - If multiple types of parameters are fused, the function should return True for all of which needs to be - split for preconditioning. - split_fn: Function to split the fused parameters (QKV, GQA, etc.) into a list of parameters. - It should support all the types of parameters that is_fused_fn returns True for. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. """ @@ -70,8 +64,7 @@ class OrthogonalizedOptimizer(optim.Optimizer): Args: {_args_doc} - orthogonalize_fn: Function to orthogonalize the updates. - scale_factor_fn: Function to compute the scale factor for the update. + scaled_orthogonalize_fn: Function to orthogonalize and scale the updates. **kwargs: Arguments passed through to the base optimizer. Note: @@ -155,7 +148,8 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: grad = exp_avg with utils.fp32_matmul_precision(self.fp32_matmul_prec): - grad = self.orthogonalize(p, grad) + group_kwargs = {k: v for k, v in group.items() if k != "params"} + grad = self.orthogonalize(p, grad, **group_kwargs) # perform weight update # scale is applied to have update RMS == 1 @@ -163,13 +157,20 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: return loss - def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: """Orthogonalize the momentum. + The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can + override this function to implement different orthogonalization logic as well as split fused parameters. + For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if + the parameter is a fused parameter and should be split for preconditioning. + Args: - p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of - information is only available in the param tensor, attributes for example. + p: The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of + information is only available in the param tensor, attributes for example. Although not used in + this default orthogonalize function. grad: The momentum tensor. + **kwargs: keyword arguments of the param_group that p was belonged to. Returns: The orthogonalized gradient tensor. From 0ed453a0ad5df7f79486c855397e6cd1d06c987a Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Fri, 17 Oct 2025 11:01:10 -0700 Subject: [PATCH 5/5] Update docstring and add example to OrthogonalizedOptimizer Signed-off-by: Hao Wu --- .../orthogonalized_optimizer.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index a907273..7d49c75 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -44,7 +44,8 @@ class OrthogonalizedOptimizer(optim.Optimizer): """Base class for orthogonalized optimizers. This class is a wrapper around a base optimizer that performs orthogonalization on the updates. - The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers: + The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the + following papers: - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.* In International Conference on Artificial Intelligence and Statistics (2015a). @@ -58,9 +59,29 @@ class OrthogonalizedOptimizer(optim.Optimizer): arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 `_] Note: - Orthogonalizing fused parameters separately is supported but with limitations. User must provide - a function to check if a weight tensor is fused parameters (QKV, GQA, etc.) as well as the - split function to split the fused parameters into a list of parameters. + OrthogonalizedOptimizer as base class doesn't directly support orthogonalizing fused parameters separately. + Subclass can override the orthogonalize function to support this, see example below. + + .. code-block:: python + :caption: Split QKV example + + class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer): + def __init__(..., split_qkv_shapes): + super().__init__(...) + self.qkv_split_shapes = split_qkv_shapes + + def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor: + + # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the + # scaled_orthogonalize_fn. + if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False): + qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) + qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads] + grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized]) + else: + grad = self.scaled_orthogonalize_fn(grad) + + return grad Args: {_args_doc}