Skip to content

Commit 9c2f746

Browse files
gcanlinIsotr0py
andauthored
[Perf] Use vLLM's SharedFusedMoE in Qwen3-Omni (vllm-project#560)
Signed-off-by: gcanlin <canlinguosdu@gmail.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 65b55d1 commit 9c2f746

File tree

1 file changed

+168
-98
lines changed

1 file changed

+168
-98
lines changed

vllm_omni/model_executor/models/qwen3_omni/qwen3_omni_moe_talker.py

Lines changed: 168 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
Qwen3OmniMoeAudioEncoder,
1717
)
1818
from vllm.config import VllmConfig
19+
from vllm.distributed import get_tensor_model_parallel_world_size
1920
from vllm.logger import init_logger
2021
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
22+
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
2123
from vllm.model_executor.layers.linear import ReplicatedLinear
24+
from vllm.model_executor.layers.quantization import QuantizationConfig
2225
from vllm.model_executor.models.interfaces import (
2326
MultiModalEmbeddings,
2427
SupportsMultiModal,
@@ -27,13 +30,12 @@
2730
from vllm.model_executor.models.qwen2_5_omni_thinker import (
2831
Qwen2_5OmniThinkerDummyInputsBuilder,
2932
)
30-
from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP
33+
from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock
3134
from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer
3235
from vllm.model_executor.models.utils import (
3336
AutoWeightsLoader,
3437
WeightsMapper,
3538
maybe_prefix,
36-
sequence_parallel_chunk,
3739
)
3840
from vllm.multimodal import MULTIMODAL_REGISTRY
3941
from vllm.sequence import IntermediateTensors
@@ -531,130 +533,198 @@ def forward(self, hidden_state):
531533
return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
532534

533535

536+
class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module):
537+
"""
538+
Wrapper that combines shared_expert MLP with its sigmoid gate.
539+
540+
This matches the HuggingFace weight structure where:
541+
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
542+
- mlp.shared_expert_gate.weight (sibling, not child)
543+
544+
The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x)
545+
"""
546+
547+
def __init__(
548+
self,
549+
shared_expert: Qwen3MoeMLP,
550+
shared_expert_gate: nn.Linear,
551+
):
552+
super().__init__()
553+
self._shared_expert = shared_expert
554+
self._shared_expert_gate = shared_expert_gate
555+
556+
def forward(self, x: torch.Tensor) -> torch.Tensor:
557+
out = self._shared_expert(x)
558+
gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1]
559+
return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden]
560+
561+
562+
class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module):
563+
"""
564+
Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support.
565+
566+
This block uses SharedFusedMoE to efficiently compute both routed experts
567+
and the shared expert, potentially overlapping computation with communication.
568+
569+
Weight structure matches HuggingFace:
570+
- mlp.gate.weight (router)
571+
- mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight
572+
- mlp.shared_expert_gate.weight
573+
- mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight
574+
"""
575+
576+
def __init__(
577+
self,
578+
config: Qwen3OmniMoeTalkerConfig,
579+
quant_config: QuantizationConfig | None = None,
580+
prefix: str = "",
581+
):
582+
super().__init__()
583+
text_config = config.text_config
584+
self.tp_size = get_tensor_model_parallel_world_size()
585+
586+
if self.tp_size > text_config.num_experts:
587+
raise ValueError(
588+
f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}."
589+
)
590+
591+
# Router gate for selecting top-k experts
592+
self.gate = ReplicatedLinear(
593+
text_config.hidden_size,
594+
text_config.num_experts,
595+
bias=False,
596+
quant_config=quant_config,
597+
prefix=f"{prefix}.gate",
598+
)
599+
600+
# Shared expert MLP (matches HF: mlp.shared_expert.*)
601+
if text_config.shared_expert_intermediate_size > 0:
602+
self.shared_expert = Qwen3MoeMLP(
603+
hidden_size=text_config.hidden_size,
604+
intermediate_size=text_config.shared_expert_intermediate_size,
605+
hidden_act=text_config.hidden_act,
606+
quant_config=quant_config,
607+
reduce_results=False, # Don't reduce, we'll handle it
608+
prefix=f"{prefix}.shared_expert",
609+
)
610+
# Shared expert gate (matches HF: mlp.shared_expert_gate.weight)
611+
# This is a sibling of shared_expert, not a child
612+
self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False)
613+
# Create wrapper for SharedFusedMoE
614+
self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper(
615+
self.shared_expert, self.shared_expert_gate
616+
)
617+
else:
618+
self.shared_expert = None
619+
self.shared_expert_gate = None
620+
self._shared_expert_wrapper = None
621+
622+
# Fused MoE with shared expert support
623+
self.experts = SharedFusedMoE(
624+
shared_experts=self._shared_expert_wrapper,
625+
num_experts=text_config.num_experts,
626+
top_k=text_config.num_experts_per_tok,
627+
hidden_size=text_config.hidden_size,
628+
intermediate_size=text_config.moe_intermediate_size,
629+
reduce_results=False, # We'll reduce manually after combining
630+
renormalize=text_config.norm_topk_prob,
631+
quant_config=quant_config,
632+
prefix=f"{prefix}.experts",
633+
)
634+
635+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
636+
# NOTE: hidden_states can have either 1D or 2D shape.
637+
orig_shape = hidden_states.shape
638+
hidden_dim = hidden_states.shape[-1]
639+
hidden_states = hidden_states.view(-1, hidden_dim)
640+
641+
# Compute router logits
642+
router_logits, _ = self.gate(hidden_states)
643+
644+
# Forward through SharedFusedMoE
645+
# Returns (shared_out, fused_out) when shared_expert is present
646+
final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)
647+
648+
# Combine shared and routed expert outputs
649+
if self._shared_expert_wrapper is not None:
650+
# SharedFusedMoE returns tuple: (shared_out, fused_out)
651+
final_hidden_states = final_hidden_states[0] + final_hidden_states[1]
652+
653+
# Apply tensor parallel reduction if needed
654+
if self.tp_size > 1:
655+
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
656+
657+
return final_hidden_states.view(orig_shape)
658+
659+
534660
class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM):
535-
def __init__(self, vllm_config, talker_config, prefix):
661+
"""
662+
Qwen3 Omni MoE Talker language model.
663+
664+
This model extends Qwen3MoeLLMForCausalLM with:
665+
- Shared expert support via SharedFusedMoE
666+
- Codec embedding instead of text embedding
667+
- No LM head (codec head is separate in the parent class)
668+
"""
669+
670+
def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str):
671+
# Create a vllm_config for the talker's text model
536672
talker_vllm_config = vllm_config.with_hf_config(
537673
talker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
538674
)
539675
talker_vllm_config.model_config.hf_text_config = talker_vllm_config.model_config.hf_config
676+
540677
super().__init__(
541678
vllm_config=talker_vllm_config,
542679
prefix=prefix,
543680
)
544681

