Skip to content

Commit 7db2abc

Browse files
committed
feat: add Muon optimizer integration and ShardingV3 support
Muon optimizer integration: - Create Muon optimizer in trainer when `optim=muon`, with per-head QKV metadata annotation for fused QKV weight orthogonalisation - Handle Muon's `_moment_acc_str` (vs AdamW's `_moment1_acc_str`) in optimizer state save/restore - Add Muon `_muon_update`/`_apply_optimize` offload support in `offload_optimizer.py` ShardingV3 support: - Add `sharding_v3` training argument and `FLAGS_sharding_v3` environment variable dispatch - Implement `DygraphShardingOptimizerV3` init path in `trainer_utils.py` - Add V3 reshard logic (`reshard/sharding_v3.py`) for checkpoint save/restore - Adapt `sharding_io.py`, `zero_cost_checkpoint.py`, and `moe_hybrid_parallel_optimizer.py` for V3 optimizer unwrapping Tests: - Add Muon smoke tests (`tests/muon/`) exercising both V2 and V3 sharding paths on 2 GPUs with AMP O2
1 parent ed15c99 commit 7db2abc

File tree

13 files changed

+588
-18
lines changed

13 files changed

+588
-18
lines changed

paddleformers/trainer/trainer.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@
8787
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import (
8888
DygraphShardingOptimizerV2,
8989
)
90+
91+
try:
92+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer_v3 import (
93+
DygraphShardingOptimizerV3,
94+
)
95+
except ImportError:
96+
DygraphShardingOptimizerV3 = None
9097
from paddle.distributed.fleet.utils.hybrid_parallel_util import (
9198
fused_allreduce_gradients,
9299
)
@@ -211,7 +218,7 @@
211218
from .utils import reshard as reshard_util
212219
from .utils.async_save import AsyncSaver
213220
from .utils.ckpt_converter import CheckpointConverter
214-
from .utils.reshard import SHARDING_STRATEGY_V1, split_opt_state
221+
from .utils.reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V3, split_opt_state
215222
from .utils.sharding_io import GroupGetter, to_device
216223

