Skip to content

Commit 6c32d4f

Browse files
Fix the incorrect logic when loading Mixtral series model weights. (#1064)
1 parent 263a7d0 commit 6c32d4f

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

lightllm/models/mixtral/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight: MixtralTransfor
3232

3333
return fused_experts_impl(
3434
hidden_states=hidden_states,
35-
w1=layer_weight.experts.w1,
36-
w2=layer_weight.experts.w2,
35+
w1=layer_weight.experts.w1[0],
36+
w2=layer_weight.experts.w2[0],
3737
topk_weights=topk_weights,
3838
topk_ids=topk_ids,
3939
inplace=True,

lightllm/models/mixtral/layer_weights/transformer_layer_weight.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from lightllm.utils.log_utils import init_logger
23
from lightllm.utils.envs_utils import enable_env_vars
34
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
@@ -30,10 +31,10 @@ def _init_weight_names(self):
3031
self.moe_gate_weight_name = f"model.layers.{self.layer_num_}.block_sparse_moe.gate.weight"
3132
self.moe_gate_bias_name = None
3233

33-
def _init_ffn(self, weights):
34-
self._init_moe(weights)
34+
def _init_ffn(self):
35+
self._init_moe()
3536

36-
def _init_moe(self, weights):
37+
def _init_moe(self):
3738
inter_size = self.network_config_["intermediate_size"]
3839
split_inter_size = inter_size // self.tp_world_size_
3940

@@ -45,16 +46,26 @@ def _init_moe(self, weights):
4546
layer_num=self.layer_num_,
4647
name="moe_gate",
4748
tp_rank=0,
48-
tp_size=1, # no tensor parallelism
49+
tp_world_size=1, # no tensor parallelism
4950
)
5051

51-
load_func = FusedMoeWeightEP if enable_env_vars("ETP_MODE_ENABLED") else FusedMoeWeightTP
52-
self.experts = load_func(
53-
gate_proj_name="w1",
54-
down_proj_name="w2",
55-
up_proj_name="w3",
56-
weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts",
57-
n_routed_experts=self.n_routed_experts,
58-
split_inter_size=split_inter_size,
59-
data_type=self.data_type_,
60-
)
52+
moe_mode = os.getenv("MOE_MODE", "TP")
53+
assert moe_mode in ["TP"], f"Unsupported moe mode: {moe_mode}"
54+
55+
if moe_mode == "TP":
56+
self.experts = FusedMoeWeightTP(
57+
gate_proj_name="w1",
58+
down_proj_name="w2",
59+
up_proj_name="w3",
60+
e_score_correction_bias_name="",
61+
weight_prefix=f"model.layers.{self.layer_num_}.block_sparse_moe.experts",
62+
n_routed_experts=self.n_routed_experts,
63+
split_inter_size=split_inter_size,
64+
data_type=self.data_type_,
65+
network_config=self.network_config_,
66+
layer_num=self.layer_num_,
67+
quant_cfg=self.quant_cfg,
68+
num_fused_shared_experts=0,
69+
)
70+
else:
71+
raise ValueError(f"Unsupported moe mode: {moe_mode}")

0 commit comments

Comments
 (0)