|
16 | 16 | Qwen3OmniMoeAudioEncoder, |
17 | 17 | ) |
18 | 18 | from vllm.config import VllmConfig |
| 19 | +from vllm.distributed import get_tensor_model_parallel_world_size |
19 | 20 | from vllm.logger import init_logger |
20 | 21 | from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY |
| 22 | +from vllm.model_executor.layers.fused_moe import SharedFusedMoE |
21 | 23 | from vllm.model_executor.layers.linear import ReplicatedLinear |
| 24 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
22 | 25 | from vllm.model_executor.models.interfaces import ( |
23 | 26 | MultiModalEmbeddings, |
24 | 27 | SupportsMultiModal, |
|
27 | 30 | from vllm.model_executor.models.qwen2_5_omni_thinker import ( |
28 | 31 | Qwen2_5OmniThinkerDummyInputsBuilder, |
29 | 32 | ) |
30 | | -from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP |
| 33 | +from vllm.model_executor.models.qwen3_moe import Qwen3MoeMLP, Qwen3MoeSparseMoeBlock |
31 | 34 | from vllm.model_executor.models.qwen3_omni_moe_thinker import Qwen3Omni_VisionTransformer |
32 | 35 | from vllm.model_executor.models.utils import ( |
33 | 36 | AutoWeightsLoader, |
34 | 37 | WeightsMapper, |
35 | 38 | maybe_prefix, |
36 | | - sequence_parallel_chunk, |
37 | 39 | ) |
38 | 40 | from vllm.multimodal import MULTIMODAL_REGISTRY |
39 | 41 | from vllm.sequence import IntermediateTensors |
@@ -531,130 +533,198 @@ def forward(self, hidden_state): |
531 | 533 | return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) |
532 | 534 |
|
533 | 535 |
|
| 536 | +class Qwen3OmniMoeTalkerSharedExpertWrapper(nn.Module): |
| 537 | + """ |
| 538 | + Wrapper that combines shared_expert MLP with its sigmoid gate. |
| 539 | +
|
| 540 | + This matches the HuggingFace weight structure where: |
| 541 | + - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight |
| 542 | + - mlp.shared_expert_gate.weight (sibling, not child) |
| 543 | +
|
| 544 | + The wrapper applies: sigmoid(shared_expert_gate(x)) * shared_expert(x) |
| 545 | + """ |
| 546 | + |
| 547 | + def __init__( |
| 548 | + self, |
| 549 | + shared_expert: Qwen3MoeMLP, |
| 550 | + shared_expert_gate: nn.Linear, |
| 551 | + ): |
| 552 | + super().__init__() |
| 553 | + self._shared_expert = shared_expert |
| 554 | + self._shared_expert_gate = shared_expert_gate |
| 555 | + |
| 556 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 557 | + out = self._shared_expert(x) |
| 558 | + gate_values = F.sigmoid(self._shared_expert_gate(x)) # [batch, 1] |
| 559 | + return gate_values * out # Broadcasting: [batch, 1] * [batch, hidden] |
| 560 | + |
| 561 | + |
| 562 | +class Qwen3OmniMoeTalkerSparseMoeBlock(nn.Module): |
| 563 | + """ |
| 564 | + Sparse MoE block for Qwen3 Omni MoE Talker with shared expert support. |
| 565 | +
|
| 566 | + This block uses SharedFusedMoE to efficiently compute both routed experts |
| 567 | + and the shared expert, potentially overlapping computation with communication. |
| 568 | +
|
| 569 | + Weight structure matches HuggingFace: |
| 570 | + - mlp.gate.weight (router) |
| 571 | + - mlp.shared_expert.{gate_proj, up_proj, down_proj}.weight |
| 572 | + - mlp.shared_expert_gate.weight |
| 573 | + - mlp.experts.{0..n}.{gate_proj, up_proj, down_proj}.weight |
| 574 | + """ |
| 575 | + |
| 576 | + def __init__( |
| 577 | + self, |
| 578 | + config: Qwen3OmniMoeTalkerConfig, |
| 579 | + quant_config: QuantizationConfig | None = None, |
| 580 | + prefix: str = "", |
| 581 | + ): |
| 582 | + super().__init__() |
| 583 | + text_config = config.text_config |
| 584 | + self.tp_size = get_tensor_model_parallel_world_size() |
| 585 | + |
| 586 | + if self.tp_size > text_config.num_experts: |
| 587 | + raise ValueError( |
| 588 | + f"Tensor parallel size {self.tp_size} is greater than the number of experts {text_config.num_experts}." |
| 589 | + ) |
| 590 | + |
| 591 | + # Router gate for selecting top-k experts |
| 592 | + self.gate = ReplicatedLinear( |
| 593 | + text_config.hidden_size, |
| 594 | + text_config.num_experts, |
| 595 | + bias=False, |
| 596 | + quant_config=quant_config, |
| 597 | + prefix=f"{prefix}.gate", |
| 598 | + ) |
| 599 | + |
| 600 | + # Shared expert MLP (matches HF: mlp.shared_expert.*) |
| 601 | + if text_config.shared_expert_intermediate_size > 0: |
| 602 | + self.shared_expert = Qwen3MoeMLP( |
| 603 | + hidden_size=text_config.hidden_size, |
| 604 | + intermediate_size=text_config.shared_expert_intermediate_size, |
| 605 | + hidden_act=text_config.hidden_act, |
| 606 | + quant_config=quant_config, |
| 607 | + reduce_results=False, # Don't reduce, we'll handle it |
| 608 | + prefix=f"{prefix}.shared_expert", |
| 609 | + ) |
| 610 | + # Shared expert gate (matches HF: mlp.shared_expert_gate.weight) |
| 611 | + # This is a sibling of shared_expert, not a child |
| 612 | + self.shared_expert_gate = torch.nn.Linear(text_config.hidden_size, 1, bias=False) |
| 613 | + # Create wrapper for SharedFusedMoE |
| 614 | + self._shared_expert_wrapper = Qwen3OmniMoeTalkerSharedExpertWrapper( |
| 615 | + self.shared_expert, self.shared_expert_gate |
| 616 | + ) |
| 617 | + else: |
| 618 | + self.shared_expert = None |
| 619 | + self.shared_expert_gate = None |
| 620 | + self._shared_expert_wrapper = None |
| 621 | + |
| 622 | + # Fused MoE with shared expert support |
| 623 | + self.experts = SharedFusedMoE( |
| 624 | + shared_experts=self._shared_expert_wrapper, |
| 625 | + num_experts=text_config.num_experts, |
| 626 | + top_k=text_config.num_experts_per_tok, |
| 627 | + hidden_size=text_config.hidden_size, |
| 628 | + intermediate_size=text_config.moe_intermediate_size, |
| 629 | + reduce_results=False, # We'll reduce manually after combining |
| 630 | + renormalize=text_config.norm_topk_prob, |
| 631 | + quant_config=quant_config, |
| 632 | + prefix=f"{prefix}.experts", |
| 633 | + ) |
| 634 | + |
| 635 | + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| 636 | + # NOTE: hidden_states can have either 1D or 2D shape. |
| 637 | + orig_shape = hidden_states.shape |
| 638 | + hidden_dim = hidden_states.shape[-1] |
| 639 | + hidden_states = hidden_states.view(-1, hidden_dim) |
| 640 | + |
| 641 | + # Compute router logits |
| 642 | + router_logits, _ = self.gate(hidden_states) |
| 643 | + |
| 644 | + # Forward through SharedFusedMoE |
| 645 | + # Returns (shared_out, fused_out) when shared_expert is present |
| 646 | + final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits) |
| 647 | + |
| 648 | + # Combine shared and routed expert outputs |
| 649 | + if self._shared_expert_wrapper is not None: |
| 650 | + # SharedFusedMoE returns tuple: (shared_out, fused_out) |
| 651 | + final_hidden_states = final_hidden_states[0] + final_hidden_states[1] |
| 652 | + |
| 653 | + # Apply tensor parallel reduction if needed |
| 654 | + if self.tp_size > 1: |
| 655 | + final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states) |
| 656 | + |
| 657 | + return final_hidden_states.view(orig_shape) |
| 658 | + |
| 659 | + |
534 | 660 | class Qwen3OmniMoeModel(Qwen3MoeLLMForCausalLM): |
535 | | - def __init__(self, vllm_config, talker_config, prefix): |
| 661 | + """ |
| 662 | + Qwen3 Omni MoE Talker language model. |
| 663 | +
|
| 664 | + This model extends Qwen3MoeLLMForCausalLM with: |
| 665 | + - Shared expert support via SharedFusedMoE |
| 666 | + - Codec embedding instead of text embedding |
| 667 | + - No LM head (codec head is separate in the parent class) |
| 668 | + """ |
| 669 | + |
| 670 | + def __init__(self, vllm_config: VllmConfig, talker_config: Qwen3OmniMoeTalkerConfig, prefix: str): |
| 671 | + # Create a vllm_config for the talker's text model |
536 | 672 | talker_vllm_config = vllm_config.with_hf_config( |
537 | 673 | talker_config.text_config, architectures=["Qwen3MoeForCausalLM"] |
538 | 674 | ) |
539 | 675 | talker_vllm_config.model_config.hf_text_config = talker_vllm_config.model_config.hf_config |
| 676 | + |
540 | 677 | super().__init__( |
541 | 678 | vllm_config=talker_vllm_config, |
542 | 679 | prefix=prefix, |
543 | 680 | ) |
544 | 681 |
|
545 | 682 | self.config = talker_config |
| 683 | + self.talker_vllm_config = talker_vllm_config |
546 | 684 |
|
547 | 685 | # Remove the inherited LM head so the talker only exposes codec outputs. |
548 | 686 | if hasattr(self, "lm_head"): |
549 | 687 | del self.lm_head |
550 | 688 |
|
551 | | - # Replace the base embed tokens with codec embedding (defined below). |
| 689 | + # Replace the base embed tokens with codec embedding. |
552 | 690 | if hasattr(self.model, "embed_tokens"): |
553 | 691 | del self.model.embed_tokens |
554 | 692 |
|
555 | 693 | # Codec embedding for RVQ code generation |
556 | 694 | self.model.codec_embedding = nn.Embedding( |
557 | | - talker_config.text_config.vocab_size, talker_config.text_config.hidden_size |
| 695 | + talker_config.text_config.vocab_size, |
| 696 | + talker_config.text_config.hidden_size, |
558 | 697 | ) |
559 | 698 |
|
560 | | - # Add shared expert to each MoE layer and patch the forward method |
561 | | - layer_idx = 0 |
562 | | - for layer in self.model.layers: |
563 | | - # add shared expert to Qwen3OmniMoeSparseMoeBlock layers |
564 | | - if hasattr(layer.mlp, "experts"): # Check if it's a SparseMoeBlock |
565 | | - # Shared expert is a regular gated MLP (SwiGLU) |
566 | | - layer.mlp.shared_expert = Qwen3MoeMLP( |
567 | | - hidden_size=self.config.text_config.hidden_size, |
568 | | - intermediate_size=self.config.text_config.shared_expert_intermediate_size, |
569 | | - hidden_act=self.config.text_config.hidden_act, |
570 | | - quant_config=talker_vllm_config.quant_config, |
571 | | - reduce_results=False, # Don't reduce since we'll add it manually |
572 | | - prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert", |
573 | | - ) |
| 699 | + # Replace MoE blocks with shared expert versions |
| 700 | + self._replace_moe_blocks_with_shared_expert(prefix) |
574 | 701 |
|
575 | | - # Shared expert gate outputs a single scalar per token |
576 | | - layer.mlp.shared_expert_gate = ReplicatedLinear( |
577 | | - self.config.text_config.hidden_size, |
578 | | - 1, # Output single scalar per token |
579 | | - bias=False, |
580 | | - quant_config=None, |
581 | | - prefix=f"{prefix}.layers.{layer_idx}.mlp.shared_expert_gate", |
582 | | - ) |
583 | | - |
584 | | - # Store MoE config values for router computation |
585 | | - layer.mlp.top_k = self.config.text_config.num_experts_per_tok |
586 | | - layer.mlp.norm_topk_prob = self.config.text_config.norm_topk_prob |
587 | | - layer.mlp.num_experts = self.config.text_config.num_experts |
588 | | - |
589 | | - # Monkey-patch the forward method to use shared expert |
590 | | - layer.mlp.forward = self._create_moe_forward_with_shared_expert(layer.mlp) |
591 | | - |
592 | | - layer_idx += 1 |
593 | | - |
594 | | - def _create_moe_forward_with_shared_expert(self, moe_layer): |
595 | | - """Create a forward method that includes shared expert computation. |
596 | | -
|
597 | | - This matches the Transformers implementation where: |
598 | | - 1. Compute shared expert output (regular MLP) |
599 | | - 2. Gate it with sigmoid(shared_expert_gate(x)) |
600 | | - 3. Apply softmax BEFORE top-k selection (matches Transformers router) |
601 | | - 4. Add to routed expert outputs |
| 702 | + def _replace_moe_blocks_with_shared_expert(self, prefix: str) -> None: |
602 | 703 | """ |
603 | | - |
604 | | - def forward_with_shared_expert(hidden_states: torch.Tensor, layer_idx: int = 0) -> torch.Tensor: |
605 | | - # Save original shape |
606 | | - orig_shape = hidden_states.shape |
607 | | - hidden_dim = hidden_states.shape[-1] |
608 | | - hidden_states = hidden_states.view(-1, hidden_dim) |
609 | | - |
610 | | - # handle sequence parallel if needed |
611 | | - if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel: |
612 | | - hidden_states = sequence_parallel_chunk(hidden_states) |
613 | | - |
614 | | - # Compute shared expert output |
615 | | - # The shared expert is a regular MLP, not a routed MoE |
616 | | - shared_output = None |
617 | | - if hasattr(moe_layer, "shared_expert") and moe_layer.shared_expert is not None: |
618 | | - # Forward through shared expert MLP |
619 | | - shared_output = moe_layer.shared_expert(hidden_states) |
620 | | - |
621 | | - # Apply gating with sigmoid: sigmoid(gate(x)) * shared_expert(x) |
622 | | - if hasattr(moe_layer, "shared_expert_gate") and moe_layer.shared_expert_gate is not None: |
623 | | - gate_logits, _ = moe_layer.shared_expert_gate(hidden_states) |
624 | | - gate_values = F.sigmoid(gate_logits) # [batch, 1] |
625 | | - shared_output = gate_values * shared_output # Broadcasting: [batch, 1] * [batch, hidden] |
626 | | - |
627 | | - # Compute experts results |
628 | | - # router_logits: (num_tokens, n_experts) |
629 | | - router_logits, _ = moe_layer.gate(hidden_states) |
630 | | - experts_output = moe_layer.experts(hidden_states=hidden_states, router_logits=router_logits) |
631 | | - |
632 | | - # combine experts and shared expert results |
633 | | - if shared_output is not None: |
634 | | - final_hidden_states = experts_output + shared_output |
635 | | - |
636 | | - # Handle sequence parallel if needed |
637 | | - if hasattr(moe_layer, "is_sequence_parallel") and moe_layer.is_sequence_parallel: |
638 | | - from vllm.distributed import tensor_model_parallel_all_gather |
639 | | - |
640 | | - num_tokens = orig_shape[0] if len(orig_shape) > 1 else 1 |
641 | | - final_hidden_states = tensor_model_parallel_all_gather(final_hidden_states, 0) |
642 | | - final_hidden_states = final_hidden_states[:num_tokens] |
643 | | - try: |
644 | | - final_hidden_states.view(orig_shape) |
645 | | - except Exception as e: |
646 | | - print(f"Error viewing final hidden states: {e}") |
647 | | - print(f"final_hidden_states.shape: {final_hidden_states.shape}") |
648 | | - print(f"orig_shape: {orig_shape}") |
649 | | - raise e |
650 | | - # Return with original shape |
651 | | - return final_hidden_states.view(orig_shape) |
652 | | - |
653 | | - return forward_with_shared_expert |
| 704 | + Replace Qwen3MoeSparseMoeBlock layers with Qwen3OmniMoeTalkerSparseMoeBlock |
| 705 | + that includes shared expert support via SharedFusedMoE. |
| 706 | + """ |
| 707 | + # Get compilation config to clean up registered layer names |
| 708 | + compilation_config = self.talker_vllm_config.compilation_config |
| 709 | + |
| 710 | + for layer_idx, layer in enumerate(self.model.layers): |
| 711 | + # Check if this layer has a MoE block (has experts attribute) |
| 712 | + if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): |
| 713 | + # Remove old layer registration from static_forward_context |
| 714 | + old_experts_prefix = f"{prefix}.model.layers.{layer_idx}.mlp.experts" |
| 715 | + if old_experts_prefix in compilation_config.static_forward_context: |
| 716 | + del compilation_config.static_forward_context[old_experts_prefix] |
| 717 | + |
| 718 | + # Create new MoE block with shared expert support |
| 719 | + layer.mlp = Qwen3OmniMoeTalkerSparseMoeBlock( |
| 720 | + config=self.config, |
| 721 | + quant_config=self.talker_vllm_config.quant_config, |
| 722 | + prefix=f"{prefix}.model.layers.{layer_idx}.mlp", |
| 723 | + ) |
654 | 724 |
|
655 | 725 | def embed_input_ids( |
656 | 726 | self, |
657 | 727 | input_ids: torch.Tensor, |
658 | | - **kwargs: object, |
659 | 728 | ) -> torch.Tensor: |
| 729 | + """Embed codec input IDs.""" |
660 | 730 | return self.model.codec_embedding(input_ids) |
0 commit comments