Skip to content

Commit 03abe17

Browse files
author
vllmellm
committed
add unittests for fbgemm fp8 ck kernel
1 parent d21dbc9 commit 03abe17

File tree

2 files changed

+335
-10
lines changed

2 files changed

+335
-10
lines changed

csrc/fbgemm_fp8_rowwise/fp8_rowwise_gemm.cu

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -213,38 +213,38 @@ RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
213213
// if(!((N % 8 == 0) && (K % 16 == 0)))
214214
// return fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
215215

216-
if (M < 64 && N < 2048 && K < 2048) {
216+
if (M < 64 && N < 2048 && K < 2048) { // COND_1
217217
// Kernel that generally works well on small shapes.
218218
return fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2;
219-
} else if (M < 64 && K < 2048) {
219+
} else if (M < 64 && K < 2048) { // COND_2
220220
// Kernel that works well for small batch size and small K.
221221
return fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2;
222-
} else if (M < 64 && N < 2048) {
222+
} else if (M < 64 && N < 2048) { // COND_3
223223
// Kernel that works well for small batch size and small N.
224224
return fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2;
225-
// } else if (M < 64 && N > 2048 && K > 2048) {
225+
// } else if (M < 64 && N > 2048 && K > 2048) { // COND_4
226226
// // Kernel that works well for small M but larger N and K.
227227
// return fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
228-
} else if (M < 64) {
228+
} else if (M < 64) { // COND_5
229229
// Fallback to generic small batch kernel if we cant find a good match.
230230
return fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2;
231231
} else if (
232232
((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) ||
233233
(K <= 2048 && N <= 8192)) &&
234-
K >= 1024) {
234+
K >= 1024) { // COND_6
235235
// Kernel that is optimized for larger batch sizes but otherwise small
236236
// tensors.
237237
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5;
238-
} else if (K < 1024) {
238+
} else if (K < 1024) { // COND_7
239239
// Special case for small K.
240240
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
241-
} else if (M < 1024) {
241+
} else if (M < 1024) { // COND_8
242242
// Kernel for generic medium batch sizes.
243243
return fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
244-
} else if (M >= 1024 && N >= 1024 && K >= 1024) {
244+
} else if (M >= 1024 && N >= 1024 && K >= 1024) { // COND_9
245245
// Kernel for very large gemm
246246
return fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
247-
} else {
247+
} else { // COND_10
248248
// Fallback large kernel.
249249
return fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
250250
}
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
import torch
2+
import pytest
3+
from typing import Type, Optional
4+
import numpy as np
5+
6+
import vllm._fp8gemm_C # noqa: F401
7+
from vllm import _custom_ops as ops
8+
from vllm.platforms import current_platform
9+
10+
from tests.kernels.utils import opcheck
11+
12+
13+
device = "cuda"
14+
if not current_platform.is_rocm():
15+
pytest.skip(reason="FBGEMM Kernel currently only supported on ROCm through CK kernel.",
16+
allow_module_level=True)
17+
18+
def get_cond_label(M, N, K):
19+
if M < 64 and N < 2048 and K < 2048: # COND_1
20+
return "COND_1"
21+
elif M < 64 and K < 2048: # COND_2
22+
return "COND_2"
23+
elif M < 64 and N < 2048: # COND_3
24+
return "COND_3"
25+
# elif M < 64 and N > 2048 and K > 2048: # COND_4
26+
# return "COND_4"
27+
elif M < 64: # COND_5
28+
return "COND_5"
29+
elif (
30+
(M < 512 and K < 8192) or (N <= 2048 and K <= 8192) or (K <= 2048 and N <= 8192)
31+
) and K >= 1024: # COND_6
32+
return "COND_6"
33+
elif K < 1024: # COND_7
34+
return "COND_7"
35+
elif M < 1024: # COND_8
36+
return "COND_8"
37+
elif M >= 1024 and N >= 1024 and K >= 1024: # COND_9
38+
return "COND_9"
39+
else: # COND_10 Will never be triggered
40+
return "COND_10"
41+
42+
def vec_scaled_mm_torch(a: torch.Tensor,
43+
b: torch.Tensor,
44+
scale_a: torch.Tensor,
45+
scale_b: torch.Tensor,
46+
out_dtype: Type[torch.dtype],
47+
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
48+
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
49+
out = scale_a * out
50+
out = scale_b.T * out
51+
out = out.to(out_dtype)
52+
if bias is not None:
53+
out = out + bias
54+
55+
return out
56+
57+
def generate_random_fp8_tensor(shape):
58+
# Generate random float32 tensor and convert to float8_e4m3fnuz
59+
tensor = torch.rand(shape, dtype=torch.float32, device=device)
60+
return tensor.to(torch.float8_e4m3fnuz)
61+
62+
def generate_random_scale_tensor(shape):
63+
return torch.rand(shape, dtype=torch.float32, device=device)
64+
65+
def _get_opcheck_params():
66+
67+
MNK_list = []
68+
for M in range(17):
69+
for N in range(17):
70+
for K in range(17):
71+
# cond_str = get_cond_label(2**M, 2**N, 2**K)
72+
# if cond_str in ["COND_10", "COND_2"]:
73+
# print(cond_str, "M, N, K ", 2**M, 2**N, 2**K)
74+
MNK_list.append((2**M, 2**N, 2**K))
75+
76+
return MNK_list
77+
78+
@pytest.mark.parametrize("M, N, K",
79+
[
80+
(56, 8192, 7392), # Qwen Fail Case
81+
(16, 1280, 1024), # Case 0 COND_1
82+
(16, 1280, 8192), # Case 1 COND_3
83+
(32, 1280, 8192), # Case 2 COND_3
84+
(64, 1280, 8192), # Case 3 COND_6
85+
(128, 1280, 8192), # Case 4 COND_6
86+
(16, 8192, 1024), # Case 5 COND_2
87+
(32, 8192, 1024), # Case 6 COND_2
88+
(64, 8192, 1024), # Case 7 COND_6
89+
(128, 8192, 1024), # Case 8 COND_6
90+
(16, 7168, 8192), # Case 9 COND_5
91+
(32, 7168, 8192), # Case 10 COND_5
92+
(64, 7168, 8192), # Case 11 COND_8
93+
(128, 7168, 8192), # Case 12 COND_8
94+
(1024, 7168, 8192), # Case 13 COND_9
95+
(2048, 7168, 8192), # Case 14 COND_9
96+
(4096, 7168, 8192), # Case 15 COND_9
97+
(8192, 7168, 8192), # Case 16 COND_9
98+
(16, 8192, 3584), # Case 17 COND_5
99+
(32, 8192, 3584), # Case 18 COND_5
100+
(64, 8192, 3584), # Case 19 COND_6
101+
(128, 8192, 3584), # Case 20 COND_6
102+
(1024, 8192, 3584), # Case 21 COND_9
103+
(2048, 8192, 3584), # Case 22 COND_9
104+
(4096, 8192, 3584), # Case 23 COND_9
105+
(8192, 8192, 3584), # Case 24 COND_9
106+
(32, 13312, 6656), # Case 25 COND_5
107+
(64, 13312, 6656), # Case 26 COND_6
108+
(128, 13312, 6656), # Case 27 COND_6
109+
(16, 13312, 16384), # Case 28 COND_5
110+
(32, 13312, 16384), # Case 29 COND_5
111+
(64, 13312, 16384), # Case 30 COND_8
112+
(128, 13312, 16384), # Case 31 COND_8
113+
(1024, 13312, 16384), # Case 32 COND_9
114+
(2048, 13312, 16384), # Case 33 COND_9
115+
(4096, 13312, 16384), # Case 34 COND_9
116+
(8192, 13312, 16384), # Case 35 COND_9
117+
(32, 16384, 6656), # Case 36 COND_5
118+
(64, 16384, 6656), # Case 37 COND_6
119+
(128, 16384, 6656), # Case 38 COND_6
120+
(1024, 16384, 6656), # Case 39 COND_9
121+
(2048, 16384, 6656), # Case 40 COND_9
122+
(4096, 16384, 6656), # Case 41 COND_9
123+
(8192, 16384, 6656), # Case 42 COND_9
124+
(16, 16384, 16384), # Case 43 COND_5
125+
(32, 16384, 16384), # Case 44 COND_5
126+
(64, 16384, 16384), # Case 45 COND_8
127+
(128, 16384, 16384), # Case 46 COND_8
128+
(1536, 3584, 3584), # Case 47 COND_9
129+
(8192, 9728, 3584), # Case 48 COND_9
130+
(8192, 3584, 9728), # Case 49 COND_9
131+
(8192, 3584, 3584), # Case 50 COND_9
132+
(4096, 3584, 3584), # Case 51 COND_9
133+
(768, 3584, 3584), # Case 52 COND_8
134+
(4096, 9728, 3584), # Case 53 COND_9
135+
(4096, 3584, 9728), # Case 54 COND_9
136+
(7200, 3584, 3584), # Case 55 COND_9
137+
(7200, 9728, 3584), # Case 56 COND_9
138+
(7200, 3584, 9728), # Case 57 COND_9
139+
(3600, 3584, 3584), # Case 58 COND_9
140+
(3600, 9728, 3584), # Case 59 COND_9
141+
(3600, 3584, 9728), # Case 60 COND_9
142+
(1536, 4096, 4096), # Case 61 COND_9
143+
(3600, 4096, 4096), # Case 62 COND_9
144+
(3600, 11008, 4096), # Case 63 COND_9
145+
(3600, 4096, 11008), # Case 64 COND_9
146+
(4096, 4096, 4096), # Case 65 COND_9
147+
(4096, 11008, 4096), # Case 66 COND_9
148+
(4096, 4096, 11008), # Case 67 COND_9
149+
(32768, 128, 8192), # Case 68 COND_6
150+
(32768, 8192, 1024), # Case 69 COND_6
151+
(32768, 8192, 3072), # Case 70 COND_9
152+
(32768, 3072, 8192), # Case 71 COND_9
153+
(32768, 1024, 8192), # Case 72 COND_6
154+
(512, 2048, 1000), # COND_7 FAILED
155+
(1024, 512, 512), # COND_7
156+
(512, 204, 512), # COND_7 FAILED
157+
(512, 512, 2048), # COND_6
158+
(4, 2048, 1024), # COND_2
159+
(2, 16384, 1024), # COND_2
160+
(1, 32768, 1), # COND_2 FAILED
161+
(32, 16384, 1024), # COND_2
162+
(1024, 1, 16384), # COND_10 FAILED
163+
(1024, 512, 32768), # COND_10
164+
(32768, 512, 32768), # COND_10
165+
] )
166+
def test_f8f8bf16_rowwise_opcheck(M, N, K):
167+
# Generate random input tensors
168+
XQ = generate_random_fp8_tensor((M, K))
169+
WQ = generate_random_fp8_tensor((N, K))
170+
x_scale = generate_random_scale_tensor((M, 1))
171+
w_scale = generate_random_scale_tensor((N, 1))
172+
# print(M, N, K, get_cond_label(M, N, K))
173+
174+
# f8f8bf16_rowwise expect
175+
# X = (M, K)
176+
# W = (N, K)
177+
output = opcheck(
178+
torch.ops._fp8gemm_C.f8f8bf16_rowwise,
179+
(XQ, WQ, x_scale, w_scale, None, True),
180+
test_utils="test_schema"
181+
)
182+
183+
184+
@pytest.mark.parametrize("M, N, K", [
185+
(56, 8192, 7392), # Qwen Fail Case
186+
(16, 1280, 1024), # Case 0 COND_1
187+
(16, 1280, 8192), # Case 1 COND_3
188+
(32, 1280, 8192), # Case 2 COND_3
189+
(64, 1280, 8192), # Case 3 COND_6
190+
(128, 1280, 8192), # Case 4 COND_6
191+
(16, 8192, 1024), # Case 5 COND_2
192+
(32, 8192, 1024), # Case 6 COND_2
193+
(64, 8192, 1024), # Case 7 COND_6
194+
(128, 8192, 1024), # Case 8 COND_6
195+
(16, 7168, 8192), # Case 9 COND_5
196+
(32, 7168, 8192), # Case 10 COND_5
197+
(64, 7168, 8192), # Case 11 COND_8
198+
(128, 7168, 8192), # Case 12 COND_8
199+
(1024, 7168, 8192), # Case 13 COND_9
200+
(2048, 7168, 8192), # Case 14 COND_9
201+
(4096, 7168, 8192), # Case 15 COND_9
202+
(8192, 7168, 8192), # Case 16 COND_9
203+
(16, 8192, 3584), # Case 17 COND_5
204+
(32, 8192, 3584), # Case 18 COND_5
205+
(64, 8192, 3584), # Case 19 COND_6
206+
(128, 8192, 3584), # Case 20 COND_6
207+
(1024, 8192, 3584), # Case 21 COND_9
208+
(2048, 8192, 3584), # Case 22 COND_9
209+
(4096, 8192, 3584), # Case 23 COND_9
210+
(8192, 8192, 3584), # Case 24 COND_9
211+
(32, 13312, 6656), # Case 25 COND_5
212+
(64, 13312, 6656), # Case 26 COND_6
213+
(128, 13312, 6656), # Case 27 COND_6
214+
(16, 13312, 16384), # Case 28 COND_5
215+
(32, 13312, 16384), # Case 29 COND_5
216+
(64, 13312, 16384), # Case 30 COND_8
217+
(128, 13312, 16384), # Case 31 COND_8
218+
(1024, 13312, 16384), # Case 32 COND_9
219+
(2048, 13312, 16384), # Case 33 COND_9
220+
(4096, 13312, 16384), # Case 34 COND_9
221+
(8192, 13312, 16384), # Case 35 COND_9
222+
(32, 16384, 6656), # Case 36 COND_5
223+
(64, 16384, 6656), # Case 37 COND_6
224+
(128, 16384, 6656), # Case 38 COND_6
225+
(1024, 16384, 6656), # Case 39 COND_9
226+
(2048, 16384, 6656), # Case 40 COND_9
227+
(4096, 16384, 6656), # Case 41 COND_9
228+
(8192, 16384, 6656), # Case 42 COND_9
229+
(16, 16384, 16384), # Case 43 COND_5
230+
(32, 16384, 16384), # Case 44 COND_5
231+
(64, 16384, 16384), # Case 45 COND_8
232+
(128, 16384, 16384), # Case 46 COND_8
233+
(1536, 3584, 3584), # Case 47 COND_9
234+
(8192, 9728, 3584), # Case 48 COND_9
235+
(8192, 3584, 9728), # Case 49 COND_9
236+
(8192, 3584, 3584), # Case 50 COND_9
237+
(4096, 3584, 3584), # Case 51 COND_9
238+
(768, 3584, 3584), # Case 52 COND_8
239+
(4096, 9728, 3584), # Case 53 COND_9
240+
(4096, 3584, 9728), # Case 54 COND_9
241+
(7200, 3584, 3584), # Case 55 COND_9
242+
(7200, 9728, 3584), # Case 56 COND_9
243+
(7200, 3584, 9728), # Case 57 COND_9
244+
(3600, 3584, 3584), # Case 58 COND_9
245+
(3600, 9728, 3584), # Case 59 COND_9
246+
(3600, 3584, 9728), # Case 60 COND_9
247+
(1536, 4096, 4096), # Case 61 COND_9
248+
(3600, 4096, 4096), # Case 62 COND_9
249+
(3600, 11008, 4096), # Case 63 COND_9
250+
(3600, 4096, 11008), # Case 64 COND_9
251+
(4096, 4096, 4096), # Case 65 COND_9
252+
(4096, 11008, 4096), # Case 66 COND_9
253+
(4096, 4096, 11008), # Case 67 COND_9
254+
(32768, 128, 8192), # Case 68 COND_6
255+
(32768, 8192, 1024), # Case 69 COND_6
256+
(32768, 8192, 3072), # Case 70 COND_9
257+
(32768, 3072, 8192), # Case 71 COND_9
258+
(32768, 1024, 8192), # Case 72 COND_6
259+
(512, 2048, 1000), # COND_7 FAILED
260+
(1024, 512, 512), # COND_7
261+
(512, 204, 512), # COND_7 FAILED
262+
(512, 512, 2048), # COND_6
263+
(4, 2048, 1024), # COND_2
264+
(2, 16384, 1024), # COND_2
265+
(1, 32768, 1), # COND_2 FAILED
266+
(32, 16384, 1024), # COND_2
267+
(1024, 1, 16384), # COND_10 FAILED
268+
(1024, 512, 32768), # COND_10
269+
(32768, 512, 32768), # COND_10
270+
])
271+
def test_f8f8bf16_rowwise(M, N, K):
272+
273+
print(M, N, K, get_cond_label(M, N, K))
274+
275+
# Generate random input tensors
276+
XQ = generate_random_fp8_tensor((M, K))
277+
WQ = generate_random_fp8_tensor((N, K))
278+
x_scale = generate_random_scale_tensor((M, 1))
279+
w_scale = generate_random_scale_tensor((N, 1))
280+
281+
# Call the rowwise function
282+
# vec_scaled_mm_torch expect
283+
# X = (M, K)
284+
# W = (K, N)
285+
ref_output = vec_scaled_mm_torch(XQ, WQ.transpose(1,0), x_scale, w_scale, out_dtype=torch.bfloat16)
286+
287+
# f8f8bf16_rowwise expect
288+
# X = (M, K)
289+
# W = (N, K)
290+
output = torch.ops._fp8gemm_C.f8f8bf16_rowwise(XQ, WQ, x_scale, w_scale, None, True)
291+
292+
# Verify the output shape and dtype
293+
assert output.shape == (M, N)
294+
assert ref_output.shape == (M, N)
295+
assert output.dtype == torch.bfloat16
296+
assert ref_output.dtype == torch.bfloat16
297+
rtol, atol = (3e-2, 1e-3)
298+
299+
assert torch.allclose(output, ref_output, rtol=rtol, atol=atol)
300+
301+
# def test_f8f8bf16_rowwise_out(M, N, K):
302+
# # Generate random input tensors
303+
# XQ = generate_random_fp8_tensor((M, K))
304+
# WQ = generate_random_fp8_tensor((N, K))
305+
# x_scale = generate_random_scale_tensor((M, 1))
306+
# w_scale = generate_random_scale_tensor((N, 1))
307+
# output = torch.empty((M, N), dtype=torch.bfloat16, device=device)
308+
309+
# # Call the rowwise_out function
310+
# f8f8bf16_rowwise_out(XQ, WQ, x_scale, w_scale, output, None, True)
311+
312+
# # Verify the output shape and dtype
313+
# assert output.shape == (M, N)
314+
# assert output.dtype == torch.bfloat16
315+
316+
if __name__ == "__main__":
317+
# pytest.main([__file__])
318+
print("RUN test")
319+
320+
for M in range(16):
321+
for N in range(16):
322+
for K in range(16):
323+
cond_str = get_cond_label(2**M, 2**N, 2**K)
324+
if cond_str in ["COND_10", "COND_2"]:
325+
print(cond_str, "M, N, K ", 2**M, 2**N, 2**K)

0 commit comments

Comments
 (0)