Skip to content

Comments

support qwen3moe overlap mode#974

Merged
shihaobai merged 2 commits intomainfrom
qwen_overlap
Jul 22, 2025
Merged

support qwen3moe overlap mode#974
shihaobai merged 2 commits intomainfrom
qwen_overlap

Conversation

@sufubao
Copy link
Collaborator

@sufubao sufubao commented Jul 16, 2025

No description provided.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @sufubao, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the Qwen3-MoE model's inference capabilities by introducing an 'overlap mode'. This mode optimizes performance by executing attention and MoE FFN computations concurrently across different inputs, effectively reducing overall latency. The changes also include necessary configuration updates for shared experts and a corresponding adjustment to the benchmarking suite to properly test this new, more efficient inference pathway.

Highlights

  • New Overlap Inference Mode for Qwen3-MoE: Introduced overlap_tpsp_token_forward and overlap_tpsp_context_forward methods in the Qwen3-MoE transformer layer. These methods enable overlapping computation stages for token and context generation, respectively, aiming to improve inference performance by hiding latency in Mixture-of-Experts (MoE) operations.
  • Support for Shared Experts: Added n_shared_experts as a configurable parameter to the Qwen3-MoE transformer layer. This allows the new overlap mode to properly handle models that incorporate shared experts within their MoE architecture.
  • Benchmark Update for Overlap Mode: Modified the static inference benchmark script to validate that all specified batch sizes are even when overlap modes are enabled. This ensures compatibility with the dual-stream processing nature of the new overlap inference implementation.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces an overlap mode for qwen3_moe models. The new methods in transformer_layer_infer.py need refactoring to reduce complexity and code duplication. Also, the local import of deep_ep needs to be addressed, and leftover comments/code should be removed. Finally, the assertion message in model_infer.py should be clarified.

Comment on lines 126 to 395
def overlap_tpsp_token_forward(
self,
input_embdings: torch.Tensor,
input_embdings1: torch.Tensor,
infer_state: LlamaInferStateInfo,
infer_state1: LlamaInferStateInfo,
layer_weight: Qwen3MOETransformerLayerWeight,
):
if not self.is_moe:
return super().overlap_tpsp_token_forward(
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
)
# 0 attention
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
_0_input1 = None
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
_0_o = self._token_attention_kernel(_0_q, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)
# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

# 0 shared expert
if self.n_shared_experts is not None:
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)

# 0 dispatch
(
_0_recv_x,
_0_masked_m,
_0_topk_idx,
_0_topk_weight,
_0_handle,
_0_hook,
) = layer_weight.experts.low_latency_dispatch(_0_input1, _0_router_logits)
infer_state.hook = _0_hook

# 1 attention
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
_1_input1 = None
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
_1_o = self._token_attention_kernel(_1_q, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

_1_router_logits = layer_weight.moe_gate.mm(_1_input1)
# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

# 1 shared expert
if self.n_shared_experts is not None:
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)

# 1 dispatch
(
_1_recv_x,
_1_masked_m,
_1_topk_idx,
_1_topk_weight,
_1_handle,
_1_hook,
) = layer_weight.experts.low_latency_dispatch(_1_input1, _1_router_logits)
infer_state1.hook = _1_hook

# moe calu
expected_m = triton.cdiv(
input_embdings.shape[0] * get_global_world_size() * self.num_experts_per_tok, self.n_routed_experts
)
_0_moe_out = layer_weight.experts.masked_group_gemm(_0_recv_x, _0_masked_m, input_embdings.dtype, expected_m)

# 1 hook
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

# 0 combine
_0_ffn_out, _0_hook = layer_weight.experts.low_latency_combine(
_0_moe_out, _0_topk_idx, _0_topk_weight, _0_handle
)

infer_state.hook = _0_hook

# to do moe caclue
_1_moe_out = layer_weight.experts.masked_group_gemm(_1_recv_x, _1_masked_m, input_embdings1.dtype, expected_m)

# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
# _0_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_0_ffn_out.add_(_0_shared_output)
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))
infer_state.hook = None

# 1 combine
_1_ffn_out, _1_hook = layer_weight.experts.low_latency_combine(
_1_moe_out, _1_topk_idx, _1_topk_weight, _1_handle
)

def _1_hook_post():
_1_hook()
nonlocal _1_ffn_out
# _1_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_1_ffn_out.add_(_1_shared_output)
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
return

infer_state1.hook = _1_hook_post

return input_embdings, input_embdings1

def overlap_tpsp_context_forward(
self,
input_embdings: torch.Tensor,
input_embdings1: torch.Tensor,
infer_state: LlamaInferStateInfo,
infer_state1: LlamaInferStateInfo,
layer_weight: Qwen3MOETransformerLayerWeight,
):
if not self.is_moe:
return super().overlap_tpsp_context_forward(
input_embdings, input_embdings1, infer_state, infer_state1, layer_weight
)
# 0 attention
_0_input1 = self._att_norm(input_embdings, infer_state, layer_weight)
_0_cache_kv = self._pre_cache_kv(infer_state, layer_weight)
_0_q, _0_cache_kv = self._tpsp_get_qkv(_0_input1, _0_cache_kv, infer_state, layer_weight)
_0_input1 = None
self._post_cache_kv(_0_cache_kv, infer_state, layer_weight)
_0_o = self._context_attention_kernel(_0_q, _0_cache_kv, infer_state, layer_weight)
_0_q = None
_0_o = self._tpsp_get_o(_0_o, infer_state, layer_weight)
input_embdings.add_(_0_o.view(-1, self.embed_dim_))
_0_o = None
_0_input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
_0_router_logits = layer_weight.moe_gate.mm(_0_input1)

