Skip to content

Commit 546b763

Browse files
xxyuxclaude
andcommitted
feat: support muon optimizer in PaddleFormers trainer
- Update trainer.py to integrate Muon optimizer support - Update trainer_utils.py with Muon-related utilities Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 765d4ed commit 546b763

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

paddleformers/trainer/trainer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1991,7 +1991,8 @@ def _inner_training_loop(
19911991
steps_trained_progress_bar.update(1)
19921992
if steps_trained_in_current_epoch == 0:
19931993
self._load_rng_state(resume_from_checkpoint)
1994-
self.timers and self.timers("read-data").start()
1994+
if self.args.ignore_data_skip:
1995+
self.timers and self.timers("read-data").start()
19951996
# Reset data loading timer for skipped steps
19961997
_data_load_start_time = time.time()
19971998
continue
@@ -2912,6 +2913,15 @@ def apply_decay_param_fun(x):
29122913
if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2":
29132914
optimizer_kwargs["multi_precision"] = True
29142915

2916+
if self.args.optim.value == "muon":
2917+
# Attach per-head metadata to fused QKV weights so the Muon
2918+
# optimizer can orthogonalise each head independently.
2919+
for name, param in self.model.named_parameters():
2920+
if "qkv_proj.weight" in name and len(param.shape) == 2:
2921+
param.needs_qkv_split = True
2922+
param.head_num = self.model.config.num_attention_heads
2923+
param.kv_head_num = self.model.config.num_key_value_heads
2924+
29152925
self.optimizer = optimizer_cls(
29162926
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
29172927
apply_decay_param_fun=apply_decay_param_fun,
@@ -3052,6 +3062,18 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
30523062

30533063
optimizer_cls = AdamWCustom
30543064
optimizer_kwargs.update(adam_kwargs)
3065+
elif args.optim == OptimizerNames.MUON:
3066+
from paddle.optimizer import Muon
3067+
3068+
logger.info("Creating Muon optimizer")
3069+
muon_kwargs = {
3070+
**adam_kwargs,
3071+
"momentum": 0.95,
3072+
"muon_version": 3,
3073+
"is_split_qkv": True,
3074+
}
3075+
optimizer_cls = Muon
3076+
optimizer_kwargs.update(muon_kwargs)
30553077
else:
30563078
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
30573079

paddleformers/trainer/trainer_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ class OptimizerNames(ExplicitEnum):
498498
ADAFACTOR = "adafactor"
499499
ADAMW_MINI = "adamw_mini"
500500
ADAMW_CUSTOM = "adamw_custom"
501+
MUON = "muon"
501502

502503

503504
class ShardingOption(ExplicitEnum):
@@ -1502,6 +1503,12 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15021503
return
15031504

15041505
elif DygraphShardingOptimizerV2 is not None and isinstance(inner_opt, DygraphShardingOptimizerV2):
1506+
# Unwrap to the innermost optimizer (e.g. Muon inside a sharding wrapper).
1507+
core_opt = optimizer._inner_opt
1508+
while hasattr(core_opt, "_inner_opt"):
1509+
core_opt = core_opt._inner_opt
1510+
is_muon_opt = type(core_opt).__name__ == "Muon"
1511+
15051512
parameter_list = []
15061513
for buffer in optimizer._comm_buffer_list:
15071514
for param_name, grad_view in buffer._sharding_param_grad_view.items():
@@ -1515,6 +1522,9 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15151522
slice_param = paddle.slice(param_buffer, axes=[0], starts=[param_begin], ends=[param_end])
15161523
assert slice_param.numel().item() > 0
15171524
slice_param.name = param_name
1525+
# Preserve original shape so Muon's should_use_muon() can identify 2-D weights.
1526+
if is_muon_opt and hasattr(grad_view, "_param") and grad_view._param is not None:
1527+
slice_param.original_shape = grad_view._param.shape
15181528
parameter_list.append(slice_param)
15191529

15201530
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), parameter_list)

0 commit comments

Comments
 (0)