Skip to content

Commit 74cf72d

Browse files
committed
refactor(pt): align HybridMuon dtype behavior with official PyTorch Muon
Changes: 1. Remove dtype conversion: NS output (bfloat16) now directly applied to parameters, matching torch.optim.Muon behavior where PyTorch handles mixed precision automatically. 2. Add muon_2d_only parameter (default True): When True, only 2D parameters use Muon; >2D parameters use Adam without weight decay. This matches PyTorch's official torch.optim.Muon which only supports 2D matrices. 3. Merge NS_EPS and ADAM_EPS into single EPS constant (both 1e-7). 4. Update dtype documentation to reflect actual behavior: - NS output (bfloat16) directly applied to parameters - Muon momentum buffer follows gradient dtype (not param dtype) 5. Update weight_decay docstring from ">=2D params" to "Muon-routed parameters" for accuracy with muon_2d_only=True.
1 parent 1978c7f commit 74cf72d

File tree

3 files changed

+110
-30
lines changed

3 files changed

+110
-30
lines changed

deepmd/pt/optimizer/hybrid_muon.py

Lines changed: 97 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
Dtype Behavior
3535
--------------
3636
- Newton-Schulz iterations: always bfloat16 (matches official Muon)
37+
- NS output (bfloat16) directly applied to parameters (PyTorch handles mixed precision)
3738
- Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability
38-
- Muon gradients: cast to parameter dtype before momentum update
39+
- Muon momentum buffer: follows gradient dtype (grad -> buffer -> update)
3940
- Adam gradients: cast to float32 for update computation
4041
4142
References
@@ -75,10 +76,8 @@
7576

7677
# Newton-Schulz iteration count
7778
NS_STEPS: int = 5
78-
# Numerical stability epsilon for norm clamping
79-
NS_EPS: float = 1e-7
80-
# Adam epsilon for numerical stability
81-
ADAM_EPS: float = 1e-7
79+
# Numerical stability epsilon for norm clamping and Adam
80+
EPS: float = 1e-7
8281
# Quintic Newton-Schulz polynomial coefficients
8382
NS_COEFF_A: float = 3.4445
8483
NS_COEFF_B: float = -4.7750
@@ -118,7 +117,7 @@ def _zeropower_via_newtonschulz5_2d(
118117
X = X.transpose(-2, -1)
119118

120119
# === Step 2. Normalize Frobenius norm to at most 1 ===
121-
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS)
120+
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS)
122121

123122
# === Step 3. Newton-Schulz iterations with fused GEMM ===
124123
for _ in range(NS_STEPS):
@@ -152,7 +151,7 @@ def _zeropower_via_newtonschulz5_3d(
152151
X = X.transpose(-2, -1)
153152

154153
# === Step 2. Normalize Frobenius norm to at most 1 ===
155-
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS)
154+
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=EPS)
156155

