@@ -209,20 +209,41 @@ def _init_qkvo(self):
209209 )
210210
211211 def _load_mlp (self , mlp_prefix ):
212- self .gate_up_proj = MultiROWMMWeight (
213- weight_names = [f"{ mlp_prefix } .gate_proj.weight" , f"{ mlp_prefix } .up_proj.weight" ],
214- data_type = self .data_type_ ,
215- quant_cfg = self .quant_cfg ,
216- layer_num = self .layer_num_ ,
217- name = "gate_up_proj" ,
218- )
219- self .down_proj = COLMMWeight (
220- weight_name = f"{ mlp_prefix } .down_proj.weight" ,
221- data_type = self .data_type_ ,
222- quant_cfg = self .quant_cfg ,
223- layer_num = self .layer_num_ ,
224- name = "down_proj" ,
225- )
212+ moe_mode = os .getenv ("MOE_MODE" , "TP" )
213+ if self .is_moe and moe_mode == "EP" :
214+ self .gate_up_proj = MultiROWMMWeight (
215+ weight_names = [f"{ mlp_prefix } .gate_proj.weight" , f"{ mlp_prefix } .up_proj.weight" ],
216+ data_type = self .data_type_ ,
217+ quant_cfg = self .quant_cfg ,
218+ layer_num = self .layer_num_ ,
219+ name = "gate_up_proj" ,
220+ tp_rank = 0 ,
221+ tp_world_size = 1 ,
222+ )
223+ self .down_proj = COLMMWeight (
224+ weight_name = f"{ mlp_prefix } .down_proj.weight" ,
225+ data_type = self .data_type_ ,
226+ quant_cfg = self .quant_cfg ,
227+ layer_num = self .layer_num_ ,
228+ name = "down_proj" ,
229+ tp_rank = 0 ,
230+ tp_world_size = 1 ,
231+ )
232+ else :
233+ self .gate_up_proj = MultiROWMMWeight (
234+ weight_names = [f"{ mlp_prefix } .gate_proj.weight" , f"{ mlp_prefix } .up_proj.weight" ],
235+ data_type = self .data_type_ ,
236+ quant_cfg = self .quant_cfg ,
237+ layer_num = self .layer_num_ ,
238+ name = "gate_up_proj" ,
239+ )
240+ self .down_proj = COLMMWeight (
241+ weight_name = f"{ mlp_prefix } .down_proj.weight" ,
242+ data_type = self .data_type_ ,
243+ quant_cfg = self .quant_cfg ,
244+ layer_num = self .layer_num_ ,
245+ name = "down_proj" ,
246+ )
226247
227248 def _init_moe (self ):
228249 moe_intermediate_size = self .network_config_ ["moe_intermediate_size" ]
0 commit comments