Skip to content

Commit a69b2bc

Browse files
authored
remove shared expert
1 parent eaadb4e commit a69b2bc

File tree

1 file changed

+0
-29
lines changed

1 file changed

+0
-29
lines changed

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def __init__(self, layer_num, network_config, mode=[]):
2828
)
2929
self.num_experts_per_tok = network_config["num_experts_per_tok"]
3030
self.norm_topk_prob = network_config["norm_topk_prob"]
31-
self.n_shared_experts = network_config.get("n_shared_experts", None)
3231
super().__init__(layer_num, network_config, mode)
3332
self.head_dim_ = network_config["head_dim"]
3433
self.tp_k_head_num_ = max(self.tp_k_head_num_, 1)
@@ -153,10 +152,6 @@ def overlap_tpsp_token_forward(
153152
infer_state1.hook()
154153
infer_state1.hook = None
155154

156-
# 0 shared expert
157-
if self.n_shared_experts is not None:
158-
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)
159-
160155
# 0 dispatch
161156
(
162157
_0_recv_x,
@@ -188,10 +183,6 @@ def overlap_tpsp_token_forward(
188183
infer_state.hook()
189184
infer_state.hook = None
190185

191-
# 1 shared expert
192-
if self.n_shared_experts is not None:
193-
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)
194-
195186
# 1 dispatch
196187
(
197188
_1_recv_x,
@@ -227,9 +218,6 @@ def overlap_tpsp_token_forward(
227218
# 0 hook
228219
if getattr(infer_state, "hook", None) is not None:
229220
infer_state.hook()
230-
# _0_ffn_out *= self.routed_scaling_factor
231-
if self.n_shared_experts is not None:
232-
_0_ffn_out.add_(_0_shared_output)
233221
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
234222
infer_state.hook = None
235223

@@ -241,9 +229,6 @@ def overlap_tpsp_token_forward(
241229
def _1_hook_post():
242230
_1_hook()
243231
nonlocal _1_ffn_out
244-
# _1_ffn_out *= self.routed_scaling_factor
245-
if self.n_shared_experts is not None:
246-
_1_ffn_out.add_(_1_shared_output)
247232
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
248233
return
249234

@@ -327,14 +312,6 @@ def overlap_tpsp_context_forward(
327312

328313
_1_overlap_event = Buffer.capture()
329314

330-
# 0 shared expert
331-
if self.n_shared_experts is not None:
332-
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)
333-
334-
# 1 shared expert
335-
if self.n_shared_experts is not None:
336-
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)
337-
338315
# 0 moe calu
339316
_0_moe_out = layer_weight.experts.prefilled_group_gemm(
340317
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight
@@ -373,9 +350,6 @@ def overlap_tpsp_context_forward(
373350

374351
_1_combine_event = Buffer.capture()
375352

376-
# _0_ffn_out *= self.routed_scaling_factor
377-
if self.n_shared_experts is not None:
378-
_0_ffn_out.add_(_0_shared_output)
379353
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
380354

381355
# 1 combine execute
@@ -384,9 +358,6 @@ def overlap_tpsp_context_forward(
384358
def _1_hook_post():
385359
_1_hook()
386360
nonlocal _1_ffn_out
387-
# _1_ffn_out *= self.routed_scaling_factor
388-
if self.n_shared_experts is not None:
389-
_1_ffn_out.add_(_1_shared_output)
390361
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
391362
return
392363

0 commit comments

Comments
 (0)