157156
# === Step 3. Newton-Schulz iterations with batched fused GEMM ===
158157
for _ in range(NS_STEPS):
@@ -270,7 +269,7 @@ class HybridMuonOptimizer(Optimizer):
270269
momentum : float
271270
Momentum coefficient for Muon with default 0.95.
272271
weight_decay : float
273-
Weight decay coefficient (applied only to >=2D params) with default 0.001.
272+
Weight decay coefficient (applied only to Muon-routed parameters) with default 0.001.
274273
adam_betas : tuple[float, float]
275274
Adam beta coefficients with default (0.9, 0.95).
276275
lr_adjust : float
@@ -287,6 +286,11 @@ class HybridMuonOptimizer(Optimizer):
287286
2. For 2D Adam fallback: learning rate multiplier,
288287
adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1).
289288
The min(., 0.1) cap ensures conservative updates for small matrices.
289+
muon_2d_only : bool
290+
If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon).
291+
Parameters with ndim > 2 use Adam without weight decay.
292+
If False, all >=2D parameters use Muon (default behavior).
293+
Default is True.
290294
min_2d_dim : int
291295
Minimum min(m, n) threshold for Muon on 2D matrices.
292296
Matrices with min(m, n) >= min_2d_dim use Muon;
@@ -313,6 +317,7 @@ def __init__(
313317
adam_betas: tuple[float, float] = (0.9, 0.95),
314318
lr_adjust: float = 10.0,
315319
lr_adjust_coeff: float = 0.2,
320+
muon_2d_only: bool = True,
316321
min_2d_dim: int = 1,
317322
) -> None:
318323
if min_2d_dim < 1:
@@ -325,6 +330,7 @@ def __init__(
325330
"adam_betas": adam_betas,
326331
"lr_adjust": lr_adjust,
327332
"lr_adjust_coeff": lr_adjust_coeff,
333+
"muon_2d_only": muon_2d_only,
328334
"min_2d_dim": min_2d_dim,
329335
}
330336
super().__init__(params, defaults)
@@ -337,9 +343,11 @@ def _build_param_routing(self) -> None:
337343
Classify parameters into Muon and Adam routes (static routing).
338344
339345
Routing logic:
340-
- >=2D parameters with min(m, n) >= min_2d_dim → Muon path
341-
- 2D parameters with min(m, n) < min_2d_dim → Adam fallback path
342346
- 1D parameters → Adam path
347+
- >2D parameters (when muon_2d_only=True) → Adam path
348+
- 2D parameters with min(m, n) < min_2d_dim → Adam fallback path
349+
- 2D parameters with min(m, n) >= min_2d_dim → Muon path
350+
- >=2D parameters (when muon_2d_only=False) → Muon path
343351
"""
344352
if self._routing_built:
345353
return
@@ -349,14 +357,23 @@ def _build_param_routing(self) -> None:
349357
muon_params: list[dict[str, Any]] = []
350358
adam_1d: list[dict[str, Any]] = []
351359
adam_matrix: list[dict[str, Any]] = []
360+
adam_nd: list[dict[str, Any]] = []
352361

353362
min_2d_dim = group["min_2d_dim"]
363+
muon_2d_only = group["muon_2d_only"]
354364

355365
for p in group["params"]:
366+
# === Step 1. 1D parameters → Adam ===
356367
if p.ndim < 2:
357368
adam_1d.append({"param": p})
358369
continue
359370

371+
# === Step 2. >2D parameters (when muon_2d_only=True) → Adam ===
372+
if muon_2d_only and p.ndim > 2:
373+
adam_nd.append({"param": p})
374+
continue
375+
376+
# === Step 3. 2D small matrices → Adam fallback ===
360377
if (p.ndim == 2) and should_fallback_to_adam_for_matrix(
361378
p, min_2d_dim=min_2d_dim
362379
):
@@ -368,6 +385,7 @@ def _build_param_routing(self) -> None:
368385
)
369386
continue
370387

388+
# === Step 4. >=2D (or 2D only when muon_2d_only=True) → Muon ===
371389
muon_params.append(
372390
{
373391
"param": p,
@@ -381,6 +399,7 @@ def _build_param_routing(self) -> None:
381399
"muon_params": muon_params,
382400
"adam_1d": adam_1d,
383401
"adam_matrix": adam_matrix,
402+
"adam_nd": adam_nd,
384403
}
385404
)
386405

@@ -470,12 +489,67 @@ def step(
470489
bias_corr2 = 1 - state["beta2_pow"]
471490
step_size = adam_lr / bias_corr1
472491
# delta = -step_size * m_hat / (sqrt(v_hat) + eps)
473-
denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS)
492+
denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS)
474493
delta_fp32 = -step_size * (adam_exp_avgs[i] / denom)
475494
p.add_(delta_fp32.to(p.dtype))
476495

477-
# === Step 2. Adam update for small 2D matrices (fallback path) ===
496+
# === Step 2. Adam update for >2D parameters (when muon_2d_only=True) ===
478497
# === Step 2.1. Collect gradients and initialize state ===
498+
adam_nd_params: list[torch.Tensor] = []
499+
adam_nd_grads_fp32: list[torch.Tensor] = []
500+
adam_nd_exp_avgs: list[torch.Tensor] = []
501+
adam_nd_exp_avg_sqs: list[torch.Tensor] = []
502+
adam_nd_states: list[dict[str, Any]] = []
503+
504+
for entry in route.get("adam_nd", []):
505+
p = entry["param"]
506+
grad = p.grad
507+
if grad is None:
508+
continue
509+
510+
grad_fp32 = grad.float()
511+
512+
state = self.state[p]
513+
if "exp_avg" not in state:
514+
state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32)
515+
state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32)
516+
state["beta1_pow"] = 1.0
517+
state["beta2_pow"] = 1.0
518+
519+
state["beta1_pow"] *= adam_betas[0]
520+
state["beta2_pow"] *= adam_betas[1]
521+
522+
adam_nd_params.append(p)
523+
adam_nd_grads_fp32.append(grad_fp32)
524+
adam_nd_exp_avgs.append(state["exp_avg"])
525+
adam_nd_exp_avg_sqs.append(state["exp_avg_sq"])
526+
adam_nd_states.append(state)
527+
528+
if adam_nd_params:
529+
# === Step 2.2. Update exp_avg / exp_avg_sq ===
530+
adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
531+
532+
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
533+
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
534+
torch._foreach_lerp_(
535+
adam_nd_exp_avgs, adam_nd_grads_fp32, 1 - adam_betas[0]
536+
)
537+
grad_sq = torch._foreach_mul(adam_nd_grads_fp32, adam_nd_grads_fp32)
538+
torch._foreach_lerp_(adam_nd_exp_avg_sqs, grad_sq, 1 - adam_betas[1])
539+
540+
# === Step 2.3. Bias correction and parameter update ===
541+
for i, p in enumerate(adam_nd_params):
542+
state = adam_nd_states[i]
543+
bias_corr1 = 1 - state["beta1_pow"]
544+
bias_corr2 = 1 - state["beta2_pow"]
545+
step_size = adam_lr / bias_corr1
546+
# delta = -step_size * m_hat / (sqrt(v_hat) + eps)
547+
denom = (adam_nd_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS)
548+
delta_fp32 = -step_size * (adam_nd_exp_avgs[i] / denom)
549+
p.add_(delta_fp32.to(p.dtype))
550+
551+
# === Step 3. Adam update for small 2D matrices (fallback path) ===
552+
# === Step 3.1. Collect gradients and initialize state ===
479553
adam_matrix_params: list[torch.Tensor] = []
480554
adam_matrix_grads_fp32: list[torch.Tensor] = []
481555
adam_matrix_exp_avgs: list[torch.Tensor] = []
@@ -509,7 +583,7 @@ def step(
509583
adam_matrix_abs_floor.append(entry["abs_floor"])
510584

511585
if adam_matrix_params:
512-
# === Step 2.2. Update exp_avg / exp_avg_sq with scaled lr ===
586+
# === Step 3.2. Update exp_avg / exp_avg_sq with scaled lr ===
513587
adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
514588
adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1)
515589

@@ -525,19 +599,17 @@ def step(
525599
adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1]
526600
)
527601

528-
# === Step 2.3. Compute unclipped deltas ===
602+
# === Step 3.3. Compute unclipped deltas ===
529603
raw_deltas: list[torch.Tensor] = []
530604
for i in range(len(adam_matrix_params)):
531605
state = adam_matrix_states[i]
532606
bias_corr1 = 1 - state["beta1_pow"]
533607
bias_corr2 = 1 - state["beta2_pow"]
534608
step_size = adam_lr_matrix / bias_corr1
535-
denom = (
536-
(adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS)
537-
)
609+
denom = (adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(EPS)
538610
raw_deltas.append(-step_size * (adam_matrix_exp_avgs[i] / denom))
539611

540-
# === Step 2.4. Clip updates by relative norm and apply ===
612+
# === Step 3.4. Clip updates by relative norm and apply ===
541613
max_rel_change = 0.05
542614
p_norms = torch.stack(torch._foreach_norm(adam_matrix_params))
543615
delta_norms = torch.stack(torch._foreach_norm(raw_deltas))
@@ -553,8 +625,8 @@ def step(
553625
):
554626
p.add_(delta.mul_(scales_tensor[i]).to(p.dtype))
555627

556-
# === Step 3. Muon update for >=2D parameters (weight matrices) ===
557-
# === Step 3.1. Collect gradients and initialize momentum ===
628+
# === Step 4. Muon update for >=2D parameters (weight matrices) ===
629+
# === Step 4.1. Collect gradients and initialize momentum ===
558630
muon_params_for_decay: list[torch.Tensor] = []
559631
muon_grads: list[torch.Tensor] = []
560632
muon_momentum_buffers: list[torch.Tensor] = []
@@ -579,22 +651,22 @@ def step(
579651
muon_momentum_buffers.append(buf)
580652
active_entries.append((entry, grad))
581653

582-
# === Step 3.2. Apply weight decay (Muon path only) ===
654+
# === Step 4.2. Apply weight decay (Muon path only) ===
583655
if weight_decay > 0 and muon_params_for_decay:
584656
torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay)
585657

586658
if not active_entries:
587659
continue
588660

589-
# === Step 3.3. Momentum update (Nesterov) ===
661+
# === Step 4.3. Momentum update (Nesterov) ===
590662
# m_t = beta * m_{t-1} + (1 - beta) * g_t
591663
torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum)
592664
# update = beta * m_t + (1 - beta) * g_t
593665
muon_updates = torch._foreach_lerp(
594666
muon_grads, muon_momentum_buffers, momentum
595667
)
596668

597-
# === Step 3.4. Bucket by shape/device/dtype for batched NS ===
669+
# === Step 4.4. Bucket by shape/device/dtype for batched NS ===
598670
buckets: dict[
599671
tuple[int, int, torch.device, torch.dtype],
600672
list[tuple[dict[str, Any], torch.Tensor]],
@@ -608,8 +680,8 @@ def step(
608680
buckets[bucket_key] = []
609681
buckets[bucket_key].append((entry, muon_updates[idx]))
610682

611-
# === Step 3.5. Newton-Schulz orthogonalization and update ===
612-
for (rows, cols, _device, dtype), bucket_entries in buckets.items():
683+
# === Step 4.5. Newton-Schulz orthogonalization and update ===
684+
for (rows, cols, _device, _), bucket_entries in buckets.items():
613685
# scale = coeff * sqrt(max(m, n)) [match-RMS mode]
614686
# scale = sqrt(max(1, m/n)) [rectangular mode]
615687
if lr_adjust <= 0:
@@ -626,8 +698,6 @@ def step(
626698
orth = _zeropower_via_newtonschulz5_2d(update_matrix)
627699
orth.mul_(scale)
628700
delta = orth.reshape(entry["param"].shape)
629-
if delta.dtype != dtype:
630-
delta = delta.to(dtype)
631701
entry["param"].add_(delta, alpha=-lr)
632702
continue
633703

@@ -648,8 +718,6 @@ def step(
648718
stacked = torch.stack(matrices, dim=0)
649719
orth = _zeropower_via_newtonschulz5_3d(stacked)
650720
orth.mul_(scale)
651-
if orth.dtype != dtype:
652-
orth = orth.to(dtype)
653721

654722
for i, _ in enumerate(bucket_entries):
655723
params[i].add_(orth[i].reshape(orig_shapes[i]), alpha=-lr)

deepmd/pt/train/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
173173
"adam_beta2": params.get("adam_beta2", 0.95),
174174
"lr_adjust": params.get("lr_adjust", 10.0),
175175
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
176+
"muon_2d_only": params.get("muon_2d_only", True),
176177
"min_2d_dim": params.get("min_2d_dim", 1),
177178
}
178179
return opt_type, opt_param
@@ -742,6 +743,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
742743
),
743744
lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)),
744745
lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)),
746+
muon_2d_only=bool(self.opt_param.get("muon_2d_only", True)),
745747
min_2d_dim=int(self.opt_param.get("min_2d_dim", 1)),
746748
)
747749
if optimizer_state_dict is not None and self.restart_training:

deepmd/utils/argcheck.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3487,7 +3487,7 @@ def training_args(
34873487
optional=True,
34883488
default=0.001,
34893489
doc=doc_only_pt_supported
3490-
+ "Weight decay coefficient. Applied only to >=2D parameters (HybridMuon path).",
3490+
+ "Weight decay coefficient. Applied only to Muon-routed parameters",
34913491
),
34923492
Argument(
34933493
"lr_adjust",
@@ -3508,6 +3508,16 @@ def training_args(
35083508
doc=doc_only_pt_supported
35093509
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
35103510
),
3511+
Argument(
3512+
"muon_2d_only",
3513+
bool,
3514+
optional=True,
3515+
default=True,
3516+
doc=doc_only_pt_supported
3517+
+ "If True, only 2D parameters use Muon (matching PyTorch's torch.optim.Muon). "
3518+
+ "Parameters with ndim > 2 use Adam without weight decay. "
3519+
+ "If False, all >=2D parameters use Muon.",
3520+
),
35113521
Argument(
35123522
"min_2d_dim",
35133523
int,

0 commit comments

Comments
 (0)