217224
try:
@@ -1215,7 +1222,10 @@ def get_metadata_file_name(path):
12151222
enable_bf16_opt = (
12161223
not isinstance(self.model, LoRAModel)
12171224
and self.args.bf16
1218-
and isinstance(self.optimizer._inner_opt, DygraphShardingOptimizerV2)
1225+
and isinstance(
1226+
self.optimizer._inner_opt,
1227+
(DygraphShardingOptimizerV2,) + ((DygraphShardingOptimizerV3,) if DygraphShardingOptimizerV3 else ()),
1228+
)
12191229
)
12201230
logger.debug(f"sharded_model_from_ema: {self.args.sharded_model_from_ema}")
12211231
logger.debug(f"enable_bf16_opt: {enable_bf16_opt}")
@@ -1277,11 +1287,12 @@ def recover_params_from_master_weight(opt_state_dict, group):
12771287
del node_model_state_tmp
12781288
sharding_strategy = reshard_util.get_sharding_strategy(self.optimizer)
12791289
logger.debug(f"sharding_strategy: {sharding_strategy}")
1280-
restore_func = (
1281-
reshard_util.sharding_v1.restore
1282-
if sharding_strategy == SHARDING_STRATEGY_V1
1283-
else reshard_util.sharding_v2.restore
1284-
)
1290+
if sharding_strategy == SHARDING_STRATEGY_V1:
1291+
restore_func = reshard_util.sharding_v1.restore
1292+
elif sharding_strategy == SHARDING_STRATEGY_V3:
1293+
restore_func = reshard_util.sharding_v3.restore
1294+
else:
1295+
restore_func = reshard_util.sharding_v2.restore
12851296
node_model_state = restore_func(node_model_state, self.model, self.optimizer)
12861297
node_model_state.unpack_keys()
12871298
master_weights = node_model_state.master_weights
@@ -1993,7 +2004,8 @@ def _inner_training_loop(
19932004
steps_trained_progress_bar.update(1)
19942005
if steps_trained_in_current_epoch == 0:
19952006
self._load_rng_state(resume_from_checkpoint)
1996-
self.timers and self.timers("read-data").start()
2007+
if self.args.ignore_data_skip:
2008+
self.timers and self.timers("read-data").start()
19972009
# Reset data loading timer for skipped steps
19982010
_data_load_start_time = time.time()
19992011
continue
@@ -2930,6 +2942,15 @@ def apply_decay_param_fun(x):
29302942
if hasattr(optimizer_cls, "_create_master_weight") and self.args.fp16_opt_level == "O2":
29312943
optimizer_kwargs["multi_precision"] = True
29322944

2945+
if self.args.optim.value == "muon":
2946+
# Attach per-head metadata to fused QKV weights so the Muon
2947+
# optimizer can orthogonalise each head independently.
2948+
for name, param in self.model.named_parameters():
2949+
if "qkv_proj.weight" in name and len(param.shape) == 2:
2950+
param.needs_qkv_split = True
2951+
param.head_num = self.model.config.num_attention_heads
2952+
param.kv_head_num = self.model.config.num_key_value_heads
2953+
29332954
self.optimizer = optimizer_cls(
29342955
learning_rate=self.lr_scheduler if lr_scheduler is None else lr_scheduler,
29352956
apply_decay_param_fun=apply_decay_param_fun,
@@ -2947,6 +2968,7 @@ def apply_decay_param_fun(x):
29472968
def _apply_to_optimizer(self, action):
29482969
attributes = [
29492970
("_accumulators", "_moment1_acc_str"),
2971+
("_accumulators", "_moment_acc_str"), # Muon uses _moment_acc_str instead of _moment1_acc_str
29502972
("_accumulators", "_moment2_acc_str"),
29512973
("_master_weights",),
29522974
("_accumulators_holder",),
@@ -3070,6 +3092,18 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
30703092

30713093
optimizer_cls = AdamWCustom
30723094
optimizer_kwargs.update(adam_kwargs)
3095+
elif args.optim == OptimizerNames.MUON:
3096+
from paddle.optimizer import Muon
3097+
3098+
logger.info("Creating Muon optimizer")
3099+
muon_kwargs = {
3100+
**adam_kwargs,
3101+
"momentum": 0.95,
3102+
"muon_version": 3,
3103+
"is_split_qkv": True,
3104+
}
3105+
optimizer_cls = Muon
3106+
optimizer_kwargs.update(muon_kwargs)
30733107
else:
30743108
raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
30753109

@@ -4031,9 +4065,7 @@ def _save_checkpoint(self, model, metrics=None):
40314065
global_rank, os.path.join(signal_dir, f".master_weight.done.{global_rank}")
40324066
)
40334067

4034-
if self.args.save_checkpoint_format == "unified_checkpoint" and (
4035-
self.args.offload_optim or self.args.tensorwise_offload_optimizer
4036-
):
4068+
if self.args.offload_optim or self.args.tensorwise_offload_optimizer:
40374069
self._offload_optimizer()
40384070
self.runtime_timer.stop()
40394071

paddleformers/trainer/trainer_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@
4545
DygraphShardingOptimizer,
4646
DygraphShardingOptimizerV2,
4747
)
48+
49+
try:
50+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer_v3 import (
51+
DygraphShardingOptimizerV3,
52+
)
53+
except ImportError:
54+
DygraphShardingOptimizerV3 = None
4855
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
4956
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
5057
GroupShardedOptimizerStage2,
@@ -498,6 +505,7 @@ class OptimizerNames(ExplicitEnum):
498505
ADAFACTOR = "adafactor"
499506
ADAMW_MINI = "adamw_mini"
500507
ADAMW_CUSTOM = "adamw_custom"
508+
MUON = "muon"
501509

502510

503511
class ShardingOption(ExplicitEnum):
@@ -1502,6 +1510,12 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15021510
return
15031511

15041512
elif DygraphShardingOptimizerV2 is not None and isinstance(inner_opt, DygraphShardingOptimizerV2):
1513+
# Unwrap to the innermost optimizer (e.g. Muon inside a sharding wrapper).
1514+
core_opt = optimizer._inner_opt
1515+
while hasattr(core_opt, "_inner_opt"):
1516+
core_opt = core_opt._inner_opt
1517+
is_muon_opt = type(core_opt).__name__ == "Muon"
1518+
15051519
parameter_list = []
15061520
for buffer in optimizer._comm_buffer_list:
15071521
for param_name, grad_view in buffer._sharding_param_grad_view.items():
@@ -1515,11 +1529,77 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15151529
slice_param = paddle.slice(param_buffer, axes=[0], starts=[param_begin], ends=[param_end])
15161530
assert slice_param.numel().item() > 0
15171531
slice_param.name = param_name
1532+
# Preserve original shape so Muon's should_use_muon() can identify 2-D weights.
1533+
if is_muon_opt and hasattr(grad_view, "_param") and grad_view._param is not None:
1534+
slice_param.original_shape = grad_view._param.shape
15181535
parameter_list.append(slice_param)
15191536

15201537
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), parameter_list)
15211538
return
15221539

