diff --git a/deepmd/pt/optimizer/__init__.py b/deepmd/pt/optimizer/__init__.py index db340b3bb9..4c069cf2ea 100644 --- a/deepmd/pt/optimizer/__init__.py +++ b/deepmd/pt/optimizer/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .adamuon import ( + AdaMuonOptimizer, +) from .KFWrapper import ( KFOptimizerWrapper, ) @@ -6,4 +9,4 @@ LKFOptimizer, ) -__all__ = ["KFOptimizerWrapper", "LKFOptimizer"] +__all__ = ["AdaMuonOptimizer", "KFOptimizerWrapper", "LKFOptimizer"] diff --git a/deepmd/pt/optimizer/adamuon.py b/deepmd/pt/optimizer/adamuon.py new file mode 100644 index 0000000000..eaca6aefdf --- /dev/null +++ b/deepmd/pt/optimizer/adamuon.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +AdaMuon optimizer for DeePMD-kit PyTorch backend. + +AdaMuon combines Newton-Schulz orthogonalization with adaptive per-element +second-moment normalization and RMS-aligned global scaling. It applies sign-stabilized +orthogonal direction for improved training stability. + +Key improvements over vanilla Muon: +- Sign-stabilized orthogonal direction +- Per-element second-moment (v_buffer) normalization +- RMS-aligned global scaling + +References +---------- +.. [1] Ethan Smith et al., "AdaMuon: Adaptive Muon Optimizer," arXiv:2507.11005, 2025. + https://arxiv.org/abs/2507.11005 +.. [2] AdaMuon GitHub Repository. + https://github.com/ethansmith2000/AdaMuon +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +from torch.optim.optimizer import ( + Optimizer, +) + +if TYPE_CHECKING: + from collections.abc import ( + Callable, + Iterable, + ) + + +def zeropower_via_newtonschulz5( + G: torch.Tensor, + steps: int = 5, + eps: float = 1e-8, +) -> torch.Tensor: + """ + Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration. + + Uses quintic Newton-Schulz iteration to compute the orthogonal component of the + input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T. + + This implementation always performs Newton-Schulz in bfloat16 and returns a + bfloat16 tensor. + + Parameters + ---------- + G : torch.Tensor + Input matrix to orthogonalize with shape (..., M, N). + steps : int + Number of Newton-Schulz iterations with default 5. + eps : float + Numerical stability epsilon for norm clamping with default 1e-8. + + Returns + ------- + torch.Tensor + Orthogonalized matrix in bfloat16 with same shape as input. + + Raises + ------ + ValueError + If G has fewer than 2 dimensions. + ValueError + If steps >= 100 (guard for efficiency). + """ + # === Step 1. Validate === + if G.ndim < 2: + raise ValueError("Input must have at least 2 dimensions (..., M, N).") + if steps >= 100: + raise ValueError("Number of steps must be less than 100 for efficiency.") + + a, b, c = (3.4445, -4.7750, 2.0315) + + # === Step 2. Cast to bf16 === + X = G.to(dtype=torch.bfloat16) + + # === Step 3. Transpose tall matrices === + if X.size(-2) > X.size(-1): + X = X.mT + + # === Step 4. Normalize Frobenius norm to at most 1 === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps) + + # === Step 5. Newton-Schulz iterations with fused GEMM === + for _ in range(steps): + A = X @ X.mT + # gram_update = b*A + c*(A@A) via addmm/baddbmm + # X = a*X + gram_update@X via addmm/baddbmm + if X.ndim == 2: + gram_update = torch.addmm(A, A, A, beta=b, alpha=c) + X = torch.addmm(X, gram_update, X, beta=a, alpha=1.0) + else: + gram_update = torch.baddbmm(A, A, A, beta=b, alpha=c) + X = torch.baddbmm(X, gram_update, X, beta=a, alpha=1.0) + + # === Step 6. Transpose back if needed === + if G.size(-2) > G.size(-1): + X = X.mT + + return X + + +def _prepare_muon_momentum( + grad: torch.Tensor, + momentum_buffer: torch.Tensor, + beta: float, + nesterov: bool, +) -> tuple[torch.Tensor, tuple[int, ...]]: + """ + Prepare momentum update and reshape for batched Newton-Schulz. + + Parameters + ---------- + grad : torch.Tensor + Gradient tensor. + momentum_buffer : torch.Tensor + Momentum buffer (will be updated in-place). + beta : float + Momentum coefficient. + nesterov : bool + Whether to use Nesterov momentum. + + Returns + ------- + update : torch.Tensor + Reshaped update tensor with shape (M, N). + original_shape : tuple[int, ...] + Original shape before reshape. + """ + # === Step 1. Update momentum buffer === + momentum_buffer.lerp_(grad, 1 - beta) + update = grad.lerp(momentum_buffer, beta) if nesterov else momentum_buffer + + # === Step 2. Handle tensor -> matrix reshape === + original_shape = update.shape + if update.ndim > 2: + update = update.reshape(update.shape[0], -1) + + return update, original_shape + + +class AdaMuonOptimizer(Optimizer): + """ + AdaMuon optimizer with adaptive second-moment normalization and auxiliary Adam. + + This optimizer applies different update rules based on parameter dimensionality: + - For 2D+ parameters (weight matrices): AdaMuon update with sign-stabilized + Newton-Schulz orthogonalization and per-element v_buffer normalization. + - For 1D parameters (biases, layer norms): Standard Adam update. + + Key AdaMuon features: + - Sign-stabilized orthogonal direction: Applies sign() before orthogonalization. + - Per-element second-moment normalization using momentum coefficient. + - RMS-aligned global scaling: 0.2 * sqrt(m * n) / norm. + + Parameters + ---------- + params : iterable + Iterable of parameters to optimize. + lr : float + Learning rate with default 1e-3. + momentum : float + Momentum coefficient for AdaMuon with default 0.95. + weight_decay : float + Weight decay coefficient (applied only to >=2D params) with default 0.001. + ns_steps : int + Number of Newton-Schulz iterations with default 5. + adam_betas : tuple[float, float] + Adam beta coefficients with default (0.9, 0.95). + adam_eps : float + Adam epsilon with default 1e-8. + nesterov : bool + Whether to use Nesterov momentum for AdaMuon with default True. + lr_adjust : float + Learning rate adjustment factor for Adam (1D params). + - If lr_adjust <= 0: use match-RMS scaling for AdaMuon update, + scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly. + - If lr_adjust > 0: use rectangular correction for AdaMuon update, + scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate. + Default is 10.0 (Adam lr = lr/10). + lr_adjust_coeff : float + Coefficient for match-RMS scaling with default 0.2. + Only effective when lr_adjust <= 0. + eps : float + Epsilon for v_buffer sqrt and global scaling normalization with default 1e-8. + + Examples + -------- + >>> optimizer = AdaMuonOptimizer(model.parameters(), lr=1e-3) + >>> for epoch in range(epochs): + ... optimizer.zero_grad() + ... loss.backward() + ... optimizer.step() + """ + + def __init__( + self, + params: Iterable[torch.Tensor] | Iterable[dict[str, Any]], + lr: float = 1e-3, + momentum: float = 0.95, + weight_decay: float = 0.001, + ns_steps: int = 5, + adam_betas: tuple[float, float] = (0.9, 0.95), + adam_eps: float = 1e-8, + nesterov: bool = True, + lr_adjust: float = 10.0, + lr_adjust_coeff: float = 0.2, + eps: float = 1e-8, + ) -> None: + defaults = { + "lr": lr, + "momentum": momentum, + "weight_decay": weight_decay, + "ns_steps": ns_steps, + "adam_betas": adam_betas, + "adam_eps": adam_eps, + "nesterov": nesterov, + "lr_adjust": lr_adjust, + "lr_adjust_coeff": lr_adjust_coeff, + "eps": eps, + } + super().__init__(params, defaults) + + @torch.no_grad() + def step( + self, + closure: Callable[[], torch.Tensor] | None = None, + ) -> torch.Tensor | None: + """ + Perform a single optimization step. + + Parameters + ---------- + closure : callable, optional + A closure that reevaluates the model and returns the loss. + + Returns + ------- + loss : torch.Tensor, optional + The loss value if closure is provided. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + weight_decay = group["weight_decay"] + ns_steps = group["ns_steps"] + adam_betas = group["adam_betas"] + adam_eps = group["adam_eps"] + nesterov = group["nesterov"] + lr_adjust = group["lr_adjust"] + lr_adjust_coeff = group["lr_adjust_coeff"] + eps = group["eps"] + + # === Step 1. Collect params with gradients and separate by type === + muon_params: list[torch.Tensor] = [] # For weight decay (>=2D only) + muon_entries: list[ + tuple[torch.nn.Parameter, torch.Tensor, tuple[int, ...]] + ] = [] + # Adam batch lists + adam_params: list[torch.Tensor] = [] + adam_grads_fp32: list[torch.Tensor] = [] + adam_exp_avgs: list[torch.Tensor] = [] + adam_exp_avg_sqs: list[torch.Tensor] = [] + + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad + if grad.dtype != p.dtype: + grad = grad.to(dtype=p.dtype) + + state = self.state[p] + + if p.ndim >= 2: + # AdaMuon path: collect for weight decay + muon_params.append(p) + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + update, orig_shape = _prepare_muon_momentum( + grad, state["momentum_buffer"], momentum, nesterov + ) + muon_entries.append((p, update, orig_shape)) + else: + # Adam path: state tensors forced to FP32 + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + adam_params.append(p) + # Cast grad to FP32 for Adam computation + adam_grads_fp32.append(grad.float()) + adam_exp_avgs.append(state["exp_avg"]) + adam_exp_avg_sqs.append(state["exp_avg_sq"]) + + # === Step 2. Foreach weight decay (only >=2D params) === + if weight_decay > 0 and muon_params: + torch._foreach_mul_(muon_params, 1.0 - lr * weight_decay) + + # === Step 3. Adam update for 1D params (FP32 computation) === + if adam_params: + # Determine Adam learning rate + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + + # Update momentum estimates in FP32 + torch._foreach_lerp_(adam_exp_avgs, adam_grads_fp32, 1 - adam_betas[0]) + grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32) + torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + + # Compute updates with bias correction (per-param beta_pow) + for i, p in enumerate(adam_params): + state = self.state[p] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr / bias_corr1 + # FP32 computation: compute full delta in FP32, then cast once + denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(adam_eps) + delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) + p.add_(delta_fp32.to(p.dtype)) + + # === Step 4. Batched Newton-Schulz for AdaMuon parameters === + if not muon_entries: + continue + + # Group by (rows, cols, device) for batched processing + buckets: dict[ + tuple[int, int, torch.device], + list[tuple[torch.nn.Parameter, torch.Tensor, tuple[int, ...]]], + ] = {} + for entry in muon_entries: + p, update, orig_shape = entry + key = (update.shape[0], update.shape[1], update.device) + if key not in buckets: + buckets[key] = [] + buckets[key].append(entry) + + # Process each bucket + for (rows, cols, _device), bucket in buckets.items(): + m, n = rows, cols + + # === Pre-compute bucket-level scaling constants === + # RMS-aligned scale: 0.2 * sqrt(m * n) + rms_scale = 0.2 * math.sqrt(float(m * n)) + # Shape-dependent lr correction (based on lr_adjust mode) + if lr_adjust <= 0: + adj_scale = lr_adjust_coeff * math.sqrt(float(max(m, n))) + else: + adj_scale = max(1.0, m / n) ** 0.5 + + # === Step 4.1 Stack sign matrices and orthogonalize === + # Always stack to 3D (B, m, n) for unified indexing + stacked = torch.stack( + [torch.sign(item[1].contiguous()) for item in bucket], dim=0 + ) + orth_stacked = zeropower_via_newtonschulz5(stacked, steps=ns_steps) + + # === Step 4.2 Per-element v_buffer normalization and update === + for i, (p, update, orig_shape) in enumerate(bucket): + state = self.state[p] + # orth_stacked is always 3D, use unified indexing + orth_vec = ( + orth_stacked[i].flatten().float() + ) # Cast to FP32 for stability + + # === Step 4.2.1 Initialize or retrieve v_buffer === + if "v_buffer" not in state: + state["v_buffer"] = torch.zeros( + orth_vec.numel(), + dtype=torch.float32, + device=orth_vec.device, + ) + v = state["v_buffer"] + + # === Step 4.2.2 EMA update and element-wise normalization === + # v = momentum * v + (1 - momentum) * orth_vec^2 + v.mul_(momentum).addcmul_(orth_vec, orth_vec, value=1.0 - momentum) + orth_vec = orth_vec / (v.sqrt().add_(eps)) + + # === Step 4.2.3 RMS-aligned global scaling === + # scale = rms_scale / (norm + eps) + norm_val = orth_vec.norm() + orth_vec.div_(norm_val + eps).mul_(rms_scale) + + # === Step 4.2.4 Shape-dependent lr correction === + orth_vec.mul_(adj_scale) + + # Reshape back and update parameter + p.add_( + orth_vec.view(m, n).reshape(orig_shape).to(p.dtype), alpha=-lr + ) + + return loss diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 24440e19de..99e376fb5b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -42,6 +42,7 @@ get_zbl_model, ) from deepmd.pt.optimizer import ( + AdaMuonOptimizer, KFOptimizerWrapper, LKFOptimizer, ) @@ -162,7 +163,12 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), "kf_start_pref_f": params.get("kf_start_pref_f", 1), "kf_limit_pref_f": params.get("kf_limit_pref_f", 1), + # Common parameters "weight_decay": params.get("weight_decay", 0.001), + # Muon/AdaMuon parameters + "momentum": params.get("momentum", 0.95), + "adam_beta1": params.get("adam_beta1", 0.9), + "adam_beta2": params.get("adam_beta2", 0.95), } return opt_type, opt_param @@ -698,6 +704,25 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.optimizer = LKFOptimizer( self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"] ) + elif self.opt_type == "AdaMuon": + self.optimizer = AdaMuonOptimizer( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + momentum=float(self.opt_param.get("momentum", 0.95)), + weight_decay=float(self.opt_param.get("weight_decay", 0.001)), + adam_betas=( + float(self.opt_param.get("adam_beta1", 0.9)), + float(self.opt_param.get("adam_beta2", 0.95)), + ), + lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), + lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), + ) + if optimizer_state_dict is not None and self.restart_training: + self.optimizer.load_state_dict(optimizer_state_dict) + self.scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, + lambda step: warm_up_linear(step + self.start_step, self.warmup_steps), + ) else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") @@ -768,7 +793,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() - if self.opt_type in ["Adam", "AdamW"]: + if self.opt_type in ["Adam", "AdamW", "AdaMuon"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 7fcc117ab5..dba503d520 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3370,6 +3370,64 @@ def training_args( [], optional=True, ), + Argument( + "AdaMuon", + dict, + [ + Argument( + "momentum", + float, + optional=True, + default=0.95, + doc=doc_only_pt_supported + + "Momentum coefficient for AdaMuon optimizer.", + ), + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_only_pt_supported + + "Adam beta1 coefficient for AdaMuon optimizer.", + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.95, + doc=doc_only_pt_supported + + "Adam beta2 coefficient for AdaMuon optimizer.", + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.001, + doc=doc_only_pt_supported + + "Weight decay coefficient. Applied only to >=2D parameters (AdaMuon path).", + ), + Argument( + "lr_adjust", + float, + optional=True, + default=10.0, + doc=doc_only_pt_supported + + "Learning rate adjustment factor for Adam (1D params). " + "If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. " + "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1.0, m/n))), Adam uses lr/lr_adjust.", + ), + Argument( + "lr_adjust_coeff", + float, + optional=True, + default=0.2, + doc=doc_only_pt_supported + + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", + ), + ], + [], + optional=True, + ), ], optional=True, default_tag="Adam", diff --git a/source/tests/pt/test_adamuon.py b/source/tests/pt/test_adamuon.py new file mode 100644 index 0000000000..4567833948 --- /dev/null +++ b/source/tests/pt/test_adamuon.py @@ -0,0 +1,387 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for AdaMuonOptimizer.""" + +import unittest + +# NOTE: avoid torch thread reconfiguration errors during import. +import torch + +torch_set_num_interop_threads = getattr(torch, "set_num_interop_threads", None) +torch_set_num_threads = getattr(torch, "set_num_threads", None) +if torch_set_num_interop_threads is not None: + torch.set_num_interop_threads = lambda *args, **kwargs: None # type: ignore[assignment] +if torch_set_num_threads is not None: + torch.set_num_threads = lambda *args, **kwargs: None # type: ignore[assignment] + +from deepmd.pt.optimizer.adamuon import ( + AdaMuonOptimizer, + zeropower_via_newtonschulz5, +) +from deepmd.pt.utils import ( + env, +) + + +class TestNewtonSchulzOrthogonalization(unittest.TestCase): + """Test Newton-Schulz orthogonalization algorithm for AdaMuon.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_square_matrix_approximate_orthogonality(self) -> None: + """Test that output is approximately orthogonal for square matrices.""" + torch.manual_seed(42) + G = torch.randn(4, 4, dtype=torch.float32, device=self.device) + X = zeropower_via_newtonschulz5(G, steps=5) + + # X @ X.T should be approximately identity (diagonal dominant) + # Note: NS returns bf16, so use relaxed tolerance + XXT = X.float() @ X.float().T + # Check diagonal elements are close to 1 (relaxed tolerance for bf16 + 5 iterations) + diag = torch.diag(XXT) + self.assertTrue( + torch.allclose( + diag, torch.ones(4, dtype=torch.float32, device=self.device), atol=0.5 + ), + f"Diagonal not close to 1: {diag}", + ) + # Check off-diagonal elements are relatively small + off_diag_norm = (XXT - torch.diag(diag)).norm() + self.assertLess( + off_diag_norm, 1.5, f"Off-diagonal norm too large: {off_diag_norm}" + ) + + def test_output_shape_preserved(self) -> None: + """Test that output shape matches input shape and dtype is bf16.""" + torch.manual_seed(42) + for shape in [(4, 4), (6, 4), (4, 6), (3, 4, 4)]: + G = torch.randn(*shape, dtype=torch.float32, device=self.device) + X = zeropower_via_newtonschulz5(G, steps=5) + self.assertEqual( + X.shape, G.shape, f"Shape mismatch for input shape {shape}" + ) + self.assertEqual( + X.dtype, torch.bfloat16, f"Output should be bf16, got {X.dtype}" + ) + + +class TestAdaMuonOptimizerBasic(unittest.TestCase): + """Test AdaMuonOptimizer class basic functionality.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_optimizer_step_smoke(self) -> None: + """Smoke test: step runs and updates both >=2D and 1D params.""" + torch.manual_seed(42) + # Model with 2D weights and 1D biases/LayerNorm params + model = torch.nn.Sequential( + torch.nn.Linear(10, 20, bias=True, device=self.device), + torch.nn.LayerNorm(20, device=self.device), + torch.nn.ReLU(), + torch.nn.Linear(20, 5, bias=True, device=self.device), + ) + + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02) + + # Dummy forward-backward pass + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + + # Store initial params + initial_params = [p.clone() for p in model.parameters()] + + # Optimizer step + optimizer.step() + + # Verify all parameters with gradients changed + for i, (p, init_p) in enumerate(zip(model.parameters(), initial_params)): + if p.grad is not None: + self.assertFalse( + torch.allclose(p, init_p), + f"Parameter {i} did not change after optimizer step", + ) + + def test_adamuon_for_2d_adam_for_1d(self) -> None: + """Test that AdaMuon is applied to 2D params and Adam to 1D params.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02) + + # Forward-backward + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Check state - weight (2D) should have momentum_buffer and v_buffer + weight_state = optimizer.state[model.weight] + self.assertIn("momentum_buffer", weight_state) + self.assertNotIn("exp_avg", weight_state) + + # Bias (1D) should have exp_avg and exp_avg_sq + bias_state = optimizer.state[model.bias] + self.assertIn("exp_avg", bias_state) + self.assertIn("exp_avg_sq", bias_state) + self.assertNotIn("momentum_buffer", bias_state) + + +class TestAdaMuonOptimizerState(unittest.TestCase): + """Test AdaMuonOptimizer state creation and management.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_2d_state_creation(self) -> None: + """State creation test: verify momentum_buffer and v_buffer for 2D path.""" + torch.manual_seed(42) + model = torch.nn.Linear(8, 16, bias=False, device=self.device) + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02) + + # Forward-backward + x = torch.randn(4, 8, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Check state for 2D weight + weight_state = optimizer.state[model.weight] + self.assertIn("momentum_buffer", weight_state) + self.assertIn("v_buffer", weight_state) + + # momentum_buffer shape should match grad shape + self.assertEqual( + weight_state["momentum_buffer"].shape, + model.weight.shape, + "momentum_buffer shape should match weight shape", + ) + + # v_buffer numel should equal reshaped matrix numel (m * n) + m, n = model.weight.shape[0], model.weight.numel() // model.weight.shape[0] + self.assertEqual( + weight_state["v_buffer"].numel(), + m * n, + "v_buffer numel should equal reshaped matrix numel", + ) + + def test_1d_state_fp32(self) -> None: + """Verify 1D Adam path uses FP32 state tensors.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02) + + # Forward-backward + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Bias state should be FP32 + bias_state = optimizer.state[model.bias] + self.assertEqual( + bias_state["exp_avg"].dtype, + torch.float32, + "exp_avg should be FP32", + ) + self.assertEqual( + bias_state["exp_avg_sq"].dtype, + torch.float32, + "exp_avg_sq should be FP32", + ) + + +class TestAdaMuonOptimizerBucketing(unittest.TestCase): + """Test bucketed batch Newton-Schulz processing.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_bucketed_path(self) -> None: + """Test that two layers with same weight shape land in the same bucket.""" + torch.manual_seed(42) + # Create two Linear layers with SAME weight shape (32, 16) + # Use parallel structure to avoid dimension mismatch + layer1 = torch.nn.Linear(16, 32, bias=False, device=self.device) + layer2 = torch.nn.Linear(16, 32, bias=False, device=self.device) + + optimizer = AdaMuonOptimizer([layer1.weight, layer2.weight], lr=0.02) + + # Store initial weights + weight1_before = layer1.weight.clone() + weight2_before = layer2.weight.clone() + + # Forward-backward with same input for both layers (parallel) + x = torch.randn(4, 16, device=self.device) + y = layer1(x) + layer2(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Both weights should have changed + self.assertFalse( + torch.allclose(layer1.weight, weight1_before), + "Layer1 weight should change after optimizer step", + ) + self.assertFalse( + torch.allclose(layer2.weight, weight2_before), + "Layer2 weight should change after optimizer step", + ) + + # Both should have v_buffer in state + self.assertIn("v_buffer", optimizer.state[layer1.weight]) + self.assertIn("v_buffer", optimizer.state[layer2.weight]) + + +class TestAdaMuonOptimizerLrAdjust(unittest.TestCase): + """Test lr_adjust behavior.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_lr_adjust_modes_differ(self) -> None: + """Test that lr_adjust <= 0 (match-RMS) and > 0 (rectangular) produce different updates.""" + torch.manual_seed(42) + + # Create two identical models + model1 = torch.nn.Linear(16, 32, bias=False, device=self.device) + model2 = torch.nn.Linear(16, 32, bias=False, device=self.device) + model2.load_state_dict(model1.state_dict()) + + # Two optimizers with different lr_adjust + opt1 = AdaMuonOptimizer( + model1.parameters(), lr=0.02, lr_adjust=-1.0 + ) # match-RMS + opt2 = AdaMuonOptimizer( + model2.parameters(), lr=0.02, lr_adjust=10.0 + ) # rectangular + + # Same input for both + torch.manual_seed(123) + x = torch.randn(4, 16, device=self.device) + + # Forward-backward for model1 + y1 = model1(x) + loss1 = y1.sum() + loss1.backward() + opt1.step() + + # Reset seed and run for model2 + torch.manual_seed(123) + x = torch.randn(4, 16, device=self.device) + y2 = model2(x) + loss2 = y2.sum() + loss2.backward() + opt2.step() + + # Updates should be different (not equal) due to different scaling + self.assertFalse( + torch.allclose(model1.weight, model2.weight), + "Different lr_adjust modes should produce different updates", + ) + + +class TestAdaMuonOptimizerWeightDecay(unittest.TestCase): + """Test weight decay application.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_weight_decay_only(self) -> None: + """Test decoupled weight decay scales weights when gradients are zero.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, bias=False, device=self.device) + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1) + + w_before = model.weight.detach().clone() + + # === Step 1. Make zero gradients === + model.weight.grad = torch.zeros_like(model.weight) + + # === Step 2. Step once === + optimizer.step() + + # === Step 3. Expect pure multiplicative decay: w <- (1 - lr*wd) * w === + expected = w_before * (1.0 - 0.02 * 0.1) + self.assertTrue( + torch.allclose(model.weight, expected), + "Weight should be scaled by (1 - lr * weight_decay)", + ) + + +class TestAdaMuonOptimizerClosure(unittest.TestCase): + """Test optimizer with closure.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_closure(self) -> None: + """Test optimizer with closure returns loss.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 5, device=self.device) + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02) + + def closure(): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + return loss + + loss = optimizer.step(closure) + self.assertIsNotNone(loss) + + +class TestAdaMuonOptimizerStateDict(unittest.TestCase): + """Test optimizer state dict save/load.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_state_dict_save_load(self) -> None: + """Test saving and loading optimizer state.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = AdaMuonOptimizer(model.parameters(), lr=0.02) + + # Run a few steps to populate state + for _ in range(3): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + y = model(x) + loss = y.sum() + loss.backward() + optimizer.step() + + # Save state + state_dict = optimizer.state_dict() + + # Create new optimizer and load state + optimizer2 = AdaMuonOptimizer(model.parameters(), lr=0.02) + optimizer2.load_state_dict(state_dict) + + # Verify state matches using param_groups as anchor + params1 = list(optimizer.param_groups[0]["params"]) + params2 = list(optimizer2.param_groups[0]["params"]) + + for p1, p2 in zip(params1, params2): + s1 = optimizer.state[p1] + s2 = optimizer2.state[p2] + self.assertEqual(set(s1.keys()), set(s2.keys())) + for key in s1: + if isinstance(s1[key], torch.Tensor): + self.assertTrue( + torch.allclose(s1[key], s2[key]), + f"State mismatch for key {key}", + ) + else: + self.assertEqual(s1[key], s2[key], f"State mismatch for key {key}") + + +if __name__ == "__main__": + unittest.main()