Skip to content

Commit eaadb4e

Browse files
committed
fix
1 parent 8ed97c7 commit eaadb4e

File tree

2 files changed

+275
-1
lines changed

2 files changed

+275
-1
lines changed

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
1414
from functools import partial
1515
from lightllm.utils.log_utils import init_logger
16+
from lightllm.utils.dist_utils import get_global_world_size
1617

1718
logger = init_logger(__name__)
1819

@@ -27,6 +28,7 @@ def __init__(self, layer_num, network_config, mode=[]):
2728
)
2829
self.num_experts_per_tok = network_config["num_experts_per_tok"]
2930
self.norm_topk_prob = network_config["norm_topk_prob"]
31+
self.n_shared_experts = network_config.get("n_shared_experts", None)
3032
super().__init__(layer_num, network_config, mode)
3133
self.head_dim_ = network_config["head_dim"]
3234
self.tp_k_head_num_ = max(self.tp_k_head_num_, 1)
@@ -120,3 +122,274 @@ def _moe_ffn_edp(
120122

121123
ep_output = ep_output.view(token_num, hidden_dim)
122124
return ep_output
125+
126+
def overlap_tpsp_token_forward(
127+
self,
128+
input_embdings: torch.Tensor,
129+
input_embdings1: torch.Tensor,
130+
infer_state: LlamaInferStateInfo,
131+
infer_state1: LlamaInferStateInfo,
132+
layer_weight: Qwen3MOETransformerLayerWeight,
133+
):
134+
if not self.is_moe:
135+
return super().overlap_tpsp_token_forward(
136+
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
137+
)
138+
# 0 attention
139+
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
140+
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
141+
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
142+
_0_input1 = None
143+
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
144+
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
145+
_0_q = None
146+
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
147+
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
148+
_0_o = None
149+
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
150+
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
151+
# 1 hook
152+
if getattr(infer_state1, "hook", None) is not None:
153+
infer_state1.hook()
154+
infer_state1.hook = None
155+
156+
# 0 shared expert
157+
if self.n_shared_experts is not None:
158+
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)
159+
160+
# 0 dispatch
161+
(
162+
_0_recv_x,
163+
_0_masked_m,
164+
_0_topk_idx,
165+
_0_topk_weight,
166+
_0_handle,
167+
_0_hook,
168+
) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits)
169+
infer_state.hook = _0_hook
170+
171+
# 1 attention
172+
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
173+
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
174+
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
175+
_1_input1 = None
176+
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
177+
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight)
178+
_1_q = None
179+
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
180+
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
181+
_1_o = None
182+
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
183+
# to do gate and disptatch
184+
185+
_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
186+
# 0 hook
187+
if getattr(infer_state, "hook", None) is not None:
188+
infer_state.hook()
189+
infer_state.hook = None
190+
191+
# 1 shared expert
192+
if self.n_shared_experts is not None:
193+
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)
194+
195+
# 1 dispatch
196+
(
197+
_1_recv_x,
198+
_1_masked_m,
199+
_1_topk_idx,
200+
_1_topk_weight,
201+
_1_handle,
202+
_1_hook,
203+
) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits)
204+
infer_state1.hook = _1_hook
205+
206+
# moe calu
207+
expected_m = triton.cdiv(
208+
input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts
209+
)
210+
_0_moe_out = layer_weight.experts.masked_group_gemm(_0_recv_x, _0_masked_m, input_embdings.dtype, expected_m)
211+
212+
# 1 hook
213+
if getattr(infer_state1, "hook", None) is not None:
214+
infer_state1.hook()
215+
infer_state1.hook = None
216+
217+
# 0 combine
218+
_0_ffn_out, _0_hook = layer_weight.experts.low_latency_combine(
219+
_0_moe_out, _0_topk_idx, _0_topk_weight, _0_handle
220+
)
221+
222+
infer_state.hook = _0_hook
223+
224+
# to do moe caclue
225+
_1_moe_out = layer_weight.experts.masked_group_gemm(_1_recv_x, _1_masked_m, input_embdings1.dtype, expected_m)
226+
227+
# 0 hook
228+
if getattr(infer_state, "hook", None) is not None:
229+
infer_state.hook()
230+
# _0_ffn_out *= self.routed_scaling_factor
231+
if self.n_shared_experts is not None:
232+
_0_ffn_out.add_(_0_shared_output)
233+
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
234+
infer_state.hook = None
235+
236+
# 1 combine
237+
_1_ffn_out, _1_hook = layer_weight.experts.low_latency_combine(
238+
_1_moe_out, _1_topk_idx, _1_topk_weight, _1_handle
239+
)
240+
241+
def _1_hook_post():
242+
_1_hook()
243+
nonlocal _1_ffn_out
244+
# _1_ffn_out *= self.routed_scaling_factor
245+
if self.n_shared_experts is not None:
246+
_1_ffn_out.add_(_1_shared_output)
247+
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
248+
return
249+
250+
infer_state1.hook = _1_hook_post
251+
252+
return input_embdings, input_embdings1
253+
254+
def overlap_tpsp_context_forward(
255+
self,
256+
input_embdings: torch.Tensor,
257+
input_embdings1: torch.Tensor,
258+
infer_state: LlamaInferStateInfo,
259+
infer_state1: LlamaInferStateInfo,
260+
layer_weight: Qwen3MOETransformerLayerWeight,
261+
):
262+
if not self.is_moe:
263+
return super().overlap_tpsp_context_forward(
264+
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
265+
)
266+
# 0 attention
267+
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
268+
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
269+
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
270+
_0_input1 = None
271+
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
272+
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
273+
_0_q = None
274+
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
275+
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
276+
_0_o = None
277+
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
278+
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
279+
280+
# wait last 1 combine
281+
if getattr(infer_state1, "hook", None) is not None:
282+
infer_state1.hook()
283+
infer_state1.hook = None
284+
285+
_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
286+
_0_input1, _0_router_logits
287+
)
288+
from deep_ep import Buffer
289+
290+
_0_overlap_event = Buffer.capture()
291+
292+
# 1 attention
293+
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
294+
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
295+
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
296+
_1_input1 = None
297+
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
298+
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight)
299+
_1_q = None
300+
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
301+
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
302+
_1_o = None
303+
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
304+
# to do gate and disptatch
305+
306+
_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
307+
308+
# 0 dispatch execute
309+
(
310+
_0_recv_x,
311+
_0_recv_topk_idx,
312+
_0_recv_topk_weight,
313+
_0_num_recv_tokens_per_expert_list,
314+
_0_handle,
315+
_0_hook,
316+
) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event)
317+
infer_state.hook = _0_hook
318+
319+
# wait 0 dispatch
320+
if getattr(infer_state, "hook", None) is not None:
321+
infer_state.hook()
322+
infer_state.hook = None
323+
324+
_1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
325+
_1_input1, _1_router_logits
326+
)
327+
328+
_1_overlap_event = Buffer.capture()
329+
330+
# 0 shared expert
331+
if self.n_shared_experts is not None:
332+
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)
333+
334+
# 1 shared expert
335+
if self.n_shared_experts is not None:
336+
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)
337+
338+
# 0 moe calu
339+
_0_moe_out = layer_weight.experts.prefilled_group_gemm(
340+
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight
341+
)
342+
343+
# 1 dispatch execute
344+
(
345+
_1_recv_x,
346+
_1_recv_topk_idx,
347+
_1_recv_topk_weight,
348+
_1_num_recv_tokens_per_expert_list,
349+
_1_handle,
350+
_1_hook,
351+
) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event)
352+
infer_state1.hook = _1_hook
353+
354+
# wait 1 dispatch
355+
if getattr(infer_state1, "hook", None) is not None:
356+
infer_state1.hook()
357+
infer_state1.hook = None
358+
359+
_0_combine_event = Buffer.capture()
360+
# 0 combine execute
361+
_0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event)
362+
infer_state.hook = _0_hook
363+
364+
# 1 moe calc
365+
_1_moe_out = layer_weight.experts.prefilled_group_gemm(
366+
_1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight
367+
)
368+
369+
# wait 0 combine
370+
if getattr(infer_state, "hook", None) is not None:
371+
infer_state.hook()
372+
infer_state.hook = None
373+
374+
_1_combine_event = Buffer.capture()
375+
376+
# _0_ffn_out *= self.routed_scaling_factor
377+
if self.n_shared_experts is not None:
378+
_0_ffn_out.add_(_0_shared_output)
379+
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
380+
381+
# 1 combine execute
382+
_1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event)
383+
384+
def _1_hook_post():
385+
_1_hook()
386+
nonlocal _1_ffn_out
387+
# _1_ffn_out *= self.routed_scaling_factor
388+
if self.n_shared_experts is not None:
389+
_1_ffn_out.add_(_1_shared_output)
390+
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
391+
return
392+
393+
infer_state1.hook = _1_hook_post
394+
395+
return input_embdings, input_embdings1

test/benchmark/static_inference/model_infer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an
369369
enable_decode_overlap = args.enable_decode_microbatch_overlap
370370
group_size = 1
371371
if enable_decode_overlap or args.enable_prefill_microbatch_overlap:
372-
assert batch_size % 2 == 0, "batch size must be even number"
372+
for bs in batch_size:
373+
assert bs % 2 == 0, "batch size must be even number"
373374
group_size = 2
374375
init_distributed_env(model_kvargs)
375376
dist_group_manager.create_groups(group_size=group_size)

0 commit comments

Comments
 (0)