545682
self.config = talker_config
683+
self.talker_vllm_config = talker_vllm_config
546684

547685
# Remove the inherited LM head so the talker only exposes codec outputs.
548686
if hasattr(self, "lm_head"):
549687
del self.lm_head
550688

551-
# Replace the base embed tokens with codec embedding (defined below).
689+
# Replace the base embed tokens with codec embedding.
552690
if hasattr(self.model, "embed_tokens"):
553691
del self.model.embed_tokens
554692

555693
# Codec embedding for RVQ code generation
556694
self.model.codec_embedding = nn.Embedding(
557-
talker_config.text_config.vocab_size, talker_config.text_config.hidden_size
695+
talker_config.text_config.vocab_size,
696+
talker_config.text_config.hidden_size,
558697
)
559698

560-
# Add shared expert to each MoE layer and patch the forward method
561-
layer_idx = 0
562-
for layer in self.model.layers:
563-
# add shared expert to Qwen3OmniMoeSparseMoeBlock layers
564-
if hasattr(layer.mlp, "experts"): # Check if it's a SparseMoeBlock
565-
# Shared expert is a regular gated MLP (SwiGLU)
566-
layer.mlp.shared_expert = Qwen3MoeMLP(
567-
hidden_size=self.config.text_config.hidden_size,
568-
intermediate_size=self.config.text_config.shared_expert_intermediate_size,
569-
hidden_act=self.config.text_config.hidden_act,
570-
quant_config=talker_vllm_config.quant_config,
571-
reduce_results=False, # Don't reduce since we'll add it manually
572-
prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert",
573-
)
699+
# Replace MoE blocks with shared expert versions
700+
self._replace_moe_blocks_with_shared_expert(prefix)
574701

