33
44Run `pytest tests/kernels/test_moe.py`.
55"""
6+ import unittest .mock as mock
7+
68import pytest
79import torch
810from torch .nn import Parameter
4042@pytest .mark .parametrize ("topk" , TOP_KS )
4143@pytest .mark .parametrize ("ep_size" , EP_SIZE )
4244@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
45+ @pytest .mark .parametrize ("padding" , [True , False ])
4346def test_fused_moe (
4447 m : int ,
4548 n : int ,
@@ -48,20 +51,20 @@ def test_fused_moe(
4851 topk : int ,
4952 ep_size : int ,
5053 dtype : torch .dtype ,
54+ padding : bool ,
5155):
56+ if padding :
57+ padding_size = 128
58+ envs .VLLM_MOE_PADDING = True
59+ else :
60+ padding_size = 0
61+ envs .VLLM_MOE_PADDING = False
62+
5263 a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
5364 w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
5465 w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
5566
5667 score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
57-
58- # Pad the input if use padding
59- if envs .VLLM_MOE_PADDING :
60- w1 = F .pad (w1 , (0 , 128 ), "constant" , 0 )
61- torch .cuda .empty_cache ()
62- w2 = F .pad (w2 , (0 , 128 ), "constant" , 0 )
63- torch .cuda .empty_cache ()
64-
6568 if ep_size > 1 :
6669 local_e = e // ep_size
6770 e_ids = torch .randint (0 ,
@@ -75,16 +78,7 @@ def test_fused_moe(
7578 else :
7679 e_map = None
7780
78- triton_output = fused_moe (a ,
79- w1 ,
80- w2 ,
81- score ,
82- topk ,
83- global_num_experts = e ,
84- expert_map = e_map ,
85- renormalize = False )
8681 torch_output = torch_moe (a , w1 , w2 , score , topk , e_map )
87- torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
8882 iterative_output = iterative_moe (a ,
8983 w1 ,
9084 w2 ,
@@ -93,6 +87,26 @@ def test_fused_moe(
9387 global_num_experts = e ,
9488 expert_map = e_map ,
9589 renormalize = False )
90+ # Pad the input if use padding
91+ if envs .VLLM_MOE_PADDING :
92+ w1 = F .pad (w1 , (0 , 128 ), "constant" , 0 )
93+ torch .cuda .empty_cache ()
94+ w2 = F .pad (w2 , (0 , 128 ), "constant" , 0 )
95+ torch .cuda .empty_cache ()
96+
97+ with mock .patch (
98+ 'vllm.model_executor.layers.fused_moe.fused_moe.padding_size' ,
99+ padding_size ):
100+ triton_output = fused_moe (a ,
101+ w1 ,
102+ w2 ,
103+ score ,
104+ topk ,
105+ global_num_experts = e ,
106+ expert_map = e_map ,
107+ renormalize = False )
108+
109+ torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
96110 torch .testing .assert_close (iterative_output ,
97111 torch_output ,
98112 atol = 1e-2 ,
0 commit comments