diff --git a/deepmd/pt/optimizer/__init__.py b/deepmd/pt/optimizer/__init__.py index 4c069cf2ea..1899f27fff 100644 --- a/deepmd/pt/optimizer/__init__.py +++ b/deepmd/pt/optimizer/__init__.py @@ -2,6 +2,9 @@ from .adamuon import ( AdaMuonOptimizer, ) +from .hybrid_muon import ( + HybridMuonOptimizer, +) from .KFWrapper import ( KFOptimizerWrapper, ) @@ -9,4 +12,9 @@ LKFOptimizer, ) -__all__ = ["AdaMuonOptimizer", "KFOptimizerWrapper", "LKFOptimizer"] +__all__ = [ + "AdaMuonOptimizer", + "HybridMuonOptimizer", + "KFOptimizerWrapper", + "LKFOptimizer", +] diff --git a/deepmd/pt/optimizer/hybrid_muon.py b/deepmd/pt/optimizer/hybrid_muon.py new file mode 100644 index 0000000000..abf4d3a572 --- /dev/null +++ b/deepmd/pt/optimizer/hybrid_muon.py @@ -0,0 +1,725 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +""" +HybridMuon optimizer for DeePMD-kit PyTorch backend. + +HybridMuon is a HYBRID optimizer that automatically combines Muon and Adam: +- For >=2D parameters with min(m,n) >= min_2d_dim: Muon update with Newton-Schulz +- For 2D parameters with min(m,n) < min_2d_dim: Adam fallback with update clipping +- For 1D parameters (biases, layer norms): Standard Adam + +This is different from PyTorch's torch.optim.Muon, which ONLY supports 2D parameters +and requires manual configuration of AdamW for 1D parameters. HybridMuon provides +automatic routing based on parameter dimensionality. + +Algorithm +--------- +For >=2D parameters (weight matrices), the Muon update is: + + 1. Momentum update (Nesterov): + m_t = beta * m_{t-1} + (1 - beta) * g_t + update = beta * m_t + (1 - beta) * g_t + + 2. Newton-Schulz orthogonalization (quintic iteration): + X_0 = G / ||G||_F + X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T + Coefficients: a=3.4445, b=-4.7750, c=2.0315 + + 3. Scaling: scale = coeff * sqrt(max(m, n)) [match-RMS mode] + scale = sqrt(max(1, m/n)) [rectangular mode] + + 4. Parameter update: theta -= lr * scale * orth(update) + +For 1D parameters (biases, norms), standard Adam is used. + +Dtype Behavior +-------------- +- Newton-Schulz iterations: always bfloat16 (matches official Muon) +- NS output (bfloat16) directly applied to parameters (PyTorch handles mixed precision) +- Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability +- Muon momentum buffer: follows gradient dtype (grad -> buffer -> update) +- Adam gradients: cast to float32 for update computation + +References +---------- +.. [1] Keller Jordan, "Muon: An optimizer for hidden layers in neural networks." + https://kellerjordan.github.io/posts/muon/ + https://github.com/KellerJordan/Muon +.. [2] Moonshot team, "Muon is Scalable for LLM Training," arXiv:2502.16982, 2025. + https://arxiv.org/abs/2502.16982 +.. [3] Moonlight GitHub Repository. + https://github.com/MoonshotAI/Moonlight +""" + +from __future__ import ( + annotations, +) + +import math +from typing import ( + TYPE_CHECKING, + Any, +) + +import torch +from torch.optim.optimizer import ( + Optimizer, +) + +if TYPE_CHECKING: + from collections.abc import ( + Iterable, + ) + +# ============================================================================ +# Constants +# ============================================================================ + +# Newton-Schulz iteration count +NS_STEPS: int = 5 +# Numerical stability epsilon for norm clamping and Adam +EPS: float = 1e-7 +# Quintic Newton-Schulz polynomial coefficients +NS_COEFF_A: float = 3.4445 +NS_COEFF_B: float = -4.7750 +NS_COEFF_C: float = 2.0315 + + +def _maybe_compile( + fn: callable, +) -> callable: + """Compile a function if torch.compile is available.""" + if not hasattr(torch, "compile"): + return fn + # Skip compile if default device is CUDA but CUDA is unavailable. + if hasattr(torch, "get_default_device"): + default_device = torch.get_default_device() + if default_device.type == "cuda" and not torch.cuda.is_available(): + return fn + return torch.compile(fn, fullgraph=True, dynamic=True) + + +@_maybe_compile +def _zeropower_via_newtonschulz5_2d( + G: torch.Tensor, +) -> torch.Tensor: + """ + Orthogonalize a 2D matrix via quintic Newton-Schulz iteration. + + Mathematical formulation: + X_0 = G / ||G||_F + X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T + Coefficients: a=3.4445, b=-4.7750, c=2.0315 + """ + # === Step 1. Cast to bf16 and transpose tall matrices === + X = G.to(dtype=torch.bfloat16) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.transpose(-2, -1) + + # === Step 2. Normalize Frobenius norm to at most 1 === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) + + # === Step 3. Newton-Schulz iterations with fused GEMM === + for _ in range(NS_STEPS): + A = torch.mm(X, X.transpose(-2, -1)) + gram_update = torch.addmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C) + X = torch.addmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0) + + # === Step 4. Transpose back if needed === + if transposed: + X = X.transpose(-2, -1) + + return X + + +@_maybe_compile +def _zeropower_via_newtonschulz5_3d( + G: torch.Tensor, +) -> torch.Tensor: + """ + Orthogonalize a 3D batch of matrices via quintic Newton-Schulz iteration. + + Mathematical formulation: + X_0 = G / ||G||_F + X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T + Coefficients: a=3.4445, b=-4.7750, c=2.0315 + """ + # === Step 1. Cast to bf16 and transpose tall matrices === + X = G.to(dtype=torch.bfloat16) + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.transpose(-2, -1) + + # === Step 2. Normalize Frobenius norm to at most 1 === + X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS) + + # === Step 3. Newton-Schulz iterations with batched fused GEMM === + for _ in range(NS_STEPS): + A = torch.bmm(X, X.transpose(-2, -1)) + gram_update = torch.baddbmm(A, A, A, beta=NS_COEFF_B, alpha=NS_COEFF_C) + X = torch.baddbmm(X, gram_update, X, beta=NS_COEFF_A, alpha=1.0) + + # === Step 4. Transpose back if needed === + if transposed: + X = X.transpose(-2, -1) + + return X + + +def zeropower_via_newtonschulz5( + G: torch.Tensor, +) -> torch.Tensor: + """ + Compute the zeroth power (orthogonalization) via Newton-Schulz iteration. + + Dispatches to compiled 2D or 3D kernels for best performance. + + Parameters + ---------- + G : torch.Tensor + Input matrix with shape (M, N) or batched input with shape (B, M, N). + + Returns + ------- + torch.Tensor + Orthogonalized tensor in bfloat16 with same shape as input. + + Raises + ------ + ValueError + If input is not 2D or 3D. + """ + if G.ndim == 2: + return _zeropower_via_newtonschulz5_2d(G) + if G.ndim == 3: + return _zeropower_via_newtonschulz5_3d(G) + raise ValueError("Input must be 2D or 3D for Newton-Schulz orthogonalization.") + + +def should_fallback_to_adam_for_matrix( + p: torch.Tensor, + min_2d_dim: int, +) -> bool: + """ + Check if a 2D matrix should fallback to Adam due to small dimensions. + + Parameters + ---------- + p : torch.Tensor + Parameter tensor with ndim >= 2. + min_2d_dim : int + Minimum min(m, n) threshold for Muon. Matrices with min(m, n) >= + min_2d_dim use Muon; those with min(m, n) < min_2d_dim use Adam. + + Returns + ------- + bool + True if min(m, n) < min_2d_dim, False otherwise. + + Raises + ------ + ValueError + If tensor has ndim < 2. + """ + # === Step 1. Validate === + if p.ndim < 2: + raise ValueError("Parameter must have ndim >= 2 for Muon suitability check.") + + # === Step 2. Derive matrix shape consistent with Muon reshape === + m = int(p.shape[0]) + n = int(p.numel() // p.shape[0]) + + # === Step 3. Check if any dimension too small for Muon === + return min(m, n) < min_2d_dim + + +class HybridMuonOptimizer(Optimizer): + """ + HybridMuon optimizer with small-2D Adam fallback and 1D Adam path. + + This optimizer applies different update rules based on parameter dimensionality: + - For >=2D parameters with min(m, n) >= min_2d_dim: + Muon update with Newton-Schulz orthogonalization. + - For 2D parameters with min(m, n) < min_2d_dim (small matrices): + Adam update with scaled learning rate and update clipping. + - For 1D parameters (biases, layer norms): + Standard Adam update. + + This hybrid approach is effective because Muon's orthogonalization is designed + for weight matrices, while Adam is more suitable for biases and normalization params. + + Update Rules + ------------ + Muon (>=2D params): + 1. Momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t + 2. Nesterov lookahead: update = beta*m_t + (1-beta)*g_t + 3. Newton-Schulz orthogonalization: orth = NS(update) + 4. Scaling: scale = coeff*sqrt(max(m,n)) or sqrt(max(1, m/n)) + 5. Parameter update: theta -= lr * scale * orth + + Adam (1D params): + Standard Adam with bias correction, all computations in float32. + + Parameters + ---------- + params : iterable + Iterable of parameters to optimize. + lr : float + Learning rate with default 1e-3. + momentum : float + Momentum coefficient for Muon with default 0.95. + weight_decay : float + Weight decay coefficient (applied only to Muon-routed parameters) with default 0.001. + adam_betas : tuple[float, float] + Adam beta coefficients with default (0.9, 0.95). + lr_adjust : float + Learning rate adjustment mode for Muon scaling and Adam learning rate. + - If lr_adjust <= 0: use match-RMS scaling for Muon, + scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly. + - If lr_adjust > 0: use rectangular correction for Muon, + scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust. + Default is 10.0 (Adam lr = lr/10). + lr_adjust_coeff : float + Dual-purpose coefficient with default 0.2: + 1. For Muon (when lr_adjust <= 0): match-RMS scaling factor, + scale = lr_adjust_coeff * sqrt(max(m, n)). + 2. For 2D Adam fallback: learning rate multiplier, + adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1). + The min(., 0.1) cap ensures conservative updates for small matrices. + muon_2d_only : bool + If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). + Parameters with ndim > 2 use Adam without weight decay. + If False, all >=2D parameters use Muon (default behavior). + Default is True. + min_2d_dim : int + Minimum min(m, n) threshold for Muon on 2D matrices. + Matrices with min(m, n) >= min_2d_dim use Muon; + those with min(m, n) < min_2d_dim use Adam fallback. + Must be >= 1. + Set to 1 to disable fallback. + Default is 1. + + Examples + -------- + >>> optimizer = HybridMuonOptimizer(model.parameters(), lr=1e-3) + >>> for epoch in range(epochs): + ... optimizer.zero_grad() + ... loss.backward() + ... optimizer.step() + """ + + def __init__( + self, + params: Iterable[torch.Tensor] | Iterable[dict[str, Any]], + lr: float = 1e-3, + momentum: float = 0.95, + weight_decay: float = 0.001, + adam_betas: tuple[float, float] = (0.9, 0.95), + lr_adjust: float = 10.0, + lr_adjust_coeff: float = 0.2, + muon_2d_only: bool = True, + min_2d_dim: int = 1, + ) -> None: + if min_2d_dim < 1: + raise ValueError("min_2d_dim must be >= 1.") + + defaults = { + "lr": lr, + "momentum": momentum, + "weight_decay": weight_decay, + "adam_betas": adam_betas, + "lr_adjust": lr_adjust, + "lr_adjust_coeff": lr_adjust_coeff, + "muon_2d_only": muon_2d_only, + "min_2d_dim": min_2d_dim, + } + super().__init__(params, defaults) + # Static parameter routing: built once on first step() call. + self._routing_built = False + self._routing: list[dict[str, Any]] = [] + + def _build_param_routing(self) -> None: + """ + Classify parameters into Muon and Adam routes (static routing). + + Routing logic: + - 1D parameters → Adam path + - >2D parameters (when muon_2d_only=True) → Adam path + - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path + - 2D parameters with min(m, n) >= min_2d_dim → Muon path + - >=2D parameters (when muon_2d_only=False) → Muon path + """ + if self._routing_built: + return + + self._routing = [] + for group in self.param_groups: + muon_params: list[dict[str, Any]] = [] + adam_1d: list[dict[str, Any]] = [] + adam_matrix: list[dict[str, Any]] = [] + adam_nd: list[dict[str, Any]] = [] + + min_2d_dim = group["min_2d_dim"] + muon_2d_only = group["muon_2d_only"] + + for p in group["params"]: + # === Step 1. 1D parameters → Adam === + if p.ndim < 2: + adam_1d.append({"param": p}) + continue + + # === Step 2. >2D parameters (when muon_2d_only=True) → Adam === + if muon_2d_only and p.ndim > 2: + adam_nd.append({"param": p}) + continue + + # === Step 3. 2D small matrices → Adam fallback === + if (p.ndim == 2) and should_fallback_to_adam_for_matrix( + p, min_2d_dim=min_2d_dim + ): + adam_matrix.append( + { + "param": p, + "abs_floor": 1e-3 * math.sqrt(float(p.numel())), + } + ) + continue + + # === Step 4. >=2D (or 2D only when muon_2d_only=True) → Muon === + muon_params.append( + { + "param": p, + "rows": int(p.shape[0]), + "cols": int(p.numel() // p.shape[0]), + } + ) + + self._routing.append( + { + "muon_params": muon_params, + "adam_1d": adam_1d, + "adam_matrix": adam_matrix, + "adam_nd": adam_nd, + } + ) + + self._routing_built = True + + @torch.no_grad() + def step( + self, + closure: callable | None = None, + ) -> torch.Tensor | None: + """ + Perform a single optimization step. + + Parameters + ---------- + closure : callable, optional + A closure that reevaluates the model and returns the loss. + + Returns + ------- + torch.Tensor | None + The loss value if closure is provided, otherwise None. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # Build static parameter routing on first call. + self._build_param_routing() + + for group_idx, group in enumerate(self.param_groups): + route = self._routing[group_idx] + lr = group["lr"] + momentum = group["momentum"] + weight_decay = group["weight_decay"] + adam_betas = group["adam_betas"] + lr_adjust = group["lr_adjust"] + lr_adjust_coeff = group["lr_adjust_coeff"] + + # === Step 1. Adam update for 1D parameters (biases, norms, etc.) === + # === Step 1.1. Collect gradients and initialize state === + adam_params: list[torch.Tensor] = [] + adam_grads_fp32: list[torch.Tensor] = [] + adam_exp_avgs: list[torch.Tensor] = [] + adam_exp_avg_sqs: list[torch.Tensor] = [] + adam_states: list[dict[str, Any]] = [] + + for entry in route["adam_1d"]: + p = entry["param"] + grad = p.grad + if grad is None: + continue + + grad_fp32 = grad.float() + + state = self.state[p] + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + + adam_params.append(p) + adam_grads_fp32.append(grad_fp32) + adam_exp_avgs.append(state["exp_avg"]) + adam_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_states.append(state) + + if adam_params: + # === Step 1.2. Update exp_avg / exp_avg_sq === + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + + # exp_avg = beta1 * exp_avg + (1 - beta1) * grad + # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 + torch._foreach_lerp_(adam_exp_avgs, adam_grads_fp32, 1 - adam_betas[0]) + grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32) + torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + + # === Step 1.3. Bias correction and parameter update === + for i, p in enumerate(adam_params): + state = adam_states[i] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr / bias_corr1 + # delta = -step_size * m_hat / (sqrt(v_hat) + eps) + denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) + delta_fp32 = -step_size * (adam_exp_avgs[i] / denom) + p.add_(delta_fp32.to(p.dtype)) + + # === Step 2. Adam update for >2D parameters (when muon_2d_only=True) === + # === Step 2.1. Collect gradients and initialize state === + adam_nd_params: list[torch.Tensor] = [] + adam_nd_grads_fp32: list[torch.Tensor] = [] + adam_nd_exp_avgs: list[torch.Tensor] = [] + adam_nd_exp_avg_sqs: list[torch.Tensor] = [] + adam_nd_states: list[dict[str, Any]] = [] + + for entry in route.get("adam_nd", []): + p = entry["param"] + grad = p.grad + if grad is None: + continue + + grad_fp32 = grad.float() + + state = self.state[p] + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + + adam_nd_params.append(p) + adam_nd_grads_fp32.append(grad_fp32) + adam_nd_exp_avgs.append(state["exp_avg"]) + adam_nd_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_nd_states.append(state) + + if adam_nd_params: + # === Step 2.2. Update exp_avg / exp_avg_sq === + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + + # exp_avg = beta1 * exp_avg + (1 - beta1) * grad + # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 + torch._foreach_lerp_( + adam_nd_exp_avgs, adam_nd_grads_fp32, 1 - adam_betas[0] + ) + grad_sq = torch._foreach_mul(adam_nd_grads_fp32, adam_nd_grads_fp32) + torch._foreach_lerp_(adam_nd_exp_avg_sqs, grad_sq, 1 - adam_betas[1]) + + # === Step 2.3. Bias correction and parameter update === + for i, p in enumerate(adam_nd_params): + state = adam_nd_states[i] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr / bias_corr1 + # delta = -step_size * m_hat / (sqrt(v_hat) + eps) + denom = (adam_nd_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) + delta_fp32 = -step_size * (adam_nd_exp_avgs[i] / denom) + p.add_(delta_fp32.to(p.dtype)) + + # === Step 3. Adam update for small 2D matrices (fallback path) === + # === Step 3.1. Collect gradients and initialize state === + adam_matrix_params: list[torch.Tensor] = [] + adam_matrix_grads_fp32: list[torch.Tensor] = [] + adam_matrix_exp_avgs: list[torch.Tensor] = [] + adam_matrix_exp_avg_sqs: list[torch.Tensor] = [] + adam_matrix_states: list[dict[str, Any]] = [] + adam_matrix_abs_floor: list[float] = [] + + for entry in route["adam_matrix"]: + p = entry["param"] + grad = p.grad + if grad is None: + continue + + grad_fp32 = grad.float() + + state = self.state[p] + if "exp_avg" not in state: + state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32) + state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32) + state["beta1_pow"] = 1.0 + state["beta2_pow"] = 1.0 + + state["beta1_pow"] *= adam_betas[0] + state["beta2_pow"] *= adam_betas[1] + + adam_matrix_params.append(p) + adam_matrix_grads_fp32.append(grad_fp32) + adam_matrix_exp_avgs.append(state["exp_avg"]) + adam_matrix_exp_avg_sqs.append(state["exp_avg_sq"]) + adam_matrix_states.append(state) + adam_matrix_abs_floor.append(entry["abs_floor"]) + + if adam_matrix_params: + # === Step 3.2. Update exp_avg / exp_avg_sq with scaled lr === + adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust + adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1) + + # exp_avg = beta1 * exp_avg + (1 - beta1) * grad + # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2 + torch._foreach_lerp_( + adam_matrix_exp_avgs, adam_matrix_grads_fp32, 1 - adam_betas[0] + ) + grad_sq_m = torch._foreach_mul( + adam_matrix_grads_fp32, adam_matrix_grads_fp32 + ) + torch._foreach_lerp_( + adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1] + ) + + # === Step 3.3. Compute unclipped deltas === + raw_deltas: list[torch.Tensor] = [] + for i in range(len(adam_matrix_params)): + state = adam_matrix_states[i] + bias_corr1 = 1 - state["beta1_pow"] + bias_corr2 = 1 - state["beta2_pow"] + step_size = adam_lr_matrix / bias_corr1 + denom = (adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS) + raw_deltas.append(-step_size * (adam_matrix_exp_avgs[i] / denom)) + + # === Step 3.4. Clip updates by relative norm and apply === + max_rel_change = 0.05 + p_norms = torch.stack(torch._foreach_norm(adam_matrix_params)) + delta_norms = torch.stack(torch._foreach_norm(raw_deltas)) + floors = torch.tensor( + adam_matrix_abs_floor, + device=p_norms.device, + dtype=p_norms.dtype, + ) + max_delta = torch.maximum(max_rel_change * p_norms, floors) + scales_tensor = torch.clamp(max_delta / (delta_norms + 1e-12), max=1.0) + for i, (p, delta) in enumerate( + zip(adam_matrix_params, raw_deltas, strict=False) + ): + p.add_(delta.mul_(scales_tensor[i]).to(p.dtype)) + + # === Step 4. Muon update for >=2D parameters (weight matrices) === + # === Step 4.1. Collect gradients and initialize momentum === + muon_params_for_decay: list[torch.Tensor] = [] + muon_grads: list[torch.Tensor] = [] + muon_momentum_buffers: list[torch.Tensor] = [] + active_entries: list[tuple[dict[str, Any], torch.Tensor]] = [] + + for entry in route["muon_params"]: + p = entry["param"] + grad = p.grad + if grad is None: + continue + + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + + buf = state["momentum_buffer"] + if grad.dtype != buf.dtype: + grad = grad.to(dtype=buf.dtype) + + muon_params_for_decay.append(p) + muon_grads.append(grad) + muon_momentum_buffers.append(buf) + active_entries.append((entry, grad)) + + # === Step 4.2. Apply weight decay (Muon path only) === + if weight_decay > 0 and muon_params_for_decay: + torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay) + + if not active_entries: + continue + + # === Step 4.3. Momentum update (Nesterov) === + # m_t = beta * m_{t-1} + (1 - beta) * g_t + torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum) + # update = beta * m_t + (1 - beta) * g_t + muon_updates = torch._foreach_lerp( + muon_grads, muon_momentum_buffers, momentum + ) + + # === Step 4.4. Bucket by shape/device/dtype for batched NS === + buckets: dict[ + tuple[int, int, torch.device, torch.dtype], + list[tuple[dict[str, Any], torch.Tensor]], + ] = {} + + for idx, entry_info in enumerate(active_entries): + entry, _ = entry_info + p = entry["param"] + bucket_key = (entry["rows"], entry["cols"], p.device, p.dtype) + if bucket_key not in buckets: + buckets[bucket_key] = [] + buckets[bucket_key].append((entry, muon_updates[idx])) + + # === Step 4.5. Newton-Schulz orthogonalization and update === + for (rows, cols, _device, _), bucket_entries in buckets.items(): + # scale = coeff * sqrt(max(m, n)) [match-RMS mode] + # scale = sqrt(max(1, m/n)) [rectangular mode] + if lr_adjust <= 0: + scale = lr_adjust_coeff * math.sqrt(float(max(rows, cols))) + else: + scale = max(1.0, rows / cols) ** 0.5 + + if len(bucket_entries) == 1: + entry, update_tensor = bucket_entries[0] + update_matrix = update_tensor.reshape(rows, cols) + if not update_matrix.is_contiguous(): + update_matrix = update_matrix.contiguous() + + orth = _zeropower_via_newtonschulz5_2d(update_matrix) + orth.mul_(scale) + delta = orth.reshape(entry["param"].shape) + entry["param"].add_(delta, alpha=-lr) + continue + + matrices: list[torch.Tensor] = [] + params: list[torch.Tensor] = [] + orig_shapes: list[tuple[int, ...]] = [] + + for entry, update_tensor in bucket_entries: + update_matrix = update_tensor.reshape(rows, cols) + matrices.append( + update_matrix + if update_matrix.is_contiguous() + else update_matrix.contiguous() + ) + params.append(entry["param"]) + orig_shapes.append(entry["param"].shape) + + stacked = torch.stack(matrices, dim=0) + orth = _zeropower_via_newtonschulz5_3d(stacked) + orth.mul_(scale) + + for i, _ in enumerate(bucket_entries): + params[i].add_(orth[i].reshape(orig_shapes[i]), alpha=-lr) + + return loss diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 0dfbe94b6b..20497a0ceb 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -43,6 +43,7 @@ ) from deepmd.pt.optimizer import ( AdaMuonOptimizer, + HybridMuonOptimizer, KFOptimizerWrapper, LKFOptimizer, ) @@ -158,6 +159,7 @@ def __init__( def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: opt_type = params.get("opt_type", "Adam") opt_param = { + # LKF parameters "kf_blocksize": params.get("kf_blocksize", 5120), "kf_start_pref_e": params.get("kf_start_pref_e", 1), "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), @@ -169,6 +171,10 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]: "momentum": params.get("momentum", 0.95), "adam_beta1": params.get("adam_beta1", 0.9), "adam_beta2": params.get("adam_beta2", 0.95), + "lr_adjust": params.get("lr_adjust", 10.0), + "lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2), + "muon_2d_only": params.get("muon_2d_only", True), + "min_2d_dim": params.get("min_2d_dim", 1), } return opt_type, opt_param @@ -648,8 +654,7 @@ def single_model_finetune( missing, unexpected = self.model.load_state_dict(state, strict=False) if missing or unexpected: log.warning( - "Checkpoint loaded non-strictly. " - f"Missing keys: {missing}, Unexpected keys: {unexpected}" + f"Checkpoint loaded non-strictly. Missing keys: {missing}, Unexpected keys: {unexpected}" ) # Get model prob for multi-task @@ -735,14 +740,29 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.optimizer = AdaMuonOptimizer( self.wrapper.parameters(), lr=self.lr_exp.start_lr, - momentum=float(self.opt_param.get("momentum", 0.95)), - weight_decay=float(self.opt_param.get("weight_decay", 0.001)), + momentum=float(self.opt_param["momentum"]), + weight_decay=float(self.opt_param["weight_decay"]), adam_betas=( - float(self.opt_param.get("adam_beta1", 0.9)), - float(self.opt_param.get("adam_beta2", 0.95)), + float(self.opt_param["adam_beta1"]), + float(self.opt_param["adam_beta2"]), ), - lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)), - lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)), + lr_adjust=float(self.opt_param["lr_adjust"]), + lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), + ) + elif self.opt_type == "HybridMuon": + self.optimizer = HybridMuonOptimizer( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + momentum=float(self.opt_param["momentum"]), + weight_decay=float(self.opt_param["weight_decay"]), + adam_betas=( + float(self.opt_param["adam_beta1"]), + float(self.opt_param["adam_beta2"]), + ), + lr_adjust=float(self.opt_param["lr_adjust"]), + lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]), + muon_2d_only=bool(self.opt_param["muon_2d_only"]), + min_2d_dim=int(self.opt_param["min_2d_dim"]), ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) @@ -820,7 +840,7 @@ def step(_step_id: int, task_key: str = "Default") -> None: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() - if self.opt_type in ["Adam", "AdamW", "AdaMuon"]: + if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr @@ -1562,8 +1582,6 @@ def model_change_out_bias( model_type_map = _model.get_type_map() log.info( - f"Change output bias of {model_type_map!s} " - f"from {to_numpy_array(old_bias).reshape(-1)!s} " - f"to {to_numpy_array(new_bias).reshape(-1)!s}." + f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}." ) return _model diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 935762cdc7..8c20bb8bf4 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3425,6 +3425,7 @@ def training_args( float, optional=True, default=0.95, + alias=["muon_momentum"], doc=doc_only_pt_supported + "Momentum coefficient for AdaMuon optimizer.", ), @@ -3474,6 +3475,95 @@ def training_args( [], optional=True, ), + Argument( + "HybridMuon", + dict, + [ + Argument( + "momentum", + float, + optional=True, + default=0.95, + alias=["muon_momentum"], + doc=doc_only_pt_supported + + "Momentum coefficient for HybridMuon optimizer (>=2D params). " + "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.", + ), + Argument( + "adam_beta1", + float, + optional=True, + default=0.9, + doc=doc_only_pt_supported + + "Adam beta1 coefficient for 1D parameters (biases, norms).", + ), + Argument( + "adam_beta2", + float, + optional=True, + default=0.95, + doc=doc_only_pt_supported + + "Adam beta2 coefficient for 1D parameters (biases, norms).", + ), + Argument( + "weight_decay", + float, + optional=True, + default=0.001, + doc=doc_only_pt_supported + + "Weight decay coefficient. Applied only to Muon-routed parameters", + ), + Argument( + "lr_adjust", + float, + optional=True, + default=10.0, + doc=doc_only_pt_supported + + "Learning rate adjustment mode for HybridMuon scaling and Adam learning rate. " + "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. " + "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. " + "Default is 10.0 (Adam lr = lr/10).", + ), + Argument( + "lr_adjust_coeff", + float, + optional=True, + default=0.2, + doc=doc_only_pt_supported + + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.", + ), + Argument( + "muon_2d_only", + bool, + optional=True, + default=True, + doc=doc_only_pt_supported + + "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). " + + "Parameters with ndim > 2 use Adam without weight decay. " + + "If False, all >=2D parameters use Muon.", + ), + Argument( + "min_2d_dim", + int, + optional=True, + default=1, + alias=["muon_min_2d_dim"], + doc=doc_only_pt_supported + + "Minimum min(m, n) threshold for HybridMuon on 2D matrices. " + "Matrices with min(m, n) >= min_2d_dim use HybridMuon; " + "those with min(m, n) < min_2d_dim use Adam fallback. " + "Set to 1 to disable fallback.", + ), + ], + [], + optional=True, + doc=doc_only_pt_supported + + "HybridMuon optimizer (DeePMD-kit custom implementation). " + + "This is a Hybrid optimizer that automatically combines Muon and Adam. " + + "For >=2D params: Muon update with Newton-Schulz. " + + "For 1D params: Standard Adam. " + + "This is DIFFERENT from PyTorch's torch.optim.Muon which ONLY supports 2D parameters.", + ), ], optional=True, default_tag="Adam", diff --git a/source/tests/pt/test_hybrid_muon.py b/source/tests/pt/test_hybrid_muon.py new file mode 100644 index 0000000000..77973c5728 --- /dev/null +++ b/source/tests/pt/test_hybrid_muon.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +from deepmd.pt.optimizer.hybrid_muon import ( + HybridMuonOptimizer, + zeropower_via_newtonschulz5, +) +from deepmd.pt.utils import ( + env, +) + + +def _bf16_matmul_supported(device: torch.device) -> bool: + """Check if bf16 matmul is reliably supported on the given device.""" + if device.type == "cuda": + if not torch.cuda.is_available(): + return False + # bf16 requires compute capability >= 8.0 (Ampere+) for native support + # or >= 7.0 (Volta) with tensor cores, but may have precision issues + if hasattr(torch.cuda, "is_bf16_supported"): + return torch.cuda.is_bf16_supported() + # Fallback: check compute capability directly + cap = torch.cuda.get_device_capability(device) + return cap[0] >= 8 + # CPU bf16 support: available on x86 with AVX-512 BF16 or ARM with BF16 extension + # Since it's hard to detect reliably, try a small matmul and check for errors + try: + a = torch.randn(4, 4, dtype=torch.bfloat16, device=device) + _ = torch.mm(a, a.T) + return True + except (RuntimeError, TypeError): + return False + + +BF16_SUPPORTED = _bf16_matmul_supported(env.DEVICE) + + +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") +class TestNewtonSchulzOrthogonalization(unittest.TestCase): + """Test Newton-Schulz orthogonalization algorithm.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_orthogonalization(self) -> None: + """Test that NS produces approximately orthogonal output.""" + torch.manual_seed(42) + G = torch.randn(4, 4, dtype=torch.float32, device=self.device) + X = zeropower_via_newtonschulz5(G) + + # X @ X.T should be approximately identity + # Note: NS uses bf16 internally, 5 iterations gives ~0.1-0.3 error + XXT = X.float() @ X.float().T + diag = torch.diag(XXT) + self.assertTrue( + torch.allclose( + diag, torch.ones(4, dtype=torch.float32, device=self.device), atol=0.5 + ), + f"Diagonal not close to 1: {diag}", + ) + off_diag_norm = (XXT - torch.diag(diag)).norm() + self.assertLess( + off_diag_norm, 1.5, f"Off-diagonal norm too large: {off_diag_norm}" + ) + + def test_shape_and_dtype(self) -> None: + """Test that output preserves shape and returns bf16.""" + torch.manual_seed(42) + for shape in [(4, 4), (6, 4), (3, 4, 4)]: + G = torch.randn(*shape, dtype=torch.float32, device=self.device) + X = zeropower_via_newtonschulz5(G) + self.assertEqual(X.shape, G.shape) + self.assertEqual(X.dtype, torch.bfloat16) + + def test_invalid_input(self) -> None: + """Test that <2D input raises ValueError.""" + G_1d = torch.randn(10, dtype=torch.float32, device=self.device) + with self.assertRaises(ValueError): + zeropower_via_newtonschulz5(G_1d) + + +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") +class TestHybridMuonOptimizer(unittest.TestCase): + """Test HybridMuonOptimizer class.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_step(self) -> None: + """Test basic optimizer step changes parameters.""" + torch.manual_seed(42) + model = torch.nn.Sequential( + torch.nn.Linear(10, 20, device=self.device), + torch.nn.ReLU(), + torch.nn.Linear(20, 5, device=self.device), + ) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) + + x = torch.randn(4, 10, device=self.device) + model(x).sum().backward() + + initial_params = [p.clone() for p in model.parameters()] + optimizer.step() + + for i, (p, init_p) in enumerate( + zip(model.parameters(), initial_params, strict=True) + ): + self.assertFalse(torch.allclose(p, init_p), f"Parameter {i} did not change") + + def test_weight_decay(self) -> None: + """Test weight decay reduces parameter norm.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1) + + initial_norm = model.weight.norm().item() + for _ in range(10): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + model(x).sum().backward() + optimizer.step() + + self.assertLess(model.weight.norm().item(), initial_norm) + + def test_muon_adam_separation(self) -> None: + """Test Muon for 2D params, Adam for 1D params.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) + + x = torch.randn(4, 10, device=self.device) + model(x).sum().backward() + optimizer.step() + + # 2D weight uses Muon (momentum_buffer) + self.assertIn("momentum_buffer", optimizer.state[model.weight]) + self.assertNotIn("exp_avg", optimizer.state[model.weight]) + # 1D bias uses Adam (exp_avg, exp_avg_sq) + self.assertIn("exp_avg", optimizer.state[model.bias]) + self.assertIn("exp_avg_sq", optimizer.state[model.bias]) + self.assertNotIn("momentum_buffer", optimizer.state[model.bias]) + + def test_muon_adam_fallback_small_2d(self) -> None: + """Test Adam fallback for small 2D matrices when min_2d_dim is set.""" + torch.manual_seed(42) + linear_small = torch.nn.Linear(10, 1, bias=False, device=self.device) + linear_large = torch.nn.Linear(10, 10, bias=False, device=self.device) + optimizer = HybridMuonOptimizer( + list(linear_small.parameters()) + list(linear_large.parameters()), + lr=0.02, + min_2d_dim=2, + ) + + x = torch.randn(4, 10, device=self.device) + loss = linear_small(x).sum() + linear_large(x).sum() + loss.backward() + optimizer.step() + + # Small 2D weight should use Adam fallback. + self.assertIn("exp_avg", optimizer.state[linear_small.weight]) + self.assertNotIn("momentum_buffer", optimizer.state[linear_small.weight]) + + # Large 2D weight should use Muon. + self.assertIn("momentum_buffer", optimizer.state[linear_large.weight]) + self.assertNotIn("exp_avg", optimizer.state[linear_large.weight]) + + def test_lr_adjust_modes(self) -> None: + """Test lr_adjust modes: match-RMS (<=0) vs rectangular (>0).""" + torch.manual_seed(42) + + model1 = torch.nn.Linear(10, 20, bias=False, device=self.device) + model2 = torch.nn.Linear(10, 20, bias=False, device=self.device) + model2.load_state_dict(model1.state_dict()) + + opt1 = HybridMuonOptimizer(model1.parameters(), lr=0.02, lr_adjust=0.0) + opt2 = HybridMuonOptimizer(model2.parameters(), lr=0.02, lr_adjust=10.0) + + x = torch.randn(4, 10, device=self.device) + + opt1.zero_grad() + model1(x).sum().backward() + opt1.step() + + opt2.zero_grad() + model2(x).sum().backward() + opt2.step() + + self.assertFalse( + torch.allclose(model1.weight, model2.weight), + "Different lr_adjust modes should produce different updates", + ) + + +@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device") +class TestHybridMuonOptimizerStateDict(unittest.TestCase): + """Test optimizer state dict save/load.""" + + def setUp(self) -> None: + self.device = env.DEVICE + + def test_state_dict_save_load(self) -> None: + """Test saving and loading optimizer state.""" + torch.manual_seed(42) + model = torch.nn.Linear(10, 10, device=self.device) + optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02) + + for _ in range(3): + optimizer.zero_grad() + x = torch.randn(4, 10, device=self.device) + model(x).sum().backward() + optimizer.step() + + state_dict = optimizer.state_dict() + + optimizer2 = HybridMuonOptimizer(model.parameters(), lr=0.02) + optimizer2.load_state_dict(state_dict) + + # Verify state matches by param id, not iteration order + for p in model.parameters(): + s1 = optimizer.state.get(p, {}) + s2 = optimizer2.state.get(p, {}) + self.assertEqual(len(s1), len(s2)) + for key in s1: + if isinstance(s1[key], torch.Tensor): + self.assertTrue(torch.allclose(s1[key], s2[key])) + else: + self.assertEqual(s1[key], s2[key]) + + +if __name__ == "__main__": + unittest.main()