Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
from functools import partial
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_global_world_size

logger = init_logger(__name__)

Expand Down Expand Up @@ -120,3 +121,246 @@ def _moe_ffn_edp(

ep_output = ep_output.view(token_num, hidden_dim)
return ep_output

def overlap_tpsp_token_forward(
self,
input_embdings: torch.Tensor,
input_embdings1: torch.Tensor,
infer_state: LlamaInferStateInfo,
infer_state1: LlamaInferStateInfo,
layer_weight: Qwen3MOETransformerLayerWeight,
):
if not self.is_moe:
return super().overlap_tpsp_token_forward(
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
)
# 0 attention
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
_0_input1 = None
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

# 0 dispatch
(
_0_recv_x,
_0_masked_m,
_0_topk_idx,
_0_topk_weight,
_0_handle,
_0_hook,
) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits)
infer_state.hook = _0_hook

# 1 attention
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
_1_input1 = None
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove this comment, as it seems to be a leftover from development and contains a typo ('disptatch' should be 'dispatch').


_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

# 1 dispatch
(
_1_recv_x,
_1_masked_m,
_1_topk_idx,
_1_topk_weight,
_1_handle,
_1_hook,
) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits)
infer_state1.hook = _1_hook

# moe calu
expected_m = triton.cdiv(
input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts
)
_0_moe_out = layer_weight.experts.masked_group_gemm(_0_recv_x, _0_masked_m, input_embdings.dtype, expected_m)

# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

# 0 combine
_0_ffn_out, _0_hook = layer_weight.experts.low_latency_combine(
_0_moe_out, _0_topk_idx, _0_topk_weight, _0_handle
)

infer_state.hook = _0_hook

# to do moe caclue
_1_moe_out = layer_weight.experts.masked_group_gemm(_1_recv_x, _1_masked_m, input_embdings1.dtype, expected_m)

# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
infer_state.hook = None

# 1 combine
_1_ffn_out, _1_hook = layer_weight.experts.low_latency_combine(
_1_moe_out, _1_topk_idx, _1_topk_weight, _1_handle
)

def _1_hook_post():
_1_hook()
nonlocal _1_ffn_out
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
return

infer_state1.hook = _1_hook_post

return input_embdings, input_embdings1

def overlap_tpsp_context_forward(
self,
input_embdings: torch.Tensor,
input_embdings1: torch.Tensor,
infer_state: LlamaInferStateInfo,
infer_state1: LlamaInferStateInfo,
layer_weight: Qwen3MOETransformerLayerWeight,
):
if not self.is_moe:
return super().overlap_tpsp_context_forward(
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
)
# 0 attention
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
_0_input1 = None
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)

# wait last 1 combine
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_0_input1, _0_router_logits
)
from deep_ep import Buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move this local import of deep_ep to the top of the file. If deep_ep is optional, wrap the import in a try...except ImportError block.


_0_overlap_event = Buffer.capture()

# 1 attention
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
_1_input1 = None
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

_1_router_logits = layer_weight.moe_gate.mm(_1_input1)

# 0 dispatch execute
(
_0_recv_x,
_0_recv_topk_idx,
_0_recv_topk_weight,
_0_num_recv_tokens_per_expert_list,
_0_handle,
_0_hook,
) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event)
infer_state.hook = _0_hook

# wait 0 dispatch
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

_1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_1_input1, _1_router_logits
)

_1_overlap_event = Buffer.capture()

# 0 moe calu
_0_moe_out = layer_weight.experts.prefilled_group_gemm(
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight
)

# 1 dispatch execute
(
_1_recv_x,
_1_recv_topk_idx,
_1_recv_topk_weight,
_1_num_recv_tokens_per_expert_list,
_1_handle,
_1_hook,
) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event)
infer_state1.hook = _1_hook

# wait 1 dispatch
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

_0_combine_event = Buffer.capture()
# 0 combine execute
_0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event)
infer_state.hook = _0_hook

# 1 moe calc
_1_moe_out = layer_weight.experts.prefilled_group_gemm(
_1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight
)

# wait 0 combine
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

_1_combine_event = Buffer.capture()

input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))

# 1 combine execute
_1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event)

def _1_hook_post():
_1_hook()
nonlocal _1_ffn_out
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
return

infer_state1.hook = _1_hook_post

return input_embdings, input_embdings1
3 changes: 2 additions & 1 deletion test/benchmark/static_inference/model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,8 @@ def tppart_model_infer(args, model_kvargs, batch_size, input_len, output_len, an
enable_decode_overlap = args.enable_decode_microbatch_overlap
group_size = 1
if enable_decode_overlap or args.enable_prefill_microbatch_overlap:
assert batch_size % 2 == 0, "batch size must be even number"
for bs in batch_size:
assert bs % 2 == 0, "batch size must be even number"
Comment on lines +372 to +373
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message is not clear. Clarify the assertion message to indicate which batch size is not an even number.

assert bs % 2 == 0, f"Batch size {bs} must be an even number"

group_size = 2
init_distributed_env(model_kvargs)
dist_group_manager.create_groups(group_size=group_size)
Expand Down