# wait last 1 combine
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_0_input1, _0_router_logits
)
from deep_ep import Buffer

_0_overlap_event = Buffer.capture()

# 1 attention
_1_input1 = self._att_norm(input_embdings1, infer_state1, layer_weight)
_1_cache_kv = self._pre_cache_kv(infer_state1, layer_weight)
_1_q, _1_cache_kv = self._tpsp_get_qkv(_1_input1, _1_cache_kv, infer_state1, layer_weight)
_1_input1 = None
self._post_cache_kv(_1_cache_kv, infer_state1, layer_weight)
_1_o = self._context_attention_kernel(_1_q, _1_cache_kv, infer_state1, layer_weight)
_1_q = None
_1_o = self._tpsp_get_o(_1_o, infer_state1, layer_weight)
input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch

_1_router_logits = layer_weight.moe_gate.mm(_1_input1)

# 0 dispatch execute
(
_0_recv_x,
_0_recv_topk_idx,
_0_recv_topk_weight,
_0_num_recv_tokens_per_expert_list,
_0_handle,
_0_hook,
) = layer_weight.experts.dispatch(_0_qinput_tensor, _0_topk_idx, _0_topk_weight, overlap_event=_0_overlap_event)
infer_state.hook = _0_hook

# wait 0 dispatch
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

_1_topk_weight, _1_topk_idx, _1_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_1_input1, _1_router_logits
)

_1_overlap_event = Buffer.capture()

# 0 shared expert
if self.n_shared_experts is not None:
_0_shared_output = LlamaTransformerLayerInfer._ffn(self, _0_input1, infer_state, layer_weight)

# 1 shared expert
if self.n_shared_experts is not None:
_1_shared_output = LlamaTransformerLayerInfer._ffn(self, _1_input1, infer_state1, layer_weight)

# 0 moe calu
_0_moe_out = layer_weight.experts.prefilled_group_gemm(
_0_num_recv_tokens_per_expert_list, _0_recv_x, _0_recv_topk_idx, _0_recv_topk_weight
)

# 1 dispatch execute
(
_1_recv_x,
_1_recv_topk_idx,
_1_recv_topk_weight,
_1_num_recv_tokens_per_expert_list,
_1_handle,
_1_hook,
) = layer_weight.experts.dispatch(_1_qinput_tensor, _1_topk_idx, _1_topk_weight, overlap_event=_1_overlap_event)
infer_state1.hook = _1_hook

# wait 1 dispatch
if getattr(infer_state1, "hook", None) is not None:
infer_state1.hook()
infer_state1.hook = None

_0_combine_event = Buffer.capture()
# 0 combine execute
_0_ffn_out, _0_hook = layer_weight.experts.combine(_0_moe_out, _0_handle, _0_combine_event)
infer_state.hook = _0_hook

# 1 moe calc
_1_moe_out = layer_weight.experts.prefilled_group_gemm(
_1_num_recv_tokens_per_expert_list, _1_recv_x, _1_recv_topk_idx, _1_recv_topk_weight
)

# wait 0 combine
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
infer_state.hook = None

_1_combine_event = Buffer.capture()

# _0_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_0_ffn_out.add_(_0_shared_output)
input_embdings.add_(_0_ffn_out.view(-1, self.embed_dim_))

# 1 combine execute
_1_ffn_out, _1_hook = layer_weight.experts.combine(_1_moe_out, _1_handle, _1_combine_event)

def _1_hook_post():
_1_hook()
nonlocal _1_ffn_out
# _1_ffn_out *= self.routed_scaling_factor
if self.n_shared_experts is not None:
_1_ffn_out.add_(_1_shared_output)
input_embdings1.add_(_1_ffn_out.view(-1, self.embed_dim_))
return

infer_state1.hook = _1_hook_post

return input_embdings, input_embdings1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The functions overlap_tpsp_token_forward and overlap_tpsp_context_forward are lengthy and contain duplicated logic. Refactor to improve readability and maintainability.

Consider these steps:

  • Extract common logic into helper methods.
  • Create a generic overlap function that accepts attention kernels as parameters.
  • Encapsulate hook management in a helper function like _execute_hook_if_exists(infer_state).

_0_topk_weight, _0_topk_idx, _0_qinput_tensor = layer_weight.experts.select_experts_and_quant_input(
_0_input1, _0_router_logits
)
from deep_ep import Buffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Move this local import of deep_ep to the top of the file. If deep_ep is optional, wrap the import in a try...except ImportError block.

input_embdings1.add_(_1_o.view(-1, self.embed_dim_))
_1_o = None
_1_input1 = self._ffn_norm(input_embdings1, infer_state1, layer_weight)
# to do gate and disptatch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Remove this comment, as it seems to be a leftover from development and contains a typo ('disptatch' should be 'dispatch').

# 0 hook
if getattr(infer_state, "hook", None) is not None:
infer_state.hook()
# _0_ffn_out *= self.routed_scaling_factor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This line is commented out. Either uncomment it if routed_scaling_factor should be applied, or remove the line if it's obsolete.

Comment on lines +372 to +373
for bs in batch_size:
assert bs % 2 == 0, "batch size must be even number"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion message is not clear. Clarify the assertion message to indicate which batch size is not an even number.

assert bs % 2 == 0, f"Batch size {bs} must be an even number"

@shihaobai shihaobai merged commit 74ed914 into main Jul 22, 2025
1 check passed
@shihaobai shihaobai deleted the qwen_overlap branch July 22, 2025 11:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants