Skip to content

Commit c389ffc

Browse files
committed
fix(pt): compatible with AdaMuon
1 parent e9bbef8 commit c389ffc

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

deepmd/pt/train/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
727727
float(self.opt_param.get("adam_beta1", 0.9)),
728728
float(self.opt_param.get("adam_beta2", 0.95)),
729729
),
730+
lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)),
731+
lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)),
730732
)
731733
elif self.opt_type == "Muon":
732734
self.optimizer = MuonOptimizer(

deepmd/utils/argcheck.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3447,6 +3447,67 @@ def training_args(
34473447
doc=doc_only_pt_supported
34483448
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
34493449
),
3450+
],
3451+
[],
3452+
optional=True,
3453+
),
3454+
Argument(
3455+
"Muon",
3456+
dict,
3457+
[
3458+
Argument(
3459+
"momentum",
3460+
float,
3461+
optional=True,
3462+
default=0.95,
3463+
alias=["muon_momentum"],
3464+
doc=doc_only_pt_supported
3465+
+ "Momentum coefficient for Muon optimizer (>=2D params). "
3466+
"Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.",
3467+
),
3468+
Argument(
3469+
"adam_beta1",
3470+
float,
3471+
optional=True,
3472+
default=0.9,
3473+
doc=doc_only_pt_supported
3474+
+ "Adam beta1 coefficient for 1D parameters (biases, norms).",
3475+
),
3476+
Argument(
3477+
"adam_beta2",
3478+
float,
3479+
optional=True,
3480+
default=0.95,
3481+
doc=doc_only_pt_supported
3482+
+ "Adam beta2 coefficient for 1D parameters (biases, norms).",
3483+
),
3484+
Argument(
3485+
"weight_decay",
3486+
float,
3487+
optional=True,
3488+
default=0.001,
3489+
doc=doc_only_pt_supported
3490+
+ "Weight decay coefficient. Applied only to >=2D parameters (Muon path).",
3491+
),
3492+
Argument(
3493+
"lr_adjust",
3494+
float,
3495+
optional=True,
3496+
default=10.0,
3497+
doc=doc_only_pt_supported
3498+
+ "Learning rate adjustment mode for Muon scaling and Adam learning rate. "
3499+
"If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. "
3500+
"If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. "
3501+
"Default is 10.0 (Adam lr = lr/10).",
3502+
),
3503+
Argument(
3504+
"lr_adjust_coeff",
3505+
float,
3506+
optional=True,
3507+
default=0.2,
3508+
doc=doc_only_pt_supported
3509+
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
3510+
),
34503511
Argument(
34513512
"min_2d_dim",
34523513
int,

0 commit comments

Comments
 (0)