1540+
elif DygraphShardingOptimizerV3 is not None and isinstance(inner_opt, DygraphShardingOptimizerV3):
1541+
# Unwrap to the innermost optimizer (e.g. Muon inside a V3 sharding wrapper).
1542+
core_opt = inner_opt._inner_opt
1543+
while hasattr(core_opt, "_inner_opt"):
1544+
core_opt = core_opt._inner_opt
1545+
is_muon_opt = type(core_opt).__name__ == "Muon"
1546+
1547+
parameter_list = []
1548+
1549+
# --- 1D params: build shard-sized slice params from FusedCommBuffer ---
1550+
# (same logic as V2 branch above, using _comm_buffer_list)
1551+
# IMPORTANT: set slice_param.name = "slice@" + param_name so that the
1552+
# accumulator key matches what V3's sharded_state_dict expects via
1553+
# _split_state_name (it strips the "_moment1_0" suffix to get static_name,
1554+
# which must match param_slice_info keys = original param names after
1555+
# removing the "slice@" prefix added back in sharded_state_dict).
1556+
for buffer in optimizer._comm_buffer_list:
1557+
for param_name, grad_view in buffer._sharding_param_grad_view.items():
1558+
if param_name not in static_to_struct_mapping:
1559+
continue
1560+
struct_name = static_to_struct_mapping[param_name]
1561+
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1562+
continue
1563+
param_buffer = grad_view._param_buffer
1564+
param_begin = grad_view._param_begin
1565+
param_end = grad_view._param_end
1566+
if param_begin >= 0 and param_end > 0 and param_end > param_begin:
1567+
slice_param = paddle.slice(param_buffer, axes=[0], starts=[param_begin], ends=[param_end])
1568+
assert slice_param.numel().item() > 0
1569+
# Use the original param name (no "slice@" prefix), consistent
1570+
# with V3's _create_slice_param and V2's init_optimizer branch.
1571+
slice_param.name = param_name
1572+
parameter_list.append(slice_param)
1573+
1574+
# --- 2D non-MoE params: local rank's full tensors (Muon) ---
1575+
local_2d = optimizer._rank2params_2d.get(optimizer._sharding_rank, [])
1576+
for param in local_2d:
1577+
param_name = param.name
1578+
if param_name not in static_to_struct_mapping:
1579+
continue
1580+
struct_name = static_to_struct_mapping[param_name]
1581+
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1582+
continue
1583+
parameter_list.append(param)
1584+
1585+
# --- 2D MoE expert params: local rank's full tensors (Muon) ---
1586+
if optimizer._moe_sharding_world_size > 1:
1587+
moe_rank = optimizer._moe_sharding_rank
1588+
else:
1589+
moe_rank = 0
1590+
local_2d_moe = optimizer._rank2params_2d_moe.get(moe_rank, [])
1591+
for param in local_2d_moe:
1592+
param_name = param.name
1593+
if param_name not in static_to_struct_mapping:
1594+
continue
1595+
struct_name = static_to_struct_mapping[param_name]
1596+
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1597+
continue
1598+
parameter_list.append(param)
1599+
1600+
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), parameter_list)
1601+
return
1602+
15231603
elif isinstance(optimizer, GroupShardedOptimizerStage2):
15241604
local_params = optimizer._segment_params()[optimizer._rank]
15251605
for p in local_params:

