@@ -20,6 +20,8 @@ def create_tp_moe_wegiht_obj(
2020 network_config : Dict [str , Any ],
2121 layer_num : int ,
2222 quant_cfg : Quantcfg = None ,
23+ fused_gate_up : bool = False ,
24+ gate_up_proj_name : str = None ,
2325) -> Union ["FusedMoeWeightTP" , "FusedAWQMARLINMoeWeightTP" ]:
2426 quant_method = quant_cfg .get_quant_method (layer_num , "fused_moe" )
2527 if quant_method is not None and quant_method .method_name == "awq_marlin" :
@@ -36,6 +38,8 @@ def create_tp_moe_wegiht_obj(
3638 network_config = network_config ,
3739 layer_num = layer_num ,
3840 quant_cfg = quant_cfg ,
41+ fused_gate_up = fused_gate_up ,
42+ gate_up_proj_name = gate_up_proj_name ,
3943 )
4044 else :
4145 return FusedMoeWeightTP (
@@ -51,6 +55,8 @@ def create_tp_moe_wegiht_obj(
5155 network_config = network_config ,
5256 layer_num = layer_num ,
5357 quant_cfg = quant_cfg ,
58+ fused_gate_up = fused_gate_up ,
59+ gate_up_proj_name = gate_up_proj_name ,
5460 )
5561
5662
@@ -69,6 +75,8 @@ def __init__(
6975 network_config : Dict [str , Any ],
7076 layer_num : int ,
7177 quant_cfg : Quantcfg = None ,
78+ fused_gate_up : bool = False ,
79+ gate_up_proj_name : str = None ,
7280 ) -> None :
7381 super ().__init__ ()
7482 self .quant_method = quant_cfg .get_quant_method (layer_num , "fused_moe" )
@@ -79,6 +87,8 @@ def __init__(
7987 self .w1_weight_name = gate_proj_name
8088 self .w2_weight_name = down_proj_name
8189 self .w3_weight_name = up_proj_name
90+ self .fused_gate_up = fused_gate_up
91+ self .gate_up_proj_name = gate_up_proj_name
8292
8393 self .e_score_correction_bias_name = e_score_correction_bias_name
8494 self .weight_prefix = weight_prefix
@@ -181,8 +191,6 @@ def _fuse(self):
181191
182192 inter_shape , hidden_size = self .w2_list [0 ].shape [0 ], self .w2_list [0 ].shape [1 ]
183193 w2 = torch ._utils ._flatten_dense_tensors (self .w2_list ).view (len (self .w2_list ), inter_shape , hidden_size )
184- if self .fused_gate_up :
185- w2 = w2 .transpose (1 , 2 ).contiguous ()
186194 if not self .quantized_weight and self .quant_method is not None :
187195 qw1 , qw1_scale , qw1_zero_point = self .quant_method .quantize (w1 )
188196 qw2 , qw2_scale , qw2_zero_point = self .quant_method .quantize (w2 )
@@ -228,51 +236,49 @@ def _fuse_weight_scale(self):
228236 delattr (self , "experts_up_proj_scales" )
229237 delattr (self , "experts_gate_proj_scales" )
230238
239+ def fused_gate_up_weights_load (self , weights ):
240+ key_gate_up = f"{ self .weight_prefix } .{ self .gate_up_proj_name } " # ...experts.gate_up_proj
241+ key_down = f"{ self .weight_prefix } .{ self .w2_weight_name } "
242+ if (key_gate_up not in weights ) or (key_down not in weights ):
243+ return
244+ gate_up = weights [key_gate_up ]
245+ down = weights [key_down ]
246+ _ , _ , I_double = gate_up .shape
247+ I_single = I_double // 2
248+ start = self .tp_rank_ * self .split_inter_size
249+ end = (self .tp_rank_ + 1 ) * self .split_inter_size
250+
251+ for i_experts in range (self .n_routed_experts ):
252+ gate_up_2d = gate_up [i_experts ]
253+ self .experts_gate_projs [i_experts ] = gate_up_2d [:, :I_single ][:, start :end ].t ().contiguous ()
254+ self .experts_up_projs [i_experts ] = gate_up_2d [:, I_single :][:, start :end ].t ().contiguous ()
255+ self .w2_list [i_experts ] = down [i_experts ].t ()[:, start :end ].contiguous ()
256+
257+ def normal_weights_load (self , weights ):
258+ for i_experts in range (self .n_routed_experts ):
259+
260+ w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
261+ w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
262+ w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
263+
264+ start = self .tp_rank_ * self .split_inter_size
265+ end = (self .tp_rank_ + 1 ) * self .split_inter_size
266+
267+ if w1_weight in weights :
268+ self .experts_gate_projs [i_experts ] = weights [w1_weight ][start :end , :]
269+ if w3_weight in weights :
270+ self .experts_up_projs [i_experts ] = weights [w3_weight ][start :end , :]
271+ if w2_weight in weights :
272+ self .w2_list [i_experts ] = weights [w2_weight ][start :end ]
273+
231274 def load_hf_weights (self , weights ):
232275 if self .e_score_correction_bias_name in weights :
233276 self .e_score_correction_bias = self ._cuda (weights [self .e_score_correction_bias_name ])
234- self .fused_gate_up = self .w3_weight_name is None # gate_up: [E,H,2I] down: [E,I,H]
235- key_gateup_3d = f"{ self .weight_prefix } .{ self .w1_weight_name } " # ...experts.gate_up_proj
236- key_down_3d = f"{ self .weight_prefix } .{ self .w2_weight_name } "
237-
238- if self .fused_gate_up and (key_gateup_3d in weights ) and (key_down_3d in weights ):
239- gate_up_3d = weights [key_gateup_3d ]
240- down_3d = weights [key_down_3d ]
241- assert gate_up_3d .dim () == 3 and down_3d .dim () == 3
242-
243- E_ckpt , H_ , twoE = gate_up_3d .shape
244- assert E_ckpt == self .n_routed_experts , f"experts mismatch: ckpt { E_ckpt } vs cfg { self .n_routed_experts } "
245- Eint_total = twoE // 2
246- start , end = self .tp_rank_ * self .split_inter_size , (self .tp_rank_ + 1 ) * self .split_inter_size
247- assert end <= Eint_total , "TP split exceeds total expert-intermediate size"
248-
249- for i in range (self .n_routed_experts ):
250- gu2d = gate_up_3d [i ]
251- gate2d = gu2d [:, :Eint_total ][:, start :end ].t ().contiguous ()
252- up2d = gu2d [:, Eint_total :][:, start :end ].t ().contiguous ()
253- self .experts_gate_projs [i ] = gate2d
254- self .experts_up_projs [i ] = up2d
255-
256- self .w2_list [i ] = down_3d [i ][start :end , :].contiguous ()
277+ if self .fused_gate_up : # gate_up: [E,H,2I] down: [E,I,H]
278+ self .fused_gate_up_weights_load (weights )
257279 else :
258- for i_experts in range (self .n_routed_experts ):
259- w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
260- w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
261- w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
262-
263- if w1_weight in weights :
264- self .experts_gate_projs [i_experts ] = weights [w1_weight ][
265- self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
266- ]
267- if w3_weight in weights :
268- self .experts_up_projs [i_experts ] = weights [w3_weight ][
269- self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
270- ]
271-
272- if w2_weight in weights :
273- self .w2_list [i_experts ] = weights [w2_weight ][
274- :, self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 )
275- ]
280+ self .normal_weights_load (weights )
281+
276282 if self .quant_method is not None :
277283 if self .fused_gate_up :
278284 raise ValueError ("qwen3_vl_moe not support quant now" )
@@ -342,6 +348,8 @@ def __init__(
342348 network_config : Dict [str , Any ],
343349 layer_num : int ,
344350 quant_cfg : Quantcfg = None ,
351+ fused_gate_up : bool = False ,
352+ gate_up_proj_name : str = None ,
345353 ) -> None :
346354 super ().__init__ ()
347355 self .quant_method = quant_cfg .get_quant_method (layer_num , "fused_moe" )
0 commit comments