33from lightllm .utils .dist_utils import get_world_size , get_rank
44import threading
55from lightllm .common .quantization import vLLMFP8w8a8QuantizationMethod
6+ import os
67
78try :
89 HAS_VLLM = True
@@ -28,6 +29,8 @@ def __init__(
2829 self .tp_rank_ = get_rank ()
2930 self .experts_up_projs = [None ] * self .n_routed_experts
3031 self .experts_gate_projs = [None ] * self .n_routed_experts
32+ self .expert_gate_up_proj_etp = None
33+ self .expert_down_proj_etp = None
3134 self .w2_list = [None ] * self .n_routed_experts
3235 self .quant_method = None
3336 self .lock = threading .Lock ()
@@ -36,9 +39,10 @@ def set_quant_method(self, quant_method):
3639 if isinstance (quant_method , vLLMFP8w8a8QuantizationMethod ):
3740 self .quant_method = quant_method
3841 if self .quant_method is not None :
39- self .quant_method .is_moe = True
42+ self .quant_method .is_moe = True
4043
4144 def experts (self , input_tensor , router_logits , top_k , renormalize , use_grouped_topk , topk_group , num_expert_group ):
45+
4246 topk_weights , topk_ids = FusedMoE .select_experts (
4347 hidden_states = input_tensor ,
4448 router_logits = router_logits ,
@@ -95,27 +99,90 @@ def _fuse(self):
9599 delattr (self , "experts_up_projs" )
96100 delattr (self , "experts_gate_projs" )
97101
102+
103+ def _load_hf_weights_etp (self , weights ):
104+ world_size_ = get_world_size ()
105+ assert self .n_routed_experts % world_size_ == 0
106+ n_expert_ep = self .n_routed_experts // world_size_
107+
108+ #tp to ep here
109+ expert_gate_up_proj_last = None
110+ expert_down_proj_last = None
111+
112+ for i_experts_ep in range (n_expert_ep ):
113+ expert_up_proj = None
114+ expert_gate_proj = None
115+ expert_gate_up_proj = None
116+ expert_down_proj = None
117+ i_experts = i_experts_ep + n_expert_ep * self .tp_rank_
118+
119+ if f"{ self .weight_prefix } .{ i_experts } .up_proj.weight" in weights :
120+ expert_up_proj = weights [f"{ self .weight_prefix } .{ i_experts } .up_proj.weight" ]
121+
122+ #self.experts_up_proj[i_experts] = expert_up_proj
123+
124+ if f"{ self .weight_prefix } .{ i_experts } .gate_proj.weight" in weights :
125+ expert_gate_proj = weights [f"{ self .weight_prefix } .{ i_experts } .gate_proj.weight" ]
126+ #self.experts_gate_proj[i_experts] = expert_gate_proj
127+
128+ if expert_gate_proj is not None and expert_up_proj is not None :
129+ expert_gate_up_proj = torch .cat ([expert_gate_proj , expert_up_proj ], dim = 0 )
130+ self .experts_gate_projs [i_experts_ep ] = expert_gate_up_proj #self._cuda(expert_gate_up_proj)
131+ expert_gate_up_proj_last = expert_gate_up_proj
132+
133+ if f"{ self .weight_prefix } .{ i_experts } .down_proj.weight" in weights :
134+ expert_down_proj = weights [f"{ self .weight_prefix } .{ i_experts } .down_proj.weight" ]
135+ self .experts_up_projs [i_experts_ep ] = expert_down_proj #self._cuda(expert_down_proj)
136+ expert_down_proj_last = expert_down_proj
137+
138+ with self .lock :
139+ if expert_gate_up_proj_last is not None :
140+ #package, if there is broken experts
141+
142+ if self .expert_gate_up_proj_etp is None :
143+ self .expert_gate_up_proj_etp = torch .zeros ( (n_expert_ep ,) + expert_gate_up_proj_last .shape ,
144+ dtype = expert_gate_up_proj_last .dtype ).cuda (self .tp_rank_ )
145+
146+ for i_experts_ep in range (n_expert_ep ):
147+ if self .experts_gate_projs [i_experts_ep ] is not None :
148+ self .expert_gate_up_proj_etp [i_experts_ep ,:] = self .experts_gate_projs [i_experts_ep ]
149+
150+
151+ if expert_down_proj_last is not None :
152+ #package, if there is broken experts
153+ if self .expert_down_proj_etp is None :
154+ self .expert_down_proj_etp = torch .zeros ( (n_expert_ep ,) + expert_down_proj_last .shape ,
155+ dtype = expert_down_proj_last .dtype ).cuda (self .tp_rank_ )
156+
157+ for i_experts_ep in range (n_expert_ep ):
158+ if self .experts_up_projs [i_experts_ep ] is not None :
159+ self .expert_down_proj_etp [i_experts_ep ,:] = self .experts_up_projs [i_experts_ep ]
160+
161+
98162 def load_hf_weights (self , weights ):
99- for i_experts in range (self .n_routed_experts ):
100- w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
101- w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
102- w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
103-
104- if w1_weight in weights :
105- self .experts_gate_projs [i_experts ] = weights [w1_weight ][
106- self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
107- ]
108- if w3_weight in weights :
109- self .experts_up_projs [i_experts ] = weights [w3_weight ][
110- self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
111- ]
112-
113- if w2_weight in weights :
114- self .w2_list [i_experts ] = weights [w2_weight ][
115- :, self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 )
116- ]
117-
118- self ._fuse ()
163+ if os .environ .get ("ETP_MODE_ENABLED" ) == "true" :
164+ self ._load_hf_weights_etp (weights )
165+ else :
166+ for i_experts in range (self .n_routed_experts ):
167+ w1_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w1_weight_name } .weight"
168+ w2_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w2_weight_name } .weight"
169+ w3_weight = f"{ self .weight_prefix } .{ i_experts } .{ self .w3_weight_name } .weight"
170+
171+ if w1_weight in weights :
172+ self .experts_gate_projs [i_experts ] = weights [w1_weight ][
173+ self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
174+ ]
175+ if w3_weight in weights :
176+ self .experts_up_projs [i_experts ] = weights [w3_weight ][
177+ self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 ), :
178+ ]
179+
180+ if w2_weight in weights :
181+ self .w2_list [i_experts ] = weights [w2_weight ][
182+ :, self .split_inter_size * self .tp_rank_ : self .split_inter_size * (self .tp_rank_ + 1 )
183+ ]
184+
185+ self ._fuse ()
119186
120187 def _cuda (self , cpu_tensor ):
121188 if self .tp_rank_ is None :
@@ -124,4 +191,7 @@ def _cuda(self, cpu_tensor):
124191 return cpu_tensor .contiguous ().to (self .data_type_ ).cuda (self .tp_rank_ )
125192
126193 def verify_load (self ):
127- return self .w1 is not None and self .w2 is not None
194+ if os .environ .get ("ETP_MODE_ENABLED" ) == "true" :
195+ return True
196+ else :
197+ return self .w1 is not None and self .w2 is not None
0 commit comments