1+ import os
12from lightllm .utils .log_utils import init_logger
23from lightllm .utils .envs_utils import enable_env_vars
34from lightllm .models .llama .layer_weights .transformer_layer_weight import LlamaTransformerLayerWeight
@@ -30,10 +31,10 @@ def _init_weight_names(self):
3031 self .moe_gate_weight_name = f"model.layers.{ self .layer_num_ } .block_sparse_moe.gate.weight"
3132 self .moe_gate_bias_name = None
3233
33- def _init_ffn (self , weights ):
34- self ._init_moe (weights )
34+ def _init_ffn (self ):
35+ self ._init_moe ()
3536
36- def _init_moe (self , weights ):
37+ def _init_moe (self ):
3738 inter_size = self .network_config_ ["intermediate_size" ]
3839 split_inter_size = inter_size // self .tp_world_size_
3940
@@ -45,16 +46,26 @@ def _init_moe(self, weights):
4546 layer_num = self .layer_num_ ,
4647 name = "moe_gate" ,
4748 tp_rank = 0 ,
48- tp_size = 1 , # no tensor parallelism
49+ tp_world_size = 1 , # no tensor parallelism
4950 )
5051
51- load_func = FusedMoeWeightEP if enable_env_vars ("ETP_MODE_ENABLED" ) else FusedMoeWeightTP
52- self .experts = load_func (
53- gate_proj_name = "w1" ,
54- down_proj_name = "w2" ,
55- up_proj_name = "w3" ,
56- weight_prefix = f"model.layers.{ self .layer_num_ } .block_sparse_moe.experts" ,
57- n_routed_experts = self .n_routed_experts ,
58- split_inter_size = split_inter_size ,
59- data_type = self .data_type_ ,
60- )
52+ moe_mode = os .getenv ("MOE_MODE" , "TP" )
53+ assert moe_mode in ["TP" ], f"Unsupported moe mode: { moe_mode } "
54+
55+ if moe_mode == "TP" :
56+ self .experts = FusedMoeWeightTP (
57+ gate_proj_name = "w1" ,
58+ down_proj_name = "w2" ,
59+ up_proj_name = "w3" ,
60+ e_score_correction_bias_name = "" ,
61+ weight_prefix = f"model.layers.{ self .layer_num_ } .block_sparse_moe.experts" ,
62+ n_routed_experts = self .n_routed_experts ,
63+ split_inter_size = split_inter_size ,
64+ data_type = self .data_type_ ,
65+ network_config = self .network_config_ ,
66+ layer_num = self .layer_num_ ,
67+ quant_cfg = self .quant_cfg ,
68+ num_fused_shared_experts = 0 ,
69+ )
70+ else :
71+ raise ValueError (f"Unsupported moe mode: { moe_mode } " )
0 commit comments