@@ -3447,6 +3447,67 @@ def training_args(
34473447 doc = doc_only_pt_supported
34483448 + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0." ,
34493449 ),
3450+ ],
3451+ [],
3452+ optional = True ,
3453+ ),
3454+ Argument (
3455+ "Muon" ,
3456+ dict ,
3457+ [
3458+ Argument (
3459+ "momentum" ,
3460+ float ,
3461+ optional = True ,
3462+ default = 0.95 ,
3463+ alias = ["muon_momentum" ],
3464+ doc = doc_only_pt_supported
3465+ + "Momentum coefficient for Muon optimizer (>=2D params). "
3466+ "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t." ,
3467+ ),
3468+ Argument (
3469+ "adam_beta1" ,
3470+ float ,
3471+ optional = True ,
3472+ default = 0.9 ,
3473+ doc = doc_only_pt_supported
3474+ + "Adam beta1 coefficient for 1D parameters (biases, norms)." ,
3475+ ),
3476+ Argument (
3477+ "adam_beta2" ,
3478+ float ,
3479+ optional = True ,
3480+ default = 0.95 ,
3481+ doc = doc_only_pt_supported
3482+ + "Adam beta2 coefficient for 1D parameters (biases, norms)." ,
3483+ ),
3484+ Argument (
3485+ "weight_decay" ,
3486+ float ,
3487+ optional = True ,
3488+ default = 0.001 ,
3489+ doc = doc_only_pt_supported
3490+ + "Weight decay coefficient. Applied only to >=2D parameters (Muon path)." ,
3491+ ),
3492+ Argument (
3493+ "lr_adjust" ,
3494+ float ,
3495+ optional = True ,
3496+ default = 10.0 ,
3497+ doc = doc_only_pt_supported
3498+ + "Learning rate adjustment mode for Muon scaling and Adam learning rate. "
3499+ "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. "
3500+ "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. "
3501+ "Default is 10.0 (Adam lr = lr/10)." ,
3502+ ),
3503+ Argument (
3504+ "lr_adjust_coeff" ,
3505+ float ,
3506+ optional = True ,
3507+ default = 0.2 ,
3508+ doc = doc_only_pt_supported
3509+ + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0." ,
3510+ ),
34503511 Argument (
34513512 "min_2d_dim" ,
34523513 int ,
0 commit comments