Skip to content

Commit 4ee9791

Browse files
author
wangzaijun
committed
fix
1 parent 58a7cb2 commit 4ee9791

File tree

1 file changed

+35
-14
lines changed

1 file changed

+35
-14
lines changed

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -209,20 +209,41 @@ def _init_qkvo(self):
209209
)
210210

211211
def _load_mlp(self, mlp_prefix):
212-
self.gate_up_proj = MultiROWMMWeight(
213-
weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"],
214-
data_type=self.data_type_,
215-
quant_cfg=self.quant_cfg,
216-
layer_num=self.layer_num_,
217-
name="gate_up_proj",
218-
)
219-
self.down_proj = COLMMWeight(
220-
weight_name=f"{mlp_prefix}.down_proj.weight",
221-
data_type=self.data_type_,
222-
quant_cfg=self.quant_cfg,
223-
layer_num=self.layer_num_,
224-
name="down_proj",
225-
)
212+
moe_mode = os.getenv("MOE_MODE", "TP")
213+
if self.is_moe and moe_mode == "EP":
214+
self.gate_up_proj = MultiROWMMWeight(
215+
weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"],
216+
data_type=self.data_type_,
217+
quant_cfg=self.quant_cfg,
218+
layer_num=self.layer_num_,
219+
name="gate_up_proj",
220+
tp_rank=0,
221+
tp_world_size=1,
222+
)
223+
self.down_proj = COLMMWeight(
224+
weight_name=f"{mlp_prefix}.down_proj.weight",
225+
data_type=self.data_type_,
226+
quant_cfg=self.quant_cfg,
227+
layer_num=self.layer_num_,
228+
name="down_proj",
229+
tp_rank=0,
230+
tp_world_size=1,
231+
)
232+
else:
233+
self.gate_up_proj = MultiROWMMWeight(
234+
weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"],
235+
data_type=self.data_type_,
236+
quant_cfg=self.quant_cfg,
237+
layer_num=self.layer_num_,
238+
name="gate_up_proj",
239+
)
240+
self.down_proj = COLMMWeight(
241+
weight_name=f"{mlp_prefix}.down_proj.weight",
242+
data_type=self.data_type_,
243+
quant_cfg=self.quant_cfg,
244+
layer_num=self.layer_num_,
245+
name="down_proj",
246+
)
226247

227248
def _init_moe(self):
228249
moe_intermediate_size = self.network_config_["moe_intermediate_size"]

0 commit comments

Comments
 (0)