Skip to content

Commit 1978c7f

Browse files
committed
rename custom muon to hybridmuon
1 parent bacae68 commit 1978c7f

File tree

5 files changed

+59
-37
lines changed

5 files changed

+59
-37
lines changed

deepmd/pt/optimizer/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,19 @@
22
from .adamuon import (
33
AdaMuonOptimizer,
44
)
5+
from .hybrid_muon import (
6+
HybridMuonOptimizer,
7+
)
58
from .KFWrapper import (
69
KFOptimizerWrapper,
710
)
811
from .LKF import (
912
LKFOptimizer,
1013
)
11-
from .muon import (
12-
MuonOptimizer,
13-
)
1414

15-
__all__ = ["AdaMuonOptimizer", "KFOptimizerWrapper", "LKFOptimizer", "MuonOptimizer"]
15+
__all__ = [
16+
"AdaMuonOptimizer",
17+
"HybridMuonOptimizer",
18+
"KFOptimizerWrapper",
19+
"LKFOptimizer",
20+
]
Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
"""
3-
Muon optimizer for DeePMD-kit PyTorch backend.
3+
HybridMuon optimizer for DeePMD-kit PyTorch backend.
44
5-
Muon is an optimizer that applies Newton-Schulz orthogonalization to the gradient
6-
before using momentum, resulting in orthogonalized updates for weight matrices.
7-
This can improve training stability and convergence for certain architectures.
5+
HybridMuon is a HYBRID optimizer that automatically combines Muon and Adam:
6+
- For >=2D parameters with min(m,n) >= min_2d_dim: Muon update with Newton-Schulz
7+
- For 2D parameters with min(m,n) < min_2d_dim: Adam fallback with update clipping
8+
- For 1D parameters (biases, layer norms): Standard Adam
9+
10+
This is different from PyTorch's torch.optim.Muon, which ONLY supports 2D parameters
11+
and requires manual configuration of AdamW for 1D parameters. HybridMuon provides
12+
automatic routing based on parameter dimensionality.
813
914
Algorithm
1015
---------
@@ -33,9 +38,15 @@
3338
- Muon gradients: cast to parameter dtype before momentum update
3439
- Adam gradients: cast to float32 for update computation
3540
36-
Reference
37-
---------
38-
https://github.com/KellerJordan/Muon
41+
References
42+
----------
43+
.. [1] Keller Jordan, "Muon: An optimizer for hidden layers in neural networks."
44+
https://kellerjordan.github.io/posts/muon/
45+
https://github.com/KellerJordan/Muon
46+
.. [2] Moonshot team, "Muon is Scalable for LLM Training," arXiv:2502.16982, 2025.
47+
https://arxiv.org/abs/2502.16982
48+
.. [3] Moonlight GitHub Repository.
49+
https://github.com/MoonshotAI/Moonlight
3950
"""
4051

