Skip to content

Commit a425470

Browse files
authored
[RL] fix qwen load when fuse is enabled and modify gate presion to fp32 (#10842)
1 parent 6d6078e commit a425470

File tree

2 files changed

+90
-46
lines changed

2 files changed

+90
-46
lines changed

paddlenlp/transformers/qwen2_moe/modeling.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
765765
# [hidden_size, n_expert]
766766
self.weight = paddle.create_parameter(
767767
shape=[expert_hidden_size, num_experts],
768-
dtype=paddle.get_default_dtype(),
768+
dtype="float32",
769769
is_bias=False,
770770
default_initializer=nn.initializer.Constant(1.0),
771771
)
@@ -1031,14 +1031,18 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
10311031
base_actions.pop("embed_tokens.weight")
10321032

10331033
# Column Linear
1034-
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
1035-
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
1036-
# if we have enough num_key_value_heads to split, then split it.
1037-
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
1038-
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
1039-
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
1040-
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
1041-
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
1034+
if config.fuse_attention_qkv:
1035+
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
1036+
base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(fn, is_column=True)
1037+
else:
1038+
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
1039+
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
1040+
# if we have enough num_key_value_heads to split, then split it.
1041+
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
1042+
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
1043+
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
1044+
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
1045+
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
10421046

10431047
for key, action in base_actions.items():
10441048
if "layers.0." in key:
@@ -1047,11 +1051,20 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
10471051
final_actions[key] = action
10481052

10491053
# Add tp split for expert params.
1050-
base_actions = {
1051-
"layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True),
1052-
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
1053-
"layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True),
1054-
}
1054+
if config.fuse_attention_ffn:
1055+
base_actions = {
1056+
"layers.0.mlp.experts.0.gate_up_fused_proj.weight": partial(
1057+
fn, is_column=True, is_naive_2fuse=True
1058+
),
1059+
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
1060+
}
1061+
else:
1062+
# Add tp split for expert params.
1063+
base_actions = {
1064+
"layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True),
1065+
"layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True),
1066+
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
1067+
}
10551068
for key, action in base_actions.items():
10561069
for i in range(num_layers):
10571070
newkey = key.replace("layers.0.", f"layers.{i}.")
@@ -1060,11 +1073,19 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
10601073
final_actions[newkey2] = action
10611074

10621075
# Add tp split for shared expert params.
1063-
base_actions = {
1064-
"layers.0.mlp.shared_expert.gate_proj.weight": partial(fn, is_column=True),
1065-
"layers.0.mlp.shared_expert.up_proj.weight": partial(fn, is_column=True),
1066-
"layers.0.mlp.shared_expert.down_proj.weight": partial(fn, is_column=False),
1067-
}
1076+
if config.fuse_attention_ffn:
1077+
base_actions = {
1078+
"layers.0.mlp.shared_expert.gate_up_fused_proj.weight": partial(
1079+
fn, is_column=True, is_naive_2fuse=True
1080+
),
1081+
"layers.0.mlp.shared_expert.down_proj.weight": partial(fn, is_column=False),
1082+
}
1083+
else:
1084+
base_actions = {
1085+
"layers.0.mlp.shared_expert.gate_proj.weight": partial(fn, is_column=True),
1086+
"layers.0.mlp.shared_expert.up_proj.weight": partial(fn, is_column=True),
1087+
"layers.0.mlp.shared_expert.down_proj.weight": partial(fn, is_column=False),
1088+
}
10681089
for key, action in base_actions.items():
10691090
if "layers.0." in key:
10701091
for i in range(num_layers):
@@ -1101,24 +1122,24 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen2MoeConfig, is_fuse=False
11011122
]
11021123

11031124
fuse_gate_up_keys = (
1104-
"layers.0.mlp.gate_proj.weight",
1105-
"layers.0.mlp.up_proj.weight",
1106-
"layers.0.mlp.gate_up_fused_proj.weight",
1125+
"layers.0.mlp.experts.0.gate_proj.weight",
1126+
"layers.0.mlp.experts.0.up_proj.weight",
1127+
"layers.0.mlp.experts.0.gate_up_fused_proj.weight",
11071128
)
11081129
num_heads = config.num_attention_heads
11091130
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
11101131
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
11111132
fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
1133+
num_experts = getattr(config, "num_experts", 128)
11121134

11131135
final_actions = {}
11141136
if is_fuse:
11151137
if fuse_attention_qkv:
11161138
for i in range(config.num_hidden_layers):
1117-
for fuse_keys in fuse_qkv_keys:
1118-
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys])
1119-
final_actions[keys] = partial(
1120-
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
1121-
)
1139+
keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]
1140+
for j in range(num_experts):
1141+
experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys])
1142+
final_actions[experts_keys] = fn
11221143
if fuse_attention_ffn:
11231144
for i in range(config.num_hidden_layers):
11241145
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
@@ -1137,8 +1158,10 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen2MoeConfig, is_fuse=False
11371158
)
11381159
if not fuse_attention_ffn:
11391160
for i in range(config.num_hidden_layers):
1140-
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
1141-
final_actions[keys] = partial(fn, split_nums=2)
1161+
keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]
1162+
for j in range(num_experts):
1163+
experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys])
1164+
final_actions[experts_keys] = partial(fn, split_nums=2)
11421165
return final_actions
11431166

11441167
def _init_weights(self, layer):

paddlenlp/transformers/qwen3_moe/modeling.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -414,11 +414,18 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
414414
base_actions.pop("embed_tokens.weight")
415415

416416
# Column Linear
417-
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
418-
# if we have enough num_key_value_heads to split, then split it.
419-
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
420-
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
421-
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
417+
if config.fuse_attention_qkv:
418+
base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True)
419+
base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(fn, is_column=True)
420+
else:
421+
base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True)
422+
base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True)
423+
# if we have enough num_key_value_heads to split, then split it.
424+
if config.num_key_value_heads % config.tensor_parallel_degree == 0:
425+
base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True)
426+
base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True)
427+
base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True)
428+
base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True)
422429

423430
for key, action in base_actions.items():
424431
if "layers.0." in key:
@@ -427,11 +434,20 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
427434
final_actions[key] = action
428435

429436
# Add tp split for expert params.
430-
base_actions = {
431-
"layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True),
432-
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
433-
"layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True),
434-
}
437+
if config.fuse_attention_ffn:
438+
base_actions = {
439+
"layers.0.mlp.experts.0.gate_up_fused_proj.weight": partial(
440+
fn, is_column=True, is_naive_2fuse=True
441+
),
442+
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
443+
}
444+
else:
445+
# Add tp split for expert params.
446+
base_actions = {
447+
"layers.0.mlp.experts.0.gate_proj.weight": partial(fn, is_column=True),
448+
"layers.0.mlp.experts.0.up_proj.weight": partial(fn, is_column=True),
449+
"layers.0.mlp.experts.0.down_proj.weight": partial(fn, is_column=False),
450+
}
435451
for key, action in base_actions.items():
436452
for i in range(num_layers):
437453
newkey = key.replace("layers.0.", f"layers.{i}.")
@@ -471,14 +487,15 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen3MoeConfig, is_fuse=False
471487
]
472488

473489
fuse_gate_up_keys = (
474-
"layers.0.mlp.gate_proj.weight",
475-
"layers.0.mlp.up_proj.weight",
476-
"layers.0.mlp.gate_up_fused_proj.weight",
490+
"layers.0.mlp.experts.0.gate_proj.weight",
491+
"layers.0.mlp.experts.0.up_proj.weight",
492+
"layers.0.mlp.experts.0.gate_up_fused_proj.weight",
477493
)
478494
num_heads = config.num_attention_heads
479495
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
480496
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
481497
fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)
498+
num_experts = getattr(config, "num_experts", 128)
482499

483500
final_actions = {}
484501
if is_fuse:
@@ -491,8 +508,10 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen3MoeConfig, is_fuse=False
491508
)
492509
if fuse_attention_ffn:
493510
for i in range(config.num_hidden_layers):
494-
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
495-
final_actions[keys] = fn
511+
keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]
512+
for j in range(num_experts):
513+
experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys])
514+
final_actions[experts_keys] = fn
496515
else:
497516
if not fuse_attention_qkv:
498517
for i in range(config.num_hidden_layers):
@@ -507,8 +526,10 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen3MoeConfig, is_fuse=False
507526
)
508527
if not fuse_attention_ffn:
509528
for i in range(config.num_hidden_layers):
510-
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
511-
final_actions[keys] = partial(fn, split_nums=2)
529+
keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]
530+
for j in range(num_experts):
531+
experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys])
532+
final_actions[experts_keys] = partial(fn, split_nums=2)
512533
return final_actions
513534

514535
def _init_weights(self, layer):

0 commit comments

Comments
 (0)