1+ import torch
2+ import pytest
3+ from typing import Type , Optional
4+ import numpy as np
5+
6+ import vllm ._fp8gemm_C # noqa: F401
7+ from vllm import _custom_ops as ops
8+ from vllm .platforms import current_platform
9+
10+ from tests .kernels .utils import opcheck
11+
12+
13+ device = "cuda"
14+ if not current_platform .is_rocm ():
15+ pytest .skip (reason = "FBGEMM Kernel currently only supported on ROCm through CK kernel." ,
16+ allow_module_level = True )
17+
18+ def get_cond_label (M , N , K ):
19+ if M < 64 and N < 2048 and K < 2048 : # COND_1
20+ return "COND_1"
21+ elif M < 64 and K < 2048 : # COND_2
22+ return "COND_2"
23+ elif M < 64 and N < 2048 : # COND_3
24+ return "COND_3"
25+ # elif M < 64 and N > 2048 and K > 2048: # COND_4
26+ # return "COND_4"
27+ elif M < 64 : # COND_5
28+ return "COND_5"
29+ elif (
30+ (M < 512 and K < 8192 ) or (N <= 2048 and K <= 8192 ) or (K <= 2048 and N <= 8192 )
31+ ) and K >= 1024 : # COND_6
32+ return "COND_6"
33+ elif K < 1024 : # COND_7
34+ return "COND_7"
35+ elif M < 1024 : # COND_8
36+ return "COND_8"
37+ elif M >= 1024 and N >= 1024 and K >= 1024 : # COND_9
38+ return "COND_9"
39+ else : # COND_10 Will never be triggered
40+ return "COND_10"
41+
42+ def vec_scaled_mm_torch (a : torch .Tensor ,
43+ b : torch .Tensor ,
44+ scale_a : torch .Tensor ,
45+ scale_b : torch .Tensor ,
46+ out_dtype : Type [torch .dtype ],
47+ bias : Optional [torch .Tensor ] = None ) -> torch .Tensor :
48+ out = torch .mm (a .to (torch .float32 ), b .to (torch .float32 ))
49+ out = scale_a * out
50+ out = scale_b .T * out
51+ out = out .to (out_dtype )
52+ if bias is not None :
53+ out = out + bias
54+
55+ return out
56+
57+ def generate_random_fp8_tensor (shape ):
58+ # Generate random float32 tensor and convert to float8_e4m3fnuz
59+ tensor = torch .rand (shape , dtype = torch .float32 , device = device )
60+ return tensor .to (torch .float8_e4m3fnuz )
61+
62+ def generate_random_scale_tensor (shape ):
63+ return torch .rand (shape , dtype = torch .float32 , device = device )
64+
65+ def _get_opcheck_params ():
66+
67+ MNK_list = []
68+ for M in range (17 ):
69+ for N in range (17 ):
70+ for K in range (17 ):
71+ # cond_str = get_cond_label(2**M, 2**N, 2**K)
72+ # if cond_str in ["COND_10", "COND_2"]:
73+ # print(cond_str, "M, N, K ", 2**M, 2**N, 2**K)
74+ MNK_list .append ((2 ** M , 2 ** N , 2 ** K ))
75+
76+ return MNK_list
77+
78+ @pytest .mark .parametrize ("M, N, K" ,
79+ [
80+ (56 , 8192 , 7392 ), # Qwen Fail Case
81+ (16 , 1280 , 1024 ), # Case 0 COND_1
82+ (16 , 1280 , 8192 ), # Case 1 COND_3
83+ (32 , 1280 , 8192 ), # Case 2 COND_3
84+ (64 , 1280 , 8192 ), # Case 3 COND_6
85+ (128 , 1280 , 8192 ), # Case 4 COND_6
86+ (16 , 8192 , 1024 ), # Case 5 COND_2
87+ (32 , 8192 , 1024 ), # Case 6 COND_2
88+ (64 , 8192 , 1024 ), # Case 7 COND_6
89+ (128 , 8192 , 1024 ), # Case 8 COND_6
90+ (16 , 7168 , 8192 ), # Case 9 COND_5
91+ (32 , 7168 , 8192 ), # Case 10 COND_5
92+ (64 , 7168 , 8192 ), # Case 11 COND_8
93+ (128 , 7168 , 8192 ), # Case 12 COND_8
94+ (1024 , 7168 , 8192 ), # Case 13 COND_9
95+ (2048 , 7168 , 8192 ), # Case 14 COND_9
96+ (4096 , 7168 , 8192 ), # Case 15 COND_9
97+ (8192 , 7168 , 8192 ), # Case 16 COND_9
98+ (16 , 8192 , 3584 ), # Case 17 COND_5
99+ (32 , 8192 , 3584 ), # Case 18 COND_5
100+ (64 , 8192 , 3584 ), # Case 19 COND_6
101+ (128 , 8192 , 3584 ), # Case 20 COND_6
102+ (1024 , 8192 , 3584 ), # Case 21 COND_9
103+ (2048 , 8192 , 3584 ), # Case 22 COND_9
104+ (4096 , 8192 , 3584 ), # Case 23 COND_9
105+ (8192 , 8192 , 3584 ), # Case 24 COND_9
106+ (32 , 13312 , 6656 ), # Case 25 COND_5
107+ (64 , 13312 , 6656 ), # Case 26 COND_6
108+ (128 , 13312 , 6656 ), # Case 27 COND_6
109+ (16 , 13312 , 16384 ), # Case 28 COND_5
110+ (32 , 13312 , 16384 ), # Case 29 COND_5
111+ (64 , 13312 , 16384 ), # Case 30 COND_8
112+ (128 , 13312 , 16384 ), # Case 31 COND_8
113+ (1024 , 13312 , 16384 ), # Case 32 COND_9
114+ (2048 , 13312 , 16384 ), # Case 33 COND_9
115+ (4096 , 13312 , 16384 ), # Case 34 COND_9
116+ (8192 , 13312 , 16384 ), # Case 35 COND_9
117+ (32 , 16384 , 6656 ), # Case 36 COND_5
118+ (64 , 16384 , 6656 ), # Case 37 COND_6
119+ (128 , 16384 , 6656 ), # Case 38 COND_6
120+ (1024 , 16384 , 6656 ), # Case 39 COND_9
121+ (2048 , 16384 , 6656 ), # Case 40 COND_9
122+ (4096 , 16384 , 6656 ), # Case 41 COND_9
123+ (8192 , 16384 , 6656 ), # Case 42 COND_9
124+ (16 , 16384 , 16384 ), # Case 43 COND_5
125+ (32 , 16384 , 16384 ), # Case 44 COND_5
126+ (64 , 16384 , 16384 ), # Case 45 COND_8
127+ (128 , 16384 , 16384 ), # Case 46 COND_8
128+ (1536 , 3584 , 3584 ), # Case 47 COND_9
129+ (8192 , 9728 , 3584 ), # Case 48 COND_9
130+ (8192 , 3584 , 9728 ), # Case 49 COND_9
131+ (8192 , 3584 , 3584 ), # Case 50 COND_9
132+ (4096 , 3584 , 3584 ), # Case 51 COND_9
133+ (768 , 3584 , 3584 ), # Case 52 COND_8
134+ (4096 , 9728 , 3584 ), # Case 53 COND_9
135+ (4096 , 3584 , 9728 ), # Case 54 COND_9
136+ (7200 , 3584 , 3584 ), # Case 55 COND_9
137+ (7200 , 9728 , 3584 ), # Case 56 COND_9
138+ (7200 , 3584 , 9728 ), # Case 57 COND_9
139+ (3600 , 3584 , 3584 ), # Case 58 COND_9
140+ (3600 , 9728 , 3584 ), # Case 59 COND_9
141+ (3600 , 3584 , 9728 ), # Case 60 COND_9
142+ (1536 , 4096 , 4096 ), # Case 61 COND_9
143+ (3600 , 4096 , 4096 ), # Case 62 COND_9
144+ (3600 , 11008 , 4096 ), # Case 63 COND_9
145+ (3600 , 4096 , 11008 ), # Case 64 COND_9
146+ (4096 , 4096 , 4096 ), # Case 65 COND_9
147+ (4096 , 11008 , 4096 ), # Case 66 COND_9
148+ (4096 , 4096 , 11008 ), # Case 67 COND_9
149+ (32768 , 128 , 8192 ), # Case 68 COND_6
150+ (32768 , 8192 , 1024 ), # Case 69 COND_6
151+ (32768 , 8192 , 3072 ), # Case 70 COND_9
152+ (32768 , 3072 , 8192 ), # Case 71 COND_9
153+ (32768 , 1024 , 8192 ), # Case 72 COND_6
154+ (512 , 2048 , 1000 ), # COND_7 FAILED
155+ (1024 , 512 , 512 ), # COND_7
156+ (512 , 204 , 512 ), # COND_7 FAILED
157+ (512 , 512 , 2048 ), # COND_6
158+ (4 , 2048 , 1024 ), # COND_2
159+ (2 , 16384 , 1024 ), # COND_2
160+ (1 , 32768 , 1 ), # COND_2 FAILED
161+ (32 , 16384 , 1024 ), # COND_2
162+ (1024 , 1 , 16384 ), # COND_10 FAILED
163+ (1024 , 512 , 32768 ), # COND_10
164+ (32768 , 512 , 32768 ), # COND_10
165+ ] )
166+ def test_f8f8bf16_rowwise_opcheck (M , N , K ):
167+ # Generate random input tensors
168+ XQ = generate_random_fp8_tensor ((M , K ))
169+ WQ = generate_random_fp8_tensor ((N , K ))
170+ x_scale = generate_random_scale_tensor ((M , 1 ))
171+ w_scale = generate_random_scale_tensor ((N , 1 ))
172+ # print(M, N, K, get_cond_label(M, N, K))
173+
174+ # f8f8bf16_rowwise expect
175+ # X = (M, K)
176+ # W = (N, K)
177+ output = opcheck (
178+ torch .ops ._fp8gemm_C .f8f8bf16_rowwise ,
179+ (XQ , WQ , x_scale , w_scale , None , True ),
180+ test_utils = "test_schema"
181+ )
182+
183+
184+ @pytest .mark .parametrize ("M, N, K" , [
185+ (56 , 8192 , 7392 ), # Qwen Fail Case
186+ (16 , 1280 , 1024 ), # Case 0 COND_1
187+ (16 , 1280 , 8192 ), # Case 1 COND_3
188+ (32 , 1280 , 8192 ), # Case 2 COND_3
189+ (64 , 1280 , 8192 ), # Case 3 COND_6
190+ (128 , 1280 , 8192 ), # Case 4 COND_6
191+ (16 , 8192 , 1024 ), # Case 5 COND_2
192+ (32 , 8192 , 1024 ), # Case 6 COND_2
193+ (64 , 8192 , 1024 ), # Case 7 COND_6
194+ (128 , 8192 , 1024 ), # Case 8 COND_6
195+ (16 , 7168 , 8192 ), # Case 9 COND_5
196+ (32 , 7168 , 8192 ), # Case 10 COND_5
197+ (64 , 7168 , 8192 ), # Case 11 COND_8
198+ (128 , 7168 , 8192 ), # Case 12 COND_8
199+ (1024 , 7168 , 8192 ), # Case 13 COND_9
200+ (2048 , 7168 , 8192 ), # Case 14 COND_9
201+ (4096 , 7168 , 8192 ), # Case 15 COND_9
202+ (8192 , 7168 , 8192 ), # Case 16 COND_9
203+ (16 , 8192 , 3584 ), # Case 17 COND_5
204+ (32 , 8192 , 3584 ), # Case 18 COND_5
205+ (64 , 8192 , 3584 ), # Case 19 COND_6
206+ (128 , 8192 , 3584 ), # Case 20 COND_6
207+ (1024 , 8192 , 3584 ), # Case 21 COND_9
208+ (2048 , 8192 , 3584 ), # Case 22 COND_9
209+ (4096 , 8192 , 3584 ), # Case 23 COND_9
210+ (8192 , 8192 , 3584 ), # Case 24 COND_9
211+ (32 , 13312 , 6656 ), # Case 25 COND_5
212+ (64 , 13312 , 6656 ), # Case 26 COND_6
213+ (128 , 13312 , 6656 ), # Case 27 COND_6
214+ (16 , 13312 , 16384 ), # Case 28 COND_5
215+ (32 , 13312 , 16384 ), # Case 29 COND_5
216+ (64 , 13312 , 16384 ), # Case 30 COND_8
217+ (128 , 13312 , 16384 ), # Case 31 COND_8
218+ (1024 , 13312 , 16384 ), # Case 32 COND_9
219+ (2048 , 13312 , 16384 ), # Case 33 COND_9
220+ (4096 , 13312 , 16384 ), # Case 34 COND_9
221+ (8192 , 13312 , 16384 ), # Case 35 COND_9
222+ (32 , 16384 , 6656 ), # Case 36 COND_5
223+ (64 , 16384 , 6656 ), # Case 37 COND_6
224+ (128 , 16384 , 6656 ), # Case 38 COND_6
225+ (1024 , 16384 , 6656 ), # Case 39 COND_9
226+ (2048 , 16384 , 6656 ), # Case 40 COND_9
227+ (4096 , 16384 , 6656 ), # Case 41 COND_9
228+ (8192 , 16384 , 6656 ), # Case 42 COND_9
229+ (16 , 16384 , 16384 ), # Case 43 COND_5
230+ (32 , 16384 , 16384 ), # Case 44 COND_5
231+ (64 , 16384 , 16384 ), # Case 45 COND_8
232+ (128 , 16384 , 16384 ), # Case 46 COND_8
233+ (1536 , 3584 , 3584 ), # Case 47 COND_9
234+ (8192 , 9728 , 3584 ), # Case 48 COND_9
235+ (8192 , 3584 , 9728 ), # Case 49 COND_9
236+ (8192 , 3584 , 3584 ), # Case 50 COND_9
237+ (4096 , 3584 , 3584 ), # Case 51 COND_9
238+ (768 , 3584 , 3584 ), # Case 52 COND_8
239+ (4096 , 9728 , 3584 ), # Case 53 COND_9
240+ (4096 , 3584 , 9728 ), # Case 54 COND_9
241+ (7200 , 3584 , 3584 ), # Case 55 COND_9
242+ (7200 , 9728 , 3584 ), # Case 56 COND_9
243+ (7200 , 3584 , 9728 ), # Case 57 COND_9
244+ (3600 , 3584 , 3584 ), # Case 58 COND_9
245+ (3600 , 9728 , 3584 ), # Case 59 COND_9
246+ (3600 , 3584 , 9728 ), # Case 60 COND_9
247+ (1536 , 4096 , 4096 ), # Case 61 COND_9
248+ (3600 , 4096 , 4096 ), # Case 62 COND_9
249+ (3600 , 11008 , 4096 ), # Case 63 COND_9
250+ (3600 , 4096 , 11008 ), # Case 64 COND_9
251+ (4096 , 4096 , 4096 ), # Case 65 COND_9
252+ (4096 , 11008 , 4096 ), # Case 66 COND_9
253+ (4096 , 4096 , 11008 ), # Case 67 COND_9
254+ (32768 , 128 , 8192 ), # Case 68 COND_6
255+ (32768 , 8192 , 1024 ), # Case 69 COND_6
256+ (32768 , 8192 , 3072 ), # Case 70 COND_9
257+ (32768 , 3072 , 8192 ), # Case 71 COND_9
258+ (32768 , 1024 , 8192 ), # Case 72 COND_6
259+ (512 , 2048 , 1000 ), # COND_7 FAILED
260+ (1024 , 512 , 512 ), # COND_7
261+ (512 , 204 , 512 ), # COND_7 FAILED
262+ (512 , 512 , 2048 ), # COND_6
263+ (4 , 2048 , 1024 ), # COND_2
264+ (2 , 16384 , 1024 ), # COND_2
265+ (1 , 32768 , 1 ), # COND_2 FAILED
266+ (32 , 16384 , 1024 ), # COND_2
267+ (1024 , 1 , 16384 ), # COND_10 FAILED
268+ (1024 , 512 , 32768 ), # COND_10
269+ (32768 , 512 , 32768 ), # COND_10
270+ ])
271+ def test_f8f8bf16_rowwise (M , N , K ):
272+
273+ print (M , N , K , get_cond_label (M , N , K ))
274+
275+ # Generate random input tensors
276+ XQ = generate_random_fp8_tensor ((M , K ))
277+ WQ = generate_random_fp8_tensor ((N , K ))
278+ x_scale = generate_random_scale_tensor ((M , 1 ))
279+ w_scale = generate_random_scale_tensor ((N , 1 ))
280+
281+ # Call the rowwise function
282+ # vec_scaled_mm_torch expect
283+ # X = (M, K)
284+ # W = (K, N)
285+ ref_output = vec_scaled_mm_torch (XQ , WQ .transpose (1 ,0 ), x_scale , w_scale , out_dtype = torch .bfloat16 )
286+
287+ # f8f8bf16_rowwise expect
288+ # X = (M, K)
289+ # W = (N, K)
290+ output = torch .ops ._fp8gemm_C .f8f8bf16_rowwise (XQ , WQ , x_scale , w_scale , None , True )
291+
292+ # Verify the output shape and dtype
293+ assert output .shape == (M , N )
294+ assert ref_output .shape == (M , N )
295+ assert output .dtype == torch .bfloat16
296+ assert ref_output .dtype == torch .bfloat16
297+ rtol , atol = (3e-2 , 1e-3 )
298+
299+ assert torch .allclose (output , ref_output , rtol = rtol , atol = atol )
300+
301+ # def test_f8f8bf16_rowwise_out(M, N, K):
302+ # # Generate random input tensors
303+ # XQ = generate_random_fp8_tensor((M, K))
304+ # WQ = generate_random_fp8_tensor((N, K))
305+ # x_scale = generate_random_scale_tensor((M, 1))
306+ # w_scale = generate_random_scale_tensor((N, 1))
307+ # output = torch.empty((M, N), dtype=torch.bfloat16, device=device)
308+
309+ # # Call the rowwise_out function
310+ # f8f8bf16_rowwise_out(XQ, WQ, x_scale, w_scale, output, None, True)
311+
312+ # # Verify the output shape and dtype
313+ # assert output.shape == (M, N)
314+ # assert output.dtype == torch.bfloat16
315+
316+ if __name__ == "__main__" :
317+ # pytest.main([__file__])
318+ print ("RUN test" )
319+
320+ for M in range (16 ):
321+ for N in range (16 ):
322+ for K in range (16 ):
323+ cond_str = get_cond_label (2 ** M , 2 ** N , 2 ** K )
324+ if cond_str in ["COND_10" , "COND_2" ]:
325+ print (cond_str , "M, N, K " , 2 ** M , 2 ** N , 2 ** K )
0 commit comments