4152
from __future__ import (
@@ -223,9 +234,9 @@ def should_fallback_to_adam_for_matrix(
223234
return min(m, n) < min_2d_dim
224235

225236

226-
class MuonOptimizer(Optimizer):
237+
class HybridMuonOptimizer(Optimizer):
227238
"""
228-
Muon optimizer with small-2D Adam fallback and 1D Adam path.
239+
HybridMuon optimizer with small-2D Adam fallback and 1D Adam path.
229240
230241
This optimizer applies different update rules based on parameter dimensionality:
231242
- For >=2D parameters with min(m, n) >= min_2d_dim:
@@ -286,7 +297,7 @@ class MuonOptimizer(Optimizer):
286297
287298
Examples
288299
--------
289-
>>> optimizer = MuonOptimizer(model.parameters(), lr=1e-3)
300+
>>> optimizer = HybridMuonOptimizer(model.parameters(), lr=1e-3)
290301
>>> for epoch in range(epochs):
291302
... optimizer.zero_grad()
292303
... loss.backward()

deepmd/pt/train/training.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
)
4444
from deepmd.pt.optimizer import (
4545
AdaMuonOptimizer,
46+
HybridMuonOptimizer,
4647
KFOptimizerWrapper,
4748
LKFOptimizer,
48-
MuonOptimizer,
4949
)
5050
from deepmd.pt.train.wrapper import (
5151
ModelWrapper,
@@ -730,8 +730,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
730730
lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)),
731731
lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)),
732732
)
733-
elif self.opt_type == "Muon":
734-
self.optimizer = MuonOptimizer(
733+
elif self.opt_type == "HybridMuon":
734+
self.optimizer = HybridMuonOptimizer(
735735
self.wrapper.parameters(),
736736
lr=self.lr_exp.start_lr,
737737
momentum=float(self.opt_param.get("momentum", 0.95)),
@@ -820,7 +820,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
820820
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
821821
fout1.write(print_str)
822822
fout1.flush()
823-
if self.opt_type in ["Adam", "AdamW", "AdaMuon", "Muon"]:
823+
if self.opt_type in ["Adam", "AdamW", "AdaMuon", "HybridMuon"]:
824824
cur_lr = self.scheduler.get_last_lr()[0]
825825
if _step_id < self.warmup_steps:
826826
pref_lr = _lr.start_lr

deepmd/utils/argcheck.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3452,7 +3452,7 @@ def training_args(
34523452
optional=True,
34533453
),
34543454
Argument(
3455-
"Muon",
3455+
"HybridMuon",
34563456
dict,
34573457
[
34583458
Argument(
@@ -3462,7 +3462,7 @@ def training_args(
34623462
default=0.95,
34633463
alias=["muon_momentum"],
34643464
doc=doc_only_pt_supported
3465-
+ "Momentum coefficient for Muon optimizer (>=2D params). "
3465+
+ "Momentum coefficient for HybridMuon optimizer (>=2D params). "
34663466
"Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.",
34673467
),
34683468
Argument(
@@ -3487,15 +3487,15 @@ def training_args(
34873487
optional=True,
34883488
default=0.001,
34893489
doc=doc_only_pt_supported
3490-
+ "Weight decay coefficient. Applied only to >=2D parameters (Muon path).",
3490+
+ "Weight decay coefficient. Applied only to >=2D parameters (HybridMuon path).",
34913491
),
34923492
Argument(
34933493
"lr_adjust",
34943494
float,
34953495
optional=True,
34963496
default=10.0,
34973497
doc=doc_only_pt_supported
3498-
+ "Learning rate adjustment mode for Muon scaling and Adam learning rate. "
3498+
+ "Learning rate adjustment mode for HybridMuon scaling and Adam learning rate. "
34993499
"If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. "
35003500
"If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. "
35013501
"Default is 10.0 (Adam lr = lr/10).",
@@ -3515,14 +3515,20 @@ def training_args(
35153515
default=1,
35163516
alias=["muon_min_2d_dim"],
35173517
doc=doc_only_pt_supported
3518-
+ "Minimum min(m, n) threshold for Muon on 2D matrices. "
3519-
"Matrices with min(m, n) >= min_2d_dim use Muon; "
3518+
+ "Minimum min(m, n) threshold for HybridMuon on 2D matrices. "
3519+
"Matrices with min(m, n) >= min_2d_dim use HybridMuon; "
35203520
"those with min(m, n) < min_2d_dim use Adam fallback. "
35213521
"Set to 1 to disable fallback.",
35223522
),
35233523
],
35243524
[],
35253525
optional=True,
3526+
doc=doc_only_pt_supported
3527+
+ "HybridMuon optimizer (DeePMD-kit custom implementation). "
3528+
+ "This is a Hybrid optimizer that automatically combines Muon and Adam. "
3529+
+ "For >=2D params: Muon update with Newton-Schulz. "
3530+
+ "For 1D params: Standard Adam. "
3531+
+ "This is DIFFERENT from PyTorch's torch.optim.Muon which ONLY supports 2D parameters.",
35263532
),
35273533
],
35283534
optional=True,
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
import torch
55

6-
from deepmd.pt.optimizer.muon import (
7-
MuonOptimizer,
6+
from deepmd.pt.optimizer.hybrid_muon import (
7+
HybridMuonOptimizer,
88
zeropower_via_newtonschulz5,
99
)
1010
from deepmd.pt.utils import (
@@ -82,8 +82,8 @@ def test_invalid_input(self) -> None:
8282

8383

8484
@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device")
85-
class TestMuonOptimizer(unittest.TestCase):
86-
"""Test MuonOptimizer class."""
85+
class TestHybridMuonOptimizer(unittest.TestCase):
86+
"""Test HybridMuonOptimizer class."""
8787

8888
def setUp(self) -> None:
8989
self.device = env.DEVICE
@@ -96,7 +96,7 @@ def test_step(self) -> None:
9696
torch.nn.ReLU(),
9797
torch.nn.Linear(20, 5, device=self.device),
9898
)
99-
optimizer = MuonOptimizer(model.parameters(), lr=0.02)
99+
optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02)
100100

101101
x = torch.randn(4, 10, device=self.device)
102102
model(x).sum().backward()
@@ -111,7 +111,7 @@ def test_weight_decay(self) -> None:
111111
"""Test weight decay reduces parameter norm."""
112112
torch.manual_seed(42)
113113
model = torch.nn.Linear(10, 10, device=self.device)
114-
optimizer = MuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1)
114+
optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02, weight_decay=0.1)
115115

116116
initial_norm = model.weight.norm().item()
117117
for _ in range(10):
@@ -126,7 +126,7 @@ def test_muon_adam_separation(self) -> None:
126126
"""Test Muon for 2D params, Adam for 1D params."""
127127
torch.manual_seed(42)
128128
model = torch.nn.Linear(10, 10, device=self.device)
129-
optimizer = MuonOptimizer(model.parameters(), lr=0.02)
129+
optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02)
130130

131131
x = torch.randn(4, 10, device=self.device)
132132
model(x).sum().backward()
@@ -145,7 +145,7 @@ def test_muon_adam_fallback_small_2d(self) -> None:
145145
torch.manual_seed(42)
146146
linear_small = torch.nn.Linear(10, 1, bias=False, device=self.device)
147147
linear_large = torch.nn.Linear(10, 10, bias=False, device=self.device)
148-
optimizer = MuonOptimizer(
148+
optimizer = HybridMuonOptimizer(
149149
list(linear_small.parameters()) + list(linear_large.parameters()),
150150
lr=0.02,
151151
min_2d_dim=2,
@@ -172,8 +172,8 @@ def test_lr_adjust_modes(self) -> None:
172172
model2 = torch.nn.Linear(10, 20, bias=False, device=self.device)
173173
model2.load_state_dict(model1.state_dict())
174174

175-
opt1 = MuonOptimizer(model1.parameters(), lr=0.02, lr_adjust=0.0)
176-
opt2 = MuonOptimizer(model2.parameters(), lr=0.02, lr_adjust=10.0)
175+
opt1 = HybridMuonOptimizer(model1.parameters(), lr=0.02, lr_adjust=0.0)
176+
opt2 = HybridMuonOptimizer(model2.parameters(), lr=0.02, lr_adjust=10.0)
177177

178178
x = torch.randn(4, 10, device=self.device)
179179

@@ -192,7 +192,7 @@ def test_lr_adjust_modes(self) -> None:
192192

193193

194194
@unittest.skipIf(not BF16_SUPPORTED, "bf16 matmul not supported on this device")
195-
class TestMuonOptimizerStateDict(unittest.TestCase):
195+
class TestHybridMuonOptimizerStateDict(unittest.TestCase):
196196
"""Test optimizer state dict save/load."""
197197

198198
def setUp(self) -> None:
@@ -202,7 +202,7 @@ def test_state_dict_save_load(self) -> None:
202202
"""Test saving and loading optimizer state."""
203203
torch.manual_seed(42)
204204
model = torch.nn.Linear(10, 10, device=self.device)
205-
optimizer = MuonOptimizer(model.parameters(), lr=0.02)
205+
optimizer = HybridMuonOptimizer(model.parameters(), lr=0.02)
206206

207207
for _ in range(3):
208208
optimizer.zero_grad()
@@ -212,7 +212,7 @@ def test_state_dict_save_load(self) -> None:
212212

213213
state_dict = optimizer.state_dict()
214214

215-
optimizer2 = MuonOptimizer(model.parameters(), lr=0.02)
215+
optimizer2 = HybridMuonOptimizer(model.parameters(), lr=0.02)
216216
optimizer2.load_state_dict(state_dict)
217217

218218
# Verify state matches by param id, not iteration order

0 commit comments

Comments
 (0)