Skip to content

Commit ba3dfbd

Browse files
authored
[main][refactor] Refactoring forward_context and model_runner_v1 (vllm-project#1979)
### What this PR does / why we need it? A refactoring of forward_context and model_runner_v1, add some context which is necessary in model inference into forward_context, and refactor dummy_run logic, make it more reasonable. Some details for this PR: Add `ascend_forward_context`; Update mc2_v2 op, and support `active_mask` param; Update scripts in examples dir; refactor `dummy_run` logic; Add soc_version for A2 and A3; ### Does this PR introduce _any_ user-facing change? No change at user-facing. ### How was this patch tested? - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@57c22e5 Signed-off-by: zzzzwwjj <[email protected]>
1 parent e3a2443 commit ba3dfbd

22 files changed

+628
-346
lines changed

examples/offline_dualbatch_overlap_npu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ def main():
2121
tensor_parallel_size=2,
2222
max_model_len=4096,
2323
trust_remote_code=True,
24+
enable_expert_parallel=True,
2425
additional_config={
2526
"torchair_graph_config": {
2627
"enabled": False
2728
},
2829
"ascend_scheduler_config": {
2930
"enabled": True
3031
},
31-
"expert_tensor_parallel_size": 1
3232
})
3333

3434
# Generate texts from the prompts. The output is a list of RequestOutput

examples/run_dp_server.sh

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
rm -rf ./.torchair_cache/
2+
rm -rf ./dynamo_*
3+
rm -rf /root/ascend/log/debug/plog/*
4+
15
export HCCL_IF_IP=2.0.0.0
26
export GLOO_SOCKET_IFNAME="enp189s0f0"
37
export TP_SOCKET_IFNAME="enp189s0f0"
@@ -6,25 +10,24 @@ export HCCL_SOCKET_IFNAME="enp189s0f0"
610
export OMP_PROC_BIND=false
711
export OMP_NUM_THREADS=100
812

9-
export VLLM_USE_V1=0
10-
11-
export ASCEND_RT_VISIBLE_DEVICES=0,1
12-
export VLLM_DP_SIZE=2
13-
export VLLM_DP_RANK=0
14-
export VLLM_DP_MASTER_IP="2.0.0.0"
15-
export VLLM_DP_MASTER_PORT=40001
16-
export VLLM_DP_PROXY_IP="2.0.0.0"
17-
export VLLM_DP_PROXY_PORT=30002
18-
export VLLM_DP_MONITOR_PORT=30003
19-
export VLLM_HTTP_PORT=20001
13+
export VLLM_USE_V1=1
14+
export ASCEND_LAUNCH_BLOCKING=0
2015

2116
vllm serve /data/weights/Qwen2.5-0.5B-Instruct \
2217
--host 0.0.0.0 \
23-
--port 20001 \
24-
--tensor-parallel-size 1 \
25-
--seed 1024 \
18+
--port 20002 \
2619
--served-model-name Qwen \
27-
--max-model-len 2000 \
28-
--max-num-batched-tokens 2000 \
20+
--data-parallel-size 4 \
21+
--data-parallel-size-local 4 \
22+
--data-parallel-address 2.0.0.0 \
23+
--data-parallel-rpc-port 13389 \
24+
--tensor-parallel-size 4 \
25+
--enable-expert-parallel \
26+
--no-enable-prefix-caching \
27+
--max-num-seqs 16 \
28+
--max-model-len 4096 \
29+
--max-num-batched-tokens 4096 \
30+
--gpu-memory-utilization 0.9 \
2931
--trust-remote-code \
30-
--gpu-memory-utilization 0.9 \
32+
--enforce-eager \
33+
--additional-config '{"ascend_scheduler_config":{"enabled":true},"torchair_graph_config":{"enabled":false, "enable_multistream_moe":false, "use_cached_graph":false}}'

tests/ut/models/test_deepseek_v2.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,16 @@ def mock_distributed():
114114
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
115115
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
116116
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
117-
_PP=pp_group):
117+
_PP=pp_group), \
118+
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
119+
yield
120+
121+
122+
@pytest.fixture
123+
def mock_forward_context():
124+
forward_context = Mock(in_profile_run=False, with_prefill=False)
125+
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
126+
return_value=forward_context):
118127
yield
119128

120129

@@ -205,7 +214,8 @@ def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
205214
quant_config=None)
206215

207216

208-
def test_custom_deepseek_v2_moe(mock_distributed, base_config):
217+
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
218+
mock_forward_context):
209219
base_config.n_shared_experts = 1
210220
moe = CustomDeepseekV2MoE(config=base_config,
211221
quant_config=None,

tests/ut/ops/test_fused_ops.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@
1818
import pytest
1919
import torch
2020
import torch.nn as nn
21+
import torch_npu
2122
from pytest_mock import MockerFixture
2223

24+
from vllm_ascend.ascend_forward_context import get_fused_moe_state
2325
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
2426
AscendUnquantizedFusedMoEMethod)
25-
from vllm_ascend.utils import adapt_patch # noqa E402
27+
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
2628

2729
adapt_patch(True)
2830

2931

30-
def mock_ep_group(mocker):
32+
def mock_ep_and_mc2_group(mocker):
3133
mock_group = mocker.MagicMock()
3234
mock_group.rank_in_group = 0
3335
mock_group.rank = 0
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
5254

5355
with patch('torch.distributed.get_rank', return_value=0), \
5456
patch('torch.distributed.get_world_size', return_value=4), \
55-
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_group(mocker)), \
57+
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
58+
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
5659
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
5760
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
5861
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
7376
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
7477
patch('vllm_ascend.ops.fused_moe.get_forward_context',
7578
return_value=MagicMock(
76-
attn_metadata=MagicMock(max_num_tokens_across_dp=10),
79+
max_tokens_across_dp=10,
7780
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
7881
)), \
7982
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
122125
patch("torch_npu.npu_moe_finalize_routing", return_value=(
123126
torch.randn(16, 2)
124127
)):
125-
yield
128+
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
129+
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
130+
torch.randn(16, 2))), \
131+
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
132+
torch.randn(16, 2))):
133+
yield
134+
else:
135+
yield
126136

127137

128138
@pytest.fixture
@@ -237,11 +247,16 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
237247
moe.moe_parallel_config.ep_size = 1
238248

239249
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
240-
output = moe.forward(inputs,
241-
router_logits,
242-
is_prefill=is_prefill,
243-
top_k=top_k,
244-
shared_experts=shared_experts)
250+
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
251+
dtype=torch.bool),
252+
padded_num_tokens=num_tokens)
253+
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
254+
return_value=forward_context):
255+
output = moe.forward(inputs,
256+
router_logits,
257+
is_prefill=is_prefill,
258+
top_k=top_k,
259+
shared_experts=shared_experts)
245260

246261
moe.quant_method.apply.assert_called_once()
247262

@@ -288,15 +303,20 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
288303
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
289304
mock_moe_env, others_param):
290305
"""
291-
1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
306+
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
292307
2 test use_select_experts and fused_experts
293308
3 test use select_gating_topk_softmax_experts and fused_experts
294309
4 test use select_experts and fused_experts_with_all2all_buffer
295310
"""
296311
global_num_experts, ep_size, select_softmax = others_param
312+
is_prefill = False
313+
is_deepseek_v3_r1 = global_num_experts == 256
314+
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
315+
ep_size, is_prefill, is_deepseek_v3_r1))
297316
with patch(
298317
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
299-
select_softmax):
318+
select_softmax), \
319+
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
300320
moe_method.ep_size = ep_size
301321
x = torch.randn(8, 2, 2)
302322
router_logits = torch.randn(8, 8)
@@ -309,7 +329,7 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
309329
top_k=2,
310330
renormalize=True,
311331
global_num_experts=global_num_experts,
312-
is_prefill=False)
332+
is_prefill=is_prefill)
313333

314334
if ep_size == 1:
315335
assert result.shape == (16, 2)
@@ -327,8 +347,13 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
327347
4 test use_select_experts and fused_experts
328348
"""
329349
ep_size, alltoall_buffer = others_param
350+
is_prefill = False
351+
forward_context = MagicMock(
352+
fused_moe_state=get_fused_moe_state(ep_size, is_prefill, True))
330353
with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER",
331-
alltoall_buffer):
354+
alltoall_buffer), \
355+
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
356+
patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
332357
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
333358
moe_method.ep_size = ep_size
334359
x = torch.randn(8, 2, 2)
@@ -347,7 +372,7 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
347372
renormalize=True,
348373
global_num_experts=128,
349374
expert_map=expert_map,
350-
is_prefill=False)
375+
is_prefill=is_prefill)
351376

352377
if ep_size == 16 or ep_size == 1:
353378
assert result.shape == (16, 2)
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import math
2+
from contextlib import contextmanager
3+
from enum import Enum
4+
from typing import Any, Optional
5+
6+
import torch
7+
from vllm.config import VllmConfig
8+
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
9+
from vllm.forward_context import get_forward_context, set_forward_context
10+
from vllm.platforms import current_platform
11+
12+
import vllm_ascend.envs as envs
13+
14+
15+
class FusedMoEState(Enum):
16+
AllGather = 0
17+
All2All = 1
18+
MC2 = 2
19+
AllGatherEP = 3
20+
NaiveMulticast = 4
21+
22+
23+
# TODO(zzzzwwjj): add soc_version to choose branch
24+
def get_fused_moe_state(ep_size: int, with_prefill: bool,
25+
is_deepseek_v3_r1: bool):
26+
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
27+
# only supports deepseek v3/r1
28+
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
29+
and is_deepseek_v3_r1):
30+
return FusedMoEState.AllGatherEP
31+
elif ep_size == 1:
32+
if with_prefill:
33+
return FusedMoEState.NaiveMulticast
34+
else:
35+
return FusedMoEState.AllGather
36+
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
37+
elif ep_size < 16 or with_prefill:
38+
return FusedMoEState.All2All
39+
else:
40+
return FusedMoEState.MC2
41+
42+
43+
@contextmanager
44+
def set_ascend_forward_context(
45+
attn_metadata: Any,
46+
vllm_config: VllmConfig,
47+
virtual_engine: int = 0,
48+
num_tokens: Optional[int] = None,
49+
num_tokens_across_dp: Optional[torch.Tensor] = None,
50+
with_prefill: bool = True,
51+
in_profile_run: bool = False,
52+
num_actual_tokens: Optional[int] = None,
53+
):
54+
"""A context manager that stores the current forward context,
55+
can be attention metadata, etc.
56+
We add some additional param into forward_context.
57+
"""
58+
with set_forward_context(attn_metadata,
59+
vllm_config,
60+
virtual_engine=virtual_engine,
61+
num_tokens=num_tokens,
62+
num_tokens_across_dp=num_tokens_across_dp):
63+
forward_context = get_forward_context()
64+
forward_context.with_prefill = with_prefill
65+
ep_size = (get_ep_group().world_size if
66+
vllm_config.parallel_config.enable_expert_parallel else 1)
67+
68+
is_deepseek_v3_r1 = hasattr(
69+
vllm_config.model_config.hf_config, 'n_routed_experts'
70+
) and vllm_config.model_config.hf_config.n_routed_experts == 256
71+
fused_moe_state = get_fused_moe_state(ep_size, with_prefill,
72+
is_deepseek_v3_r1)
73+
74+
forward_context.fused_moe_state = fused_moe_state
75+
76+
forward_context.in_profile_run = in_profile_run
77+
78+
# NOTE: This cannot be set using set_forward_context
79+
# due to multiple warmups before actual capturing
80+
forward_context.capturing = False
81+
82+
if num_tokens is None and attn_metadata is not None:
83+
if hasattr(attn_metadata, 'num_actual_tokens'):
84+
# for v1 engine
85+
num_tokens = attn_metadata.num_actual_tokens
86+
else:
87+
# for v0 engine
88+
num_tokens = attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
89+
90+
if num_actual_tokens is None:
91+
num_actual_tokens = num_tokens
92+
93+
dp_world_size = get_dp_group().world_size
94+
if dp_world_size > 1 and forward_context.dp_metadata is not None:
95+
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
96+
)
97+
else:
98+
max_tokens_across_dp = num_tokens
99+
100+
forward_context.max_tokens_across_dp = max_tokens_across_dp
101+
102+
if num_tokens is not None:
103+
tp_world_size = get_tp_group().world_size
104+
# NOTE: token num which need to pad to when mc2
105+
forward_context.padded_num_tokens = math.ceil(
106+
max_tokens_across_dp / tp_world_size) * tp_world_size
107+
108+
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
109+
dtype=torch.bool,
110+
device=current_platform.device_type)
111+
mc2_mask[:num_actual_tokens] = True
112+
forward_context.mc2_mask = mc2_mask
113+
114+
try:
115+
yield
116+
finally:
117+
pass

0 commit comments

Comments
 (0)