paddleformers/trainer/training_args.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,6 +1528,17 @@ class TrainingArguments:
15281528
"help": "Enable parameter sharding to distribute model parameters across devices, reducing memory footprint per GPU (ZeRO-style optimization)."
15291529
},
15301530
)
1531+
sharding_v3: bool = field(
1532+
default=False,
1533+
metadata={
1534+
"help": (
1535+
"Enable ShardingV3 (hybrid tensor-wise + element-wise) for Muon optimizer. "
1536+
"2D Muon parameters are assigned as whole tensors to ranks (no sharding gather), "
1537+
"while non-2D AdamW parameters use element-wise splitting for memory balance. "
1538+
"Requires split_param=True and Muon optimizer. Set FLAGS_sharding_v3=1."
1539+
)
1540+
},
1541+
)
15311542
sd_sharding_comm_overlap: bool = field(
15321543
default=False,
15331544
metadata={
@@ -2095,6 +2106,16 @@ def is_context_parallel_supported():
20952106
strategy.hybrid_configs["sharding_configs"].split_param = True
20962107
assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad"
20972108

2109+
if self.sharding_v3:
2110+
os.environ["FLAGS_sharding_v3"] = "1"
2111+
assert self.split_param, "sharding_v3 requires split_param=True"
2112+
logger.info("ShardingV3 enabled via sharding_v3=True")
2113+
else:
2114+
os.environ["FLAGS_sharding_v3"] = "0"
2115+
2116+
if self.tensorwise_offload_optimizer:
2117+
os.environ["FLAGS_tensorwise_offload_optimizer"] = "1"
2118+
20982119
if self.sd_release_grads:
20992120
strategy.hybrid_configs["sharding_configs"].release_gradients = True
21002121

paddleformers/trainer/utils/offload_optimizer.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,60 @@ def new_insert_sync(self, sync_var, *args, **kwargs):
9292

9393
setattr(opt_type, "_insert_sync", new_insert_sync)
9494

95+
# Step 4: mock Muon._muon_update and Muon._apply_optimize
96+
# Muon's _muon_update is pure Python (paddle.lerp + paddle.assign),
97+
# so it bypasses the _C_ops.adamw_ patch above. We need explicit
98+
# reload/offload for Muon's momentum_buffer and master_weights.
99+
try:
100+
from paddle.optimizer.muon import Muon
101+
102+
# 4a: Patch _muon_update (staticmethod) — per-param momentum offload
103+
origin_muon_update = Muon._muon_update
104+
105+
def new_muon_update(param, grad, lr, momentum_buffer, *args, **kwargs):
106+
reload(momentum_buffer)
107+
ret = origin_muon_update(param, grad, lr, momentum_buffer, *args, **kwargs)
108+
is_offload_opt = getattr(param, "is_offload_opt", True)
109+
if is_offload_opt:
110+
offload(momentum_buffer)
111+
return ret
112+
113+
Muon._muon_update = staticmethod(new_muon_update)
114+
115+
# 4b: Patch _apply_optimize — reload/offload master_weights around Muon updates
116+
origin_muon_apply = Muon._apply_optimize
117+
118+
def new_muon_apply(self, loss, startup_program, params_grads):
119+
# Reload master_weights to GPU before Muon update
120+
# (needed after checkpoint restore where master_weights may be on CPU/pinned)
121+
mw_dict = getattr(self, "_master_weights", None)
122+
if mw_dict:
123+
for param, grad in params_grads:
124+
if grad is None:
125+
continue
126+
mw = mw_dict.get(param.name)
127+
if mw is not None and isinstance(mw, paddle.Tensor):
128+
reload(mw)
129+
130+
ret = origin_muon_apply(self, loss, startup_program, params_grads)
131+
132+
# Offload master_weights back to CPU pinned after Muon update
133+
if mw_dict:
134+
for param, grad in params_grads:
135+
if grad is None:
136+
continue
137+
mw = mw_dict.get(param.name)
138+
if mw is not None and isinstance(mw, paddle.Tensor):
139+
is_offload_opt = getattr(param, "is_offload_opt", True)
140+
if is_offload_opt:
141+
offload(mw)
142+
return ret
143+
144+
Muon._apply_optimize = new_muon_apply
145+
146+
except ImportError:
147+
pass
148+
95149

96150
def hack_offload_optimizer_eb5():
97151
# Step 1: mock _add_accumulator

paddleformers/trainer/utils/reshard/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from . import pp_reshard, sharding_v1, sharding_v2
15+
from . import pp_reshard, sharding_v1, sharding_v2, sharding_v3
1616
from .common import (
1717
SHARDING_STRATEGY_V1,
1818
SHARDING_STRATEGY_V2,
19+
SHARDING_STRATEGY_V3,
1920
NodeModelState,
2021
all_gather_state_dict,
2122
convert_opt_name_to_tname,

paddleformers/trainer/utils/reshard/common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@
2121
DygraphShardingOptimizer,
2222
DygraphShardingOptimizerV2,
2323
)
24+
25+
try:
26+
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer_v3 import (
27+
DygraphShardingOptimizerV3,
28+
)
29+
except ImportError:
30+
DygraphShardingOptimizerV3 = None
2431
from paddle.distributed.fleet.utils.log_util import logger
2532

2633
from paddleformers.utils.tools import get_env_device
@@ -29,6 +36,7 @@
2936

3037
SHARDING_STRATEGY_V1 = "ShardingV1"
3138
SHARDING_STRATEGY_V2 = "ShardingV2"
39+
SHARDING_STRATEGY_V3 = "ShardingV3"
3240

3341

3442
def is_sharding_opt(optimizer):
@@ -45,10 +53,18 @@ def check(cls):
4553
if check(DygraphShardingOptimizerV2):
4654
return True
4755

56+
if DygraphShardingOptimizerV3 is not None:
57+
if check(DygraphShardingOptimizerV3):
58+
return True
59+
4860
return False
4961

5062

5163
def get_sharding_strategy(optimizer):
64+
if DygraphShardingOptimizerV3 is not None:
65+
tmp = unwrap_optimizer(optimizer, DygraphShardingOptimizerV3)
66+
if tmp is not None:
67+
return SHARDING_STRATEGY_V3
5268
if DygraphShardingOptimizerV2 is not None:
5369
tmp = unwrap_optimizer(optimizer, DygraphShardingOptimizerV2)
5470
if tmp is not None:

0 commit comments

Comments
 (0)