575-
# Shared expert gate outputs a single scalar per token
576-
layer.mlp.shared_expert_gate = ReplicatedLinear(
577-
self.config.text_config.hidden_size,
578-
1, # Output single scalar per token
579-
bias=False,
580-
quant_config=None,
581-
prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert_gate",
582-
)
583-
584-
# Store MoE config values for router computation
585-
layer.mlp.top_k = self.config.text_config.num_experts_per_tok
586-
layer.mlp.norm_topk_prob = self.config.text_config.norm_topk_prob
587-
layer.mlp.num_experts = self.config.text_config.num_experts
588-
589-
# Monkey-patch the forward method to use shared expert
590-
layer.mlp.forward = self._create_moe_forward_with_shared_expert(layer.mlp)
591-
592-
layer_idx += 1
593-
594-
def _create_moe_forward_with_shared_expert(self, moe_layer):
595-
"""Create a forward method that includes shared expert computation.
596-
597-
This matches the Transformers implementation where:
598-
1. Compute shared expert output (regular MLP)
599-
2. Gate it with sigmoid(shared_expert_gate(x))
600-
3. Apply softmax BEFORE top-k selection (matches Transformers router)
601-
4. Add to routed expert outputs
702+
def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None:
602703
"""
603-
604-
def forward_with_shared_expert(hidden_states: torch.Tensor, layer_idx: int = 0) -> torch.Tensor:
605-
# Save original shape
606-
orig_shape = hidden_states.shape
607-
hidden_dim = hidden_states.shape[-1]
608-
hidden_states = hidden_states.view(-1, hidden_dim)
609-
610-
# handle sequence parallel if needed
611-
if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel:
612-
hidden_states = sequence_parallel_chunk(hidden_states)
613-
614-
# Compute shared expert output
615-
# The shared expert is a regular MLP, not a routed MoE
616-
shared_output = None
617-
if hasattr(moe_layer, "shared_expert") and moe_layer.shared_expert is not None:
618-
# Forward through shared expert MLP
619-
shared_output = moe_layer.shared_expert(hidden_states)
620-
621-
# Apply gating with sigmoid: sigmoid(gate(x)) * shared_expert(x)
622-
if hasattr(moe_layer, "shared_expert_gate") and moe_layer.shared_expert_gate is not None:
623-
gate_logits, _ = moe_layer.shared_expert_gate(hidden_states)
624-
gate_values = F.sigmoid(gate_logits) # [batch, 1]
625-
shared_output = gate_values * shared_output # Broadcasting: [batch, 1] * [batch, hidden]
626-
627-
# Compute experts results
628-
# router_logits: (num_tokens, n_experts)
629-
router_logits, _ = moe_layer.gate(hidden_states)
630-
experts_output = moe_layer.experts(hidden_states=hidden_states, router_logits=router_logits)
631-
632-
# combine experts and shared expert results
633-
if shared_output is not None:
634-
final_hidden_states = experts_output + shared_output
635-
636-
# Handle sequence parallel if needed
637-
if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel:
638-
from vllm.distributed import tensor_model_parallel_all_gather
639-
640-
num_tokens = orig_shape[0] if len(orig_shape) > 1 else 1
641-
final_hidden_states = tensor_model_parallel_all_gather(final_hidden_states, 0)
642-
final_hidden_states = final_hidden_states[:num_tokens]
643-
try:
644-
final_hidden_states.view(orig_shape)
645-
except Exception as e:
646-
print(f"Error viewing final hidden states: {e}")
647-
print(f"final_hidden_states.shape: {final_hidden_states.shape}")
648-
print(f"orig_shape: {orig_shape}")
649-
raise e
650-
# Return with original shape
651-
return final_hidden_states.view(orig_shape)
652-
653-
return forward_with_shared_expert
704+
Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock
705+
that includes shared expert support via SharedFusedMoE.
706+
"""
707+
# Get compilation config to clean up registered layer names
708+
compilation_config = self.talker_vllm_config.compilation_config
709+
710+
for layer_idx, layer in enumerate(self.model.layers):
711+
# Check if this layer has a MoE block (has experts attribute)
712+
if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock):
713+
# Remove old layer registration from static_forward_context
714+
old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts"
715+
if old_experts_prefix in compilation_config.static_forward_context:
716+
del compilation_config.static_forward_context[old_experts_prefix]
717+
718+
# Create new MoE block with shared expert support
719+
layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock(
720+
config=self.config,
721+
quant_config=self.talker_vllm_config.quant_config,
722+
prefix=f"{prefix}.model.layers.{layer_idx}.mlp",
723+
)
654724

655725
def embed_input_ids(
656726
self,
657727
input_ids: torch.Tensor,
658-
**kwargs: object,
659728
) -> torch.Tensor:
729+
"""Embed codec input IDs."""
660730
return self.model.codec_embedding(input_ids)

0 commit comments

Comments
 (0)