Skip to content

Commit 2ad0ca5

Browse files
authored
Qwen3.5 MoE supports flashcomm v1 (vllm-project#7644)
cherry pick from vllm-project#7486 <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> Multimodal models like Qwen3.5 MoE does embedding in model_runner, so when flash comm is enabled, the first AllGather operation should be skipped. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> No. ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@8b63257 --------- Signed-off-by: Wangbingjie <wangbj1207@126.com> Signed-off-by: wangbj127 <256472688+wangbj127@users.noreply.github.com>
1 parent ff1860b commit 2ad0ca5

File tree

7 files changed

+182
-8
lines changed

7 files changed

+182
-8
lines changed

tests/e2e/multicard/4-cards/test_qwen3_5.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
# This file is a part of the vllm-ascend project.
1717
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
1818
#
19+
import os
1920
from tests.e2e.conftest import VllmRunner
21+
from unittest.mock import patch
2022

2123

2224
def test_qwen3_5_27b_distributed_mp_tp4():
@@ -72,4 +74,32 @@ def test_qwen3_5_35b_distributed_mp_tp4_full_decode_only_mtp3():
7274
"num_speculative_tokens": 3,
7375
}) as vllm_model:
7476
vllm_model.generate_greedy(example_prompts, max_tokens)
75-
del vllm_model
77+
del vllm_model
78+
79+
80+
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
81+
def test_qwen3_5_35b_distributed_mp_tp4_full_decode_only_mtp3_flashcomm():
82+
example_prompts = [
83+
"Hello, my name is",
84+
"The president of the United States is",
85+
"The capital of France is",
86+
"The future of AI is",
87+
]
88+
89+
max_tokens = 20
90+
with VllmRunner("Qwen/Qwen3.5-35B-A3B",
91+
tensor_parallel_size=4,
92+
enable_expert_parallel=True,
93+
max_model_len=4096,
94+
gpu_memory_utilization=0.90,
95+
distributed_executor_backend="mp",
96+
compilation_config={
97+
"cudagraph_mode": "FULL_DECODE_ONLY",
98+
"cudagraph_capture_sizes": [4, 8, 12, 16],
99+
},
100+
speculative_config={
101+
"method": "qwen3_5_mtp",
102+
"num_speculative_tokens": 3,
103+
}) as vllm_model:
104+
vllm_model.generate_greedy(example_prompts, max_tokens)
105+
del vllm_model

vllm_ascend/ops/layernorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def forward_oot(
9797
import torch_npu
9898

9999
if residual is not None:
100+
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
100101
if enable_custom_op():
101102
x, _, residual = torch.ops._C_ascend.npu_add_rms_norm_bias(
102103
x, residual, 1.0 + self.weight, None, self.variance_epsilon

vllm_ascend/ops/linear_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
tensor_model_parallel_reduce_scatter,
5858
)
5959
from vllm.distributed.parallel_state import get_tp_group
60+
from vllm.model_executor.models.utils import extract_layer_index
6061

6162
from vllm_ascend.ascend_config import get_ascend_config
6263
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
@@ -74,6 +75,7 @@
7475
flashcomm2_enable,
7576
get_flashcomm2_reorgnized_batch_ids,
7677
get_weight_prefetch_method,
78+
is_vl_model,
7779
matmul_allreduce_enable,
7880
mlp_tp_enable,
7981
oproj_tp_enable,
@@ -430,8 +432,8 @@ def apply_impl(self, input_: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor,
430432

431433
# Matrix multiply.
432434
assert self.quant_method is not None
433-
434-
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, True)
435+
need_all_gather = not (extract_layer_index(self.layer.prefix) == 0 and is_vl_model() and "attn" in self.prefix)
436+
input_ = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(input_, label=need_all_gather)
435437
output_parallel = self.quant_method.apply(self.layer, input_, bias)
436438

437439
if self.gather_output:

vllm_ascend/ops/register_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm_ascend.ops.rotary_embedding import rope_forward_oot
1818
from vllm_ascend.ops.triton.muls_add import muls_add_triton
1919
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
20-
from vllm_ascend.utils import enable_sp_by_pass, npu_stream_switch, prefetch_stream
20+
from vllm_ascend.utils import enable_sp_by_pass, is_vl_model, npu_stream_switch, prefetch_stream
2121

2222

2323
def _maybe_chunk_residual_impl(x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
@@ -80,7 +80,7 @@ def _maybe_pad_and_reduce_impl(x: torch.Tensor, is_ep_comm: bool = False) -> tor
8080
enable_sp_by_pass() and is_ep_comm
8181
)
8282

83-
if not flash_comm_v1_enabled:
83+
if not flash_comm_v1_enabled or (forward_context.is_draft_model and is_vl_model()):
8484
return tensor_model_parallel_all_reduce(x)
8585

8686
dp_metadata = forward_context.dp_metadata

vllm_ascend/patch/worker/patch_qwen3_5.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,18 @@
1818

1919

2020
import torch
21+
from einops import rearrange
22+
from vllm.distributed import get_tensor_model_parallel_world_size
2123
from vllm.forward_context import get_forward_context
2224
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
2325
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
24-
from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet
26+
from vllm.model_executor.models.qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
2527
from vllm.model_executor.models.qwen3_next import Qwen3NextAttention
2628
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
2729
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
2830
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
2931

32+
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
3033
from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
3134
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
3235
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
@@ -41,6 +44,64 @@ def to_int64_tuple(t):
4144

4245

4346
class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
47+
def forward(
48+
self,
49+
hidden_states: torch.Tensor,
50+
output: torch.Tensor,
51+
):
52+
"""
53+
Forward pass with three parts:
54+
1. Input projection
55+
2. Core attention (custom op)
56+
3. Output projection
57+
"""
58+
59+
# ============================================================
60+
# Part 1: Input Projection
61+
# ============================================================
62+
mixed_qkvz, _ = self.in_proj_qkvz(hidden_states)
63+
num_tokens = mixed_qkvz.size(0)
64+
qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size
65+
z_size = self.value_dim // self.tp_size
66+
mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1)
67+
z = z.reshape(z.size(0), -1, self.head_v_dim)
68+
ba, _ = self.in_proj_ba(hidden_states)
69+
b, a = ba.chunk(2, dim=-1)
70+
71+
b = b.contiguous()
72+
a = a.contiguous()
73+
74+
# ============================================================
75+
# Part 2: Core Attention (Custom Op)
76+
# ============================================================
77+
# Note: we should not use torch.empty here like other attention backends,
78+
# see discussions in https://github.com/vllm-project/vllm/pull/28182
79+
core_attn_out = torch.zeros(
80+
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
81+
dtype=hidden_states.dtype,
82+
device=hidden_states.device,
83+
)
84+
torch.ops.vllm.gdn_attention_core(
85+
mixed_qkv,
86+
b,
87+
a,
88+
core_attn_out,
89+
self.prefix,
90+
)
91+
# ============================================================
92+
# Part 3: Output Projection
93+
# ============================================================
94+
z_shape_og = z.shape
95+
# Reshape input data into 2D tensor
96+
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
97+
z = z.reshape(-1, z.shape[-1])
98+
core_attn_out = self.norm(core_attn_out, z)
99+
core_attn_out = core_attn_out.reshape(z_shape_og)
100+
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
101+
o_out, _ = self.out_proj(core_attn_out)
102+
actual_num_tokens = o_out.shape[0]
103+
output[:actual_num_tokens] = o_out
104+
44105
def _forward_core(
45106
self,
46107
mixed_qkv: torch.Tensor,
@@ -320,5 +381,68 @@ def forward(self, positions: torch.Tensor, output: torch.Tensor, hidden_states:
320381
output[:], _ = self.o_proj(attn_output)
321382

322383

384+
class AscendQwen3_5DecoderLayer(Qwen3_5DecoderLayer):
385+
def forward(
386+
self,
387+
hidden_states: torch.Tensor,
388+
residual: torch.Tensor | None,
389+
positions: torch.Tensor = None,
390+
**kwargs: object,
391+
):
392+
if residual is None:
393+
residual = hidden_states
394+
hidden_states = self.input_layernorm(hidden_states)
395+
else:
396+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
397+
398+
if self.layer_idx == 0 and _EXTRA_CTX.flash_comm_v1_enabled:
399+
tp_size = get_tensor_model_parallel_world_size()
400+
n_out = (hidden_states.shape[0] + tp_size - 1) // tp_size
401+
hidden_dim = hidden_states.shape[-1]
402+
self_attention_output = torch.empty(
403+
(n_out, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
404+
)
405+
else:
406+
self_attention_output = torch.empty_like(hidden_states)
407+
408+
if self.layer_type == "linear_attention":
409+
self.linear_attn(
410+
hidden_states=hidden_states,
411+
output=self_attention_output,
412+
)
413+
elif self.layer_type == "full_attention":
414+
self.self_attn(
415+
hidden_states=hidden_states,
416+
output=self_attention_output,
417+
positions=positions,
418+
)
419+
else:
420+
raise ValueError("Invalid layer_type")
421+
hidden_states = self_attention_output
422+
423+
if self.layer_scale:
424+
if len(hidden_states.shape) == 2:
425+
hidden_states = hidden_states * (self.attn_layer_scale.to(hidden_states.dtype)[0] + 1)
426+
else:
427+
hidden_states = hidden_states * (self.attn_layer_scale.to(hidden_states.dtype) + 1)
428+
429+
# Fully Connected
430+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
431+
hidden_states = self.mlp(hidden_states)
432+
433+
if self.layer_scale:
434+
if len(hidden_states.shape) == 2:
435+
hidden_states = hidden_states * (self.ffn_layer_scale.to(hidden_states.dtype)[0] + 1)
436+
else:
437+
assert len(hidden_states.shape) == len(self.ffn_layer_scale.shape), (
438+
f"shape must be the same {len(hidden_states.shape)}, {len(self.ffn_layer_scale.shape)}"
439+
)
440+
hidden_states = hidden_states * (self.ffn_layer_scale.to(hidden_states.dtype) + 1)
441+
442+
return hidden_states, residual
443+
444+
445+
Qwen3_5GatedDeltaNet.forward = AscendQwen3_5GatedDeltaNet.forward
323446
Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core
324447
Qwen3NextAttention.forward = AscendQwen3NextAttention.forward
448+
Qwen3_5DecoderLayer.forward = AscendQwen3_5DecoderLayer.forward

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_st
155155
]
156156

157157
self._runnable = self._run_merged_draft
158+
self.is_multimodal_model = self.vllm_config.model_config.is_multimodal_model
158159
if self.uses_mrope:
159160
self.mrope_positions = torch.zeros((3, self.max_num_tokens + 1), dtype=torch.int32, device=device)
160161
elif self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim > 0:
@@ -414,6 +415,16 @@ def dummy_run(
414415
if is_profile:
415416
batch_size = min(batch_size, self.runner.max_num_reqs)
416417

418+
if self.supports_mm_inputs:
419+
mm_embeds, is_mm_embed = (None, None)
420+
inputs_embeds = self.model.embed_input_ids(
421+
self.input_ids[:num_tokens], multimodal_embeddings=mm_embeds, is_multimodal=is_mm_embed
422+
)
423+
self.inputs_embeds[:num_tokens] = inputs_embeds
424+
inputs_embeds = self.inputs_embeds[:num_tokens]
425+
else:
426+
inputs_embeds = None
427+
417428
with set_ascend_forward_context(
418429
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
419430
self.vllm_config,
@@ -437,7 +448,7 @@ def dummy_run(
437448
token_indices_to_sample=self.token_indices_to_sample[: batch_size * self.extra_slots_per_request],
438449
# The target_position's address is same as the model_positions's
439450
target_positions=model_positions,
440-
inputs_embeds=None,
451+
inputs_embeds=inputs_embeds,
441452
multi_steps_attn_metadata=multi_steps_attn_metadata,
442453
num_tokens=num_tokens,
443454
)
@@ -1581,6 +1592,8 @@ def maybe_pad_and_reduce(
15811592
hidden_states: torch.Tensor,
15821593
positions: torch.Tensor,
15831594
) -> tuple[torch.Tensor, torch.Tensor]:
1595+
if self.is_multimodal_model and _EXTRA_CTX.flash_comm_v1_enabled:
1596+
return hidden_states, positions
15841597
if self.method == "mtp":
15851598
if _EXTRA_CTX.flash_comm_v1_enabled:
15861599
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states)

vllm_ascend/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ def _is_contain_expert(config: Any):
849849
return False
850850

851851

852-
def is_vl_model(vllm_config: VllmConfig):
852+
def is_vl_model(vllm_config: VllmConfig = None):
853853
"""Checks if the model is a VL model by config.
854854
855855
Uses the same criterion as vllm itself (model_config.py): a model is
@@ -859,6 +859,10 @@ def is_vl_model(vllm_config: VllmConfig):
859859
self (rare but possible).
860860
"""
861861
global _IS_VL_MODEL
862+
if vllm_config is None:
863+
from vllm.config import get_current_vllm_config_or_none
864+
865+
vllm_config = get_current_vllm_config_or_none()
862866
if _IS_VL_MODEL is None and vllm_config and vllm_config.model_config:
863867
model_config = vllm_config.model_config
864868
# Primary: vllm's own VL detection — hf_config is the top-level

0 commit comments

Comments
 (0)