1
1
# SPDX-License-Identifier: Apache-2.0
2
2
3
- import torch
4
- import triton
5
- import triton .language as tl
3
+ from dataclasses import dataclass
6
4
7
5
import pytest
8
- from dataclasses import dataclass
6
+ import torch
7
+ import triton .language as tl
9
8
10
9
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
11
- invoke_moe_batched_triton_kernel ,
12
- invoke_batched_silu_and_mul )
10
+ invoke_batched_silu_and_mul , invoke_moe_batched_triton_kernel )
13
11
14
12
15
13
@dataclass
@@ -20,25 +18,36 @@ class BatchedMMConfig:
20
18
K : int
21
19
N : int
22
20
21
+
23
22
@dataclass
24
23
class BatchedMMTensors :
25
24
A : torch .Tensor # [E, max_tokens, K]
26
25
B : torch .Tensor # [E, K, N] - column major
27
26
C : torch .Tensor # [E, max_tokens, N]
28
- num_expert_tokens : torch .Tensor # [E]
27
+ num_expert_tokens : torch .Tensor # [E]
29
28
30
29
@staticmethod
31
30
def make_tensors (config : BatchedMMConfig ):
32
- A = torch .randn ((config .num_experts , config .max_tokens_per_expert , config .K ), device = "cuda" , dtype = config .dtype ) / 50.0
33
- B = torch .randn ((config .num_experts , config .N , config .K ), device = "cuda" , dtype = config .dtype ) / 50.0
34
- C = torch .zeros ((config .num_experts , config .max_tokens_per_expert , config .N ), device = "cuda" , dtype = config .dtype )
35
- num_expert_tokens = torch .randint (low = 0 , high = config .max_tokens_per_expert , size = (config .num_experts ,), device = "cuda" , dtype = torch .int32 )
36
- return BatchedMMTensors (A ,B ,C , num_expert_tokens )
37
-
38
-
39
- def ref_impl (A : torch .Tensor ,
40
- B : torch .Tensor ,
41
- C : torch .Tensor ,
31
+ A = torch .randn (
32
+ (config .num_experts , config .max_tokens_per_expert , config .K ),
33
+ device = "cuda" ,
34
+ dtype = config .dtype ) / 50.0
35
+ B = torch .randn ((config .num_experts , config .N , config .K ),
36
+ device = "cuda" ,
37
+ dtype = config .dtype ) / 50.0
38
+ C = torch .zeros (
39
+ (config .num_experts , config .max_tokens_per_expert , config .N ),
40
+ device = "cuda" ,
41
+ dtype = config .dtype )
42
+ num_expert_tokens = torch .randint (low = 0 ,
43
+ high = config .max_tokens_per_expert ,
44
+ size = (config .num_experts , ),
45
+ device = "cuda" ,
46
+ dtype = torch .int32 )
47
+ return BatchedMMTensors (A , B , C , num_expert_tokens )
48
+
49
+
50
+ def ref_impl (A : torch .Tensor , B : torch .Tensor , C : torch .Tensor ,
42
51
num_expert_tokens : torch .Tensor ) -> torch .Tensor :
43
52
44
53
num_expert_tokens_cpu = num_expert_tokens .clone ()
@@ -49,49 +58,50 @@ def ref_impl(A: torch.Tensor,
49
58
num_tokens = num_expert_tokens_cpu [e ]
50
59
C [e , :num_tokens , :] = A [e , :num_tokens , :] @ B [e ].transpose (0 , 1 )
51
60
52
-
53
61
return C
54
62
63
+
55
64
@pytest .mark .parametrize ("num_experts" , [16 , 32 ])
56
65
@pytest .mark .parametrize ("max_tokens_per_expert" , [512 ])
57
66
@pytest .mark .parametrize ("K" , [256 ])
58
67
@pytest .mark .parametrize ("N" , [512 ])
59
68
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
60
- def test_batched_mm (num_experts : int ,
61
- max_tokens_per_expert : int ,
62
- K : int ,
63
- N : int ,
64
- dtype : torch .dtype ):
69
+ def test_batched_mm (num_experts : int , max_tokens_per_expert : int , K : int ,
70
+ N : int , dtype : torch .dtype ):
65
71
66
72
config = BatchedMMConfig (dtype , num_experts , max_tokens_per_expert , K , N )
67
73
tensors = BatchedMMTensors .make_tensors (config )
68
74
69
75
test_output = tensors .C
70
76
ref_output = test_output .clone ()
71
77
72
-
73
- compute_tl_dtype = {torch .float16 : tl .float16 ,
74
- torch .bfloat16 : tl .bfloat16 ,
75
- torch .float32 : tl .float32 }[test_output .dtype ]
76
- invoke_moe_batched_triton_kernel (tensors .A ,
77
- tensors .B ,
78
- test_output ,
79
- tensors .num_expert_tokens ,
80
- compute_tl_dtype ,
81
- # Quantization data
82
- None ,
83
- None ,
84
- None ,
85
- # Quantization schemes
86
- False ,
87
- False ,
88
- False ,
89
- config = {"BLOCK_SIZE_M" : 16 ,
90
- "BLOCK_SIZE_N" : 16 ,
91
- "BLOCK_SIZE_K" : 16 })
92
-
93
-
94
- ref_output = ref_impl (tensors .A , tensors .B , ref_output , tensors .num_expert_tokens )
78
+ compute_tl_dtype = {
79
+ torch .float16 : tl .float16 ,
80
+ torch .bfloat16 : tl .bfloat16 ,
81
+ torch .float32 : tl .float32
82
+ }[test_output .dtype ]
83
+ invoke_moe_batched_triton_kernel (
84
+ tensors .A ,
85
+ tensors .B ,
86
+ test_output ,
87
+ tensors .num_expert_tokens ,
88
+ compute_tl_dtype ,
89
+ # Quantization data
90
+ None ,
91
+ None ,
92
+ None ,
93
+ # Quantization schemes
94
+ False ,
95
+ False ,
96
+ False ,
97
+ config = {
98
+ "BLOCK_SIZE_M" : 16 ,
99
+ "BLOCK_SIZE_N" : 16 ,
100
+ "BLOCK_SIZE_K" : 16
101
+ })
102
+
103
+ ref_output = ref_impl (tensors .A , tensors .B , ref_output ,
104
+ tensors .num_expert_tokens )
95
105
#torch.cuda.synchronize()
96
106
#print (f"ref output {ref_output}")
97
107
#print (f"test output {test_output}")
@@ -106,6 +116,7 @@ class BatchedSiluMulConfig:
106
116
max_tokens_per_expert : int
107
117
D : int
108
118
119
+
109
120
@dataclass
110
121
class BatchedSiluMulTensors :
111
122
input : torch .Tensor
@@ -114,16 +125,24 @@ class BatchedSiluMulTensors:
114
125
115
126
@staticmethod
116
127
def make_tensors (config : BatchedSiluMulConfig ):
117
- input = torch .randn ((config .num_experts , config .max_tokens_per_expert , config .D * 2 ), device = "cuda" , dtype = config .dtype ) / 50.0
118
- output = torch .zeros ((config .num_experts , config .max_tokens_per_expert , config .D ), device = "cuda" , dtype = config .dtype )
119
- num_expert_tokens = torch .randint (low = 0 , high = config .max_tokens_per_expert , size = (config .num_experts ,), device = "cuda" , dtype = torch .int32 )
128
+ input = torch .randn (
129
+ (config .num_experts , config .max_tokens_per_expert , config .D * 2 ),
130
+ device = "cuda" ,
131
+ dtype = config .dtype ) / 50.0
132
+ output = torch .zeros (
133
+ (config .num_experts , config .max_tokens_per_expert , config .D ),
134
+ device = "cuda" ,
135
+ dtype = config .dtype )
136
+ num_expert_tokens = torch .randint (low = 0 ,
137
+ high = config .max_tokens_per_expert ,
138
+ size = (config .num_experts , ),
139
+ device = "cuda" ,
140
+ dtype = torch .int32 )
120
141
return BatchedSiluMulTensors (input , output , num_expert_tokens )
121
142
122
143
123
- def ref_batched_silu_mul (
124
- output : torch .Tensor ,
125
- input : torch .Tensor ,
126
- num_expert_tokens : torch .Tensor ) -> torch .Tensor :
144
+ def ref_batched_silu_mul (output : torch .Tensor , input : torch .Tensor ,
145
+ num_expert_tokens : torch .Tensor ) -> torch .Tensor :
127
146
128
147
num_expert_tokens_cpu = num_expert_tokens .clone ()
129
148
num_expert_tokens_cpu = num_expert_tokens_cpu .to (device = "cpu" )
@@ -140,10 +159,8 @@ def ref_batched_silu_mul(
140
159
@pytest .mark .parametrize ("max_tokens_per_expert" , [128 ])
141
160
@pytest .mark .parametrize ("D" , [128 , 256 ])
142
161
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
143
- def test_batched_silu_mul (num_experts : int ,
144
- max_tokens_per_expert : int ,
145
- D : int ,
146
- dtype : torch .dtype ):
162
+ def test_batched_silu_mul (num_experts : int , max_tokens_per_expert : int , D : int ,
163
+ dtype : torch .dtype ):
147
164
148
165
config = BatchedSiluMulConfig (dtype , num_experts , max_tokens_per_expert , D )
149
166
tensors = BatchedSiluMulTensors .make_tensors (config )
@@ -153,6 +170,7 @@ def test_batched_silu_mul(num_experts: int,
153
170
154
171
ref_batched_silu_mul (ref_out , tensors .input , tensors .expert_num_tokens )
155
172
156
- invoke_batched_silu_and_mul (test_out , tensors .input , tensors .expert_num_tokens )
173
+ invoke_batched_silu_and_mul (test_out , tensors .input ,
174
+ tensors .expert_num_tokens )
157
175
158
176
torch .testing .assert_close (test_out , ref_out )
0 commit comments