@@ -765,7 +765,7 @@ def __init__(self, config, num_experts, expert_hidden_size, **kwargs):
765765 # [hidden_size, n_expert]
766766 self .weight = paddle .create_parameter (
767767 shape = [expert_hidden_size , num_experts ],
768- dtype = paddle . get_default_dtype () ,
768+ dtype = "float32" ,
769769 is_bias = False ,
770770 default_initializer = nn .initializer .Constant (1.0 ),
771771 )
@@ -1031,14 +1031,18 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
10311031 base_actions .pop ("embed_tokens.weight" )
10321032
10331033 # Column Linear
1034- base_actions ["layers.0.self_attn.q_proj.weight" ] = partial (fn , is_column = True )
1035- base_actions ["layers.0.self_attn.q_proj.bias" ] = partial (fn , is_column = True )
1036- # if we have enough num_key_value_heads to split, then split it.
1037- if config .num_key_value_heads % config .tensor_parallel_degree == 0 :
1038- base_actions ["layers.0.self_attn.k_proj.weight" ] = partial (fn , is_column = True )
1039- base_actions ["layers.0.self_attn.v_proj.weight" ] = partial (fn , is_column = True )
1040- base_actions ["layers.0.self_attn.k_proj.bias" ] = partial (fn , is_column = True )
1041- base_actions ["layers.0.self_attn.v_proj.bias" ] = partial (fn , is_column = True )
1034+ if config .fuse_attention_qkv :
1035+ base_actions ["layers.0.self_attn.qkv_proj.weight" ] = partial (fn , is_column = True )
1036+ base_actions ["layers.0.self_attn.qkv_proj.bias" ] = partial (fn , is_column = True )
1037+ else :
1038+ base_actions ["layers.0.self_attn.q_proj.weight" ] = partial (fn , is_column = True )
1039+ base_actions ["layers.0.self_attn.q_proj.bias" ] = partial (fn , is_column = True )
1040+ # if we have enough num_key_value_heads to split, then split it.
1041+ if config .num_key_value_heads % config .tensor_parallel_degree == 0 :
1042+ base_actions ["layers.0.self_attn.k_proj.weight" ] = partial (fn , is_column = True )
1043+ base_actions ["layers.0.self_attn.v_proj.weight" ] = partial (fn , is_column = True )
1044+ base_actions ["layers.0.self_attn.k_proj.bias" ] = partial (fn , is_column = True )
1045+ base_actions ["layers.0.self_attn.v_proj.bias" ] = partial (fn , is_column = True )
10421046
10431047 for key , action in base_actions .items ():
10441048 if "layers.0." in key :
@@ -1047,11 +1051,20 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
10471051 final_actions [key ] = action
10481052
10491053 # Add tp split for expert params.
1050- base_actions = {
1051- "layers.0.mlp.experts.0.gate_proj.weight" : partial (fn , is_column = True ),
1052- "layers.0.mlp.experts.0.down_proj.weight" : partial (fn , is_column = False ),
1053- "layers.0.mlp.experts.0.up_proj.weight" : partial (fn , is_column = True ),
1054- }
1054+ if config .fuse_attention_ffn :
1055+ base_actions = {
1056+ "layers.0.mlp.experts.0.gate_up_fused_proj.weight" : partial (
1057+ fn , is_column = True , is_naive_2fuse = True
1058+ ),
1059+ "layers.0.mlp.experts.0.down_proj.weight" : partial (fn , is_column = False ),
1060+ }
1061+ else :
1062+ # Add tp split for expert params.
1063+ base_actions = {
1064+ "layers.0.mlp.experts.0.gate_proj.weight" : partial (fn , is_column = True ),
1065+ "layers.0.mlp.experts.0.up_proj.weight" : partial (fn , is_column = True ),
1066+ "layers.0.mlp.experts.0.down_proj.weight" : partial (fn , is_column = False ),
1067+ }
10551068 for key , action in base_actions .items ():
10561069 for i in range (num_layers ):
10571070 newkey = key .replace ("layers.0." , f"layers.{ i } ." )
@@ -1060,11 +1073,19 @@ def get_tensor_parallel_split_mappings(num_layers, num_experts):
10601073 final_actions [newkey2 ] = action
10611074
10621075 # Add tp split for shared expert params.
1063- base_actions = {
1064- "layers.0.mlp.shared_expert.gate_proj.weight" : partial (fn , is_column = True ),
1065- "layers.0.mlp.shared_expert.up_proj.weight" : partial (fn , is_column = True ),
1066- "layers.0.mlp.shared_expert.down_proj.weight" : partial (fn , is_column = False ),
1067- }
1076+ if config .fuse_attention_ffn :
1077+ base_actions = {
1078+ "layers.0.mlp.shared_expert.gate_up_fused_proj.weight" : partial (
1079+ fn , is_column = True , is_naive_2fuse = True
1080+ ),
1081+ "layers.0.mlp.shared_expert.down_proj.weight" : partial (fn , is_column = False ),
1082+ }
1083+ else :
1084+ base_actions = {
1085+ "layers.0.mlp.shared_expert.gate_proj.weight" : partial (fn , is_column = True ),
1086+ "layers.0.mlp.shared_expert.up_proj.weight" : partial (fn , is_column = True ),
1087+ "layers.0.mlp.shared_expert.down_proj.weight" : partial (fn , is_column = False ),
1088+ }
10681089 for key , action in base_actions .items ():
10691090 if "layers.0." in key :
10701091 for i in range (num_layers ):
@@ -1101,24 +1122,24 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen2MoeConfig, is_fuse=False
11011122 ]
11021123
11031124 fuse_gate_up_keys = (
1104- "layers.0.mlp.gate_proj.weight" ,
1105- "layers.0.mlp.up_proj.weight" ,
1106- "layers.0.mlp.gate_up_fused_proj.weight" ,
1125+ "layers.0.mlp.experts.0. gate_proj.weight" ,
1126+ "layers.0.mlp.experts.0. up_proj.weight" ,
1127+ "layers.0.mlp.experts.0. gate_up_fused_proj.weight" ,
11071128 )
11081129 num_heads = config .num_attention_heads
11091130 num_key_value_heads = getattr (config , "num_key_value_heads" , num_heads )
11101131 fuse_attention_qkv = getattr (config , "fuse_attention_qkv" , False )
11111132 fuse_attention_ffn = getattr (config , "fuse_attention_ffn" , False )
1133+ num_experts = getattr (config , "num_experts" , 128 )
11121134
11131135 final_actions = {}
11141136 if is_fuse :
11151137 if fuse_attention_qkv :
11161138 for i in range (config .num_hidden_layers ):
1117- for fuse_keys in fuse_qkv_keys :
1118- keys = tuple ([key .replace ("layers.0." , f"layers.{ i } ." ) for key in fuse_keys ])
1119- final_actions [keys ] = partial (
1120- fn , is_qkv = True , num_heads = num_heads , num_key_value_heads = num_key_value_heads
1121- )
1139+ keys = [key .replace ("layers.0." , f"layers.{ i } ." ) for key in fuse_gate_up_keys ]
1140+ for j in range (num_experts ):
1141+ experts_keys = tuple ([key .replace ("experts.0." , f"experts.{ j } ." ) for key in keys ])
1142+ final_actions [experts_keys ] = fn
11221143 if fuse_attention_ffn :
11231144 for i in range (config .num_hidden_layers ):
11241145 keys = tuple ([key .replace ("layers.0." , f"layers.{ i } ." ) for key in fuse_gate_up_keys ])
@@ -1137,8 +1158,10 @@ def _get_fuse_or_split_param_mappings(cls, config: Qwen2MoeConfig, is_fuse=False
11371158 )
11381159 if not fuse_attention_ffn :
11391160 for i in range (config .num_hidden_layers ):
1140- keys = tuple ([key .replace ("layers.0." , f"layers.{ i } ." ) for key in fuse_gate_up_keys ])
1141- final_actions [keys ] = partial (fn , split_nums = 2 )
1161+ keys = [key .replace ("layers.0." , f"layers.{ i } ." ) for key in fuse_gate_up_keys ]
1162+ for j in range (num_experts ):
1163+ experts_keys = tuple ([key .replace ("experts.0." , f"experts.{ j } ." ) for key in keys ])
1164+ final_actions [experts_keys ] = partial (fn , split_nums = 2 )
11421165 return final_actions
11431166
11441167 def _init_weights (self , layer ):
0 commit comments