18
18
import pytest
19
19
import torch
20
20
import torch .nn as nn
21
+ import torch_npu
21
22
from pytest_mock import MockerFixture
22
23
24
+ from vllm_ascend .ascend_forward_context import get_fused_moe_state
23
25
from vllm_ascend .ops .fused_moe import (AscendFusedMoE ,
24
26
AscendUnquantizedFusedMoEMethod )
25
- from vllm_ascend .utils import adapt_patch # noqa E402
27
+ from vllm_ascend .utils import AscendSocVersion , adapt_patch # noqa E402
26
28
27
29
adapt_patch (True )
28
30
29
31
30
- def mock_ep_group (mocker ):
32
+ def mock_ep_and_mc2_group (mocker ):
31
33
mock_group = mocker .MagicMock ()
32
34
mock_group .rank_in_group = 0
33
35
mock_group .rank = 0
@@ -52,7 +54,8 @@ def mock_dist_env(mocker: MockerFixture):
52
54
53
55
with patch ('torch.distributed.get_rank' , return_value = 0 ), \
54
56
patch ('torch.distributed.get_world_size' , return_value = 4 ), \
55
- patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_group (mocker )), \
57
+ patch ('vllm_ascend.ops.fused_moe.get_ep_group' , return_value = mock_ep_and_mc2_group (mocker )), \
58
+ patch ('vllm_ascend.ops.fused_moe.get_mc2_group' , return_value = mock_ep_and_mc2_group (mocker )), \
56
59
patch ('vllm_ascend.ops.fused_moe.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
57
60
patch ('vllm.distributed.parallel_state.get_tp_group' , return_value = mock_dp_and_tp_group (mocker )), \
58
61
patch ('vllm_ascend.ops.fused_moe.get_dp_group' , return_value = mock_dp_and_tp_group (mocker )), \
@@ -73,7 +76,7 @@ def mock_dist_env(mocker: MockerFixture):
73
76
return_value = (3 , torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ]))), \
74
77
patch ('vllm_ascend.ops.fused_moe.get_forward_context' ,
75
78
return_value = MagicMock (
76
- attn_metadata = MagicMock ( max_num_tokens_across_dp = 10 ) ,
79
+ max_tokens_across_dp = 10 ,
77
80
dp_metadata = MagicMock (cu_tokens_across_dp_cpu = [5 , 10 ])
78
81
)), \
79
82
patch ('vllm_ascend.ops.fused_moe.get_current_vllm_config' ,
@@ -122,7 +125,14 @@ def mock_moe_env(mocker: MockerFixture):
122
125
patch ("torch_npu.npu_moe_finalize_routing" , return_value = (
123
126
torch .randn (16 , 2 )
124
127
)):
125
- yield
128
+ if hasattr (torch_npu , 'npu_moe_distribute_dispatch_v2' ):
129
+ with patch ("torch_npu.npu_moe_distribute_dispatch_v2" , return_value = (
130
+ torch .randn (16 , 2 ))), \
131
+ patch ("torch_npu.npu_moe_distribute_combine_v2" , return_value = (
132
+ torch .randn (16 , 2 ))):
133
+ yield
134
+ else :
135
+ yield
126
136
127
137
128
138
@pytest .fixture
@@ -237,11 +247,16 @@ def test_forward(self, mock_dist_env, default_moe_config, others_param):
237
247
moe .moe_parallel_config .ep_size = 1
238
248
239
249
moe .quant_method = MockQuantMethod (shared_experts , num_tokens )
240
- output = moe .forward (inputs ,
241
- router_logits ,
242
- is_prefill = is_prefill ,
243
- top_k = top_k ,
244
- shared_experts = shared_experts )
250
+ forward_context = MagicMock (mc2_mask = torch .zeros (num_tokens ,
251
+ dtype = torch .bool ),
252
+ padded_num_tokens = num_tokens )
253
+ with patch ("vllm_ascend.ops.fused_moe.get_forward_context" ,
254
+ return_value = forward_context ):
255
+ output = moe .forward (inputs ,
256
+ router_logits ,
257
+ is_prefill = is_prefill ,
258
+ top_k = top_k ,
259
+ shared_experts = shared_experts )
245
260
246
261
moe .quant_method .apply .assert_called_once ()
247
262
@@ -288,15 +303,20 @@ def test_process_weights_after_loading(self, moe_method, mock_dist_env):
288
303
def test_apply_without_expert_map (self , moe_method , mock_dist_env ,
289
304
mock_moe_env , others_param ):
290
305
"""
291
- 1 test is_deepseek_v3_r1=true and use fused_expters_with_all2all
306
+ 1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
292
307
2 test use_select_experts and fused_experts
293
308
3 test use select_gating_topk_softmax_experts and fused_experts
294
309
4 test use select_experts and fused_experts_with_all2all_buffer
295
310
"""
296
311
global_num_experts , ep_size , select_softmax = others_param
312
+ is_prefill = False
313
+ is_deepseek_v3_r1 = global_num_experts == 256
314
+ forward_context = MagicMock (fused_moe_state = get_fused_moe_state (
315
+ ep_size , is_prefill , is_deepseek_v3_r1 ))
297
316
with patch (
298
317
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS" ,
299
- select_softmax ):
318
+ select_softmax ), \
319
+ patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ):
300
320
moe_method .ep_size = ep_size
301
321
x = torch .randn (8 , 2 , 2 )
302
322
router_logits = torch .randn (8 , 8 )
@@ -309,7 +329,7 @@ def test_apply_without_expert_map(self, moe_method, mock_dist_env,
309
329
top_k = 2 ,
310
330
renormalize = True ,
311
331
global_num_experts = global_num_experts ,
312
- is_prefill = False )
332
+ is_prefill = is_prefill )
313
333
314
334
if ep_size == 1 :
315
335
assert result .shape == (16 , 2 )
@@ -327,8 +347,13 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
327
347
4 test use_select_experts and fused_experts
328
348
"""
329
349
ep_size , alltoall_buffer = others_param
350
+ is_prefill = False
351
+ forward_context = MagicMock (
352
+ fused_moe_state = get_fused_moe_state (ep_size , is_prefill , True ))
330
353
with patch ("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER" ,
331
- alltoall_buffer ):
354
+ alltoall_buffer ), \
355
+ patch ("vllm_ascend.ops.fused_moe.get_forward_context" , return_value = forward_context ), \
356
+ patch ("vllm_ascend.ops.fused_moe.get_ascend_soc_version" , return_value = AscendSocVersion .A3 ):
332
357
expert_map = torch .tensor ([0 , 1 , 2 , - 1 , - 1 , - 1 , - 1 , - 1 ])
333
358
moe_method .ep_size = ep_size
334
359
x = torch .randn (8 , 2 , 2 )
@@ -347,7 +372,7 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
347
372
renormalize = True ,
348
373
global_num_experts = 128 ,
349
374
expert_map = expert_map ,
350
- is_prefill = False )
375
+ is_prefill = is_prefill )
351
376
352
377
if ep_size == 16 or ep_size == 1 :
353
378
assert result .shape == (16 , 2 )
0 commit comments