Skip to content

Commit 0e3403b

Browse files
authored
bugfix: fix flashinfer_benchmark.py IMA when running a test list (#1625)
<!-- .github/pull_request_template.md --> ## 📌 Description The current `flashinfer_benchmark.py` script can trigger an IMA when a testlist is provided to batch-benchmark multiple test cases. Current PR: * Fixes the bug by clearing torch's cache and synchronizing the device at the beginning of each test. - IMA occurs between test cases and can only be reproduced when running a testlist. - Fix is to `torch.cuda.empty_cache()` and `torch.cuda.synchronize()` at the beginning of each test. * Miscellaneous improvements to flashinfer_benchmark.py: - Attention benchmarks: - Reduces unnecessary reference calculations. - Prints statistics upon encountering reference check failures - GEMM benchmarks: - Allows testing `trtllm`backend in `testGemmFp8NtGroupwise` API *No changes to the library code* <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 75df649 commit 0e3403b

File tree

3 files changed

+86
-70
lines changed

3 files changed

+86
-70
lines changed

benchmarks/routines/attention.py

Lines changed: 57 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
dtype_str_to_torch_dtype,
1616
get_device,
1717
print_perf_metrics,
18+
is_close_stats,
1819
)
1920

2021

@@ -485,7 +486,7 @@ def run_backend_wrapper(backend):
485486
)
486487
elif backend == "trtllm-gen-native":
487488
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
488-
query=q,
489+
query=q.contiguous(),
489490
kv_cache=kv_cache,
490491
workspace_buffer=workspace_buffer,
491492
block_tables=block_tables,
@@ -498,19 +499,14 @@ def run_backend_wrapper(backend):
498499
raise ValueError(f"Backend {backend} not supported")
499500

500501
has_reference_output = False
501-
if run_refcheck and "fa2" in backends:
502-
reference_output = (
503-
backend_wrappers["fa2"]
504-
.run(q, kv_cache, k_scale=k_scale, v_scale=v_scale)
505-
.detach()
506-
)
507-
has_reference_output = True
508-
509502
# Iterate over each backend:
510503
for cur_backend in backends:
511504
if run_refcheck:
512-
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach()
513-
if is_cuda_graph_compatible:
505+
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
506+
if cur_backend == "fa2":
507+
has_reference_output = True
508+
reference_output = outputs[cur_backend]
509+
if is_cuda_graph_compatible and cur_backend != "fa2":
514510
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
515511
fn=lambda: run_backend_wrapper(cur_backend),
516512
dry_run_iters=args.dry_run_iters,
@@ -550,8 +546,14 @@ def run_backend_wrapper(backend):
550546
reference_output, tested_outputs[i], rtol=rtol, atol=atol
551547
)
552548
except AssertionError as e:
549+
(
550+
num_different_elements,
551+
num_elements,
552+
num_different_elements_percentage,
553+
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
553554
print(
554-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
555+
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
556+
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
555557
)
556558
if not args.allow_output_mismatch:
557559
print(e)
@@ -721,9 +723,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
721723

722724
# Check for layer-specific constraints
723725
layer_not_supported = False
724-
if not ((head_dim_qk == 128 and head_dim_qk == head_dim_vo) or head_dim_qk == 192):
725-
print("[ERROR] Head dimension must be 128 or 192")
726-
layer_not_supported = True
727726
if layer_not_supported:
728727
print("[ERROR] Layer not supported. Exiting.")
729728
return
@@ -882,7 +881,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
882881
flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
883882
workspace_buffer,
884883
"HND",
885-
use_cuda_graph=is_cuda_graph_compatible,
884+
use_cuda_graph=is_cuda_graph_compatible
885+
if backend != "fa2"
886+
else False,
886887
qo_indptr_buf=qo_indptr,
887888
paged_kv_indptr_buf=kv_indptr,
888889
paged_kv_indices_buf=kv_indices,
@@ -958,17 +959,14 @@ def run_backend_wrapper(backend):
958959
raise ValueError(f"Backend {backend} not supported")
959960

960961
has_reference_output = False
961-
if run_refcheck and "fa2" in backends:
962-
reference_output = backend_wrappers["fa2"].run(
963-
q, kv_cache, k_scale=k_scale, v_scale=v_scale
964-
)
965-
has_reference_output = True
966-
967962
# Iterate over each backend:
968963
for cur_backend in backends:
969964
if run_refcheck:
970-
outputs[cur_backend] = run_backend_wrapper(cur_backend)
971-
if is_cuda_graph_compatible:
965+
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
966+
if cur_backend == "fa2":
967+
has_reference_output = True
968+
reference_output = outputs[cur_backend]
969+
if is_cuda_graph_compatible and cur_backend != "fa2":
972970
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
973971
fn=lambda: run_backend_wrapper(cur_backend),
974972
dry_run_iters=args.dry_run_iters,
@@ -1008,8 +1006,14 @@ def run_backend_wrapper(backend):
10081006
reference_output, tested_outputs[i], rtol=rtol, atol=atol
10091007
)
10101008
except AssertionError as e:
1009+
(
1010+
num_different_elements,
1011+
num_elements,
1012+
num_different_elements_percentage,
1013+
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
10111014
print(
1012-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
1015+
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
1016+
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
10131017
)
10141018
if not args.allow_output_mismatch:
10151019
print(e)
@@ -1295,7 +1299,9 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
12951299
flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
12961300
workspace_buffer,
12971301
"NHD",
1298-
use_cuda_graph=is_cuda_graph_compatible,
1302+
use_cuda_graph=is_cuda_graph_compatible
1303+
if backend != "fa2"
1304+
else False,
12991305
qo_indptr_buf=qo_indptr,
13001306
kv_indptr_buf=kv_indptr,
13011307
backend=backend,
@@ -1350,15 +1356,14 @@ def run_backend_wrapper(backend):
13501356
raise ValueError(f"Backend {backend} not supported")
13511357

13521358
has_reference_output = False
1353-
if run_refcheck and "fa2" in backends:
1354-
reference_output = backend_wrappers["fa2"].run_return_lse(q, k, v)[0]
1355-
has_reference_output = True
1356-
13571359
# Iterate over each backend:
13581360
for cur_backend in backends:
13591361
if run_refcheck:
1360-
outputs[cur_backend] = run_backend_wrapper(cur_backend)
1361-
if is_cuda_graph_compatible:
1362+
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
1363+
if cur_backend == "fa2":
1364+
has_reference_output = True
1365+
reference_output = outputs[cur_backend]
1366+
if is_cuda_graph_compatible and cur_backend != "fa2":
13621367
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
13631368
fn=lambda: run_backend_wrapper(cur_backend),
13641369
dry_run_iters=args.dry_run_iters,
@@ -1398,8 +1403,14 @@ def run_backend_wrapper(backend):
13981403
reference_output, tested_outputs[i], rtol=rtol, atol=atol
13991404
)
14001405
except AssertionError as e:
1406+
(
1407+
num_different_elements,
1408+
num_elements,
1409+
num_different_elements_percentage,
1410+
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
14011411
print(
1402-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
1412+
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
1413+
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
14031414
)
14041415
if not args.allow_output_mismatch:
14051416
print(e)
@@ -1693,19 +1704,15 @@ def run_backend_wrapper(backend):
16931704
else:
16941705
raise ValueError(f"Unsupported backend: {backend}")
16951706

1696-
if run_refcheck and "fa2" in backends:
1697-
reference_output = fi_fa2_mla_wrapper.run(
1698-
q_nope, q_pe, ckv_cache, kpe_cache, return_lse=False
1699-
)
1700-
has_reference_output = True
1701-
else:
1702-
has_reference_output = False
1703-
1707+
has_reference_output = False
17041708
# Iterate over each backend:
17051709
for cur_backend in backends:
17061710
if run_refcheck:
1707-
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach()
1708-
if is_cuda_graph_compatible:
1711+
outputs[cur_backend] = run_backend_wrapper(cur_backend).detach().clone()
1712+
if cur_backend == "fa2":
1713+
has_reference_output = True
1714+
reference_output = outputs[cur_backend]
1715+
if is_cuda_graph_compatible and cur_backend != "fa2":
17091716
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
17101717
fn=lambda: run_backend_wrapper(cur_backend),
17111718
dry_run_iters=args.dry_run_iters,
@@ -1741,8 +1748,14 @@ def run_backend_wrapper(backend):
17411748
reference_output, tested_outputs[i], rtol=rtol, atol=atol
17421749
)
17431750
except AssertionError as e:
1751+
(
1752+
num_different_elements,
1753+
num_elements,
1754+
num_different_elements_percentage,
1755+
) = is_close_stats(reference_output, tested_outputs[i], rtol, atol)
17441756
print(
1745-
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
1757+
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}: "
1758+
f"{num_different_elements} / {num_elements} ({num_different_elements_percentage:.2f}%) elements are different"
17461759
)
17471760
if not args.allow_output_mismatch:
17481761
print(e)

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ def print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec):
117117

118118

119119
def get_device(args):
120+
# Synchronize to ensure that the device is ready after previous tests
121+
torch.cuda.empty_cache()
122+
torch.cuda.synchronize()
120123
set_seed(args.random_seed)
121124
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122125
gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()).replace(" ", "_")
@@ -125,6 +128,17 @@ def get_device(args):
125128
return device
126129

127130

131+
def is_close_stats(input, other, rtol=1e-5, atol=1e-8):
132+
close_tensor = torch.isclose(input, other, rtol=rtol, atol=atol)
133+
num_elements = close_tensor.numel()
134+
num_different_elements = num_elements - close_tensor.sum().item()
135+
return (
136+
num_different_elements, # number of different elements
137+
num_elements, # total number of elements in tensor
138+
num_different_elements / num_elements * 100.0,
139+
)
140+
141+
128142
def dtype_str_to_torch_dtype(dtype_str):
129143
if dtype_str == "bfloat16":
130144
return torch.bfloat16

benchmarks/routines/gemm.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,22 @@ def testGemmFp8NtGroupwise(args):
197197
## Done parsing input arguments
198198

199199
if "trtllm" in backends:
200-
remove_trtllm = True
201-
print("[INFO] trtllm backend testing not supported yet")
200+
remove_trtllm = False
201+
if scale_major_mode != "MN":
202+
print(
203+
"[INFO] trtllm only supports MN scale_major_mode, removing trtllm from backends"
204+
)
205+
remove_trtllm = True
206+
if k < 256:
207+
print("[INFO] trtllm only supports k >= 256, removing trtllm from backends")
208+
remove_trtllm = True
202209
if remove_trtllm:
203210
backends.remove("trtllm")
204211

212+
if len(backends) == 0:
213+
print("[ERROR] No backends to test. Exiting.")
214+
return
215+
205216
## Prepare input tensors
206217
a_val = torch.randn((m, k), dtype=torch.float, device=device)
207218
b_val = torch.randn((n, k), dtype=torch.float, device=device) / np.sqrt(k)
@@ -223,17 +234,6 @@ def testGemmFp8NtGroupwise(args):
223234
a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode)
224235
b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode)
225236

226-
if "trtllm" in backends:
227-
a_scale_shape_trtllm = (m, k // tile_size)
228-
b_scale_shape_trtllm = (k // tile_size, n // tile_size)
229-
230-
a_fp8_trtllm, a_scale_trtllm = quantize_fp8(
231-
a_val, a_scale_shape_trtllm, a_tile_shape, "K"
232-
)
233-
b_fp8_trtllm, b_scale_trtllm = quantize_fp8(
234-
b_val, b_scale_shape_trtllm, b_tile_shape, "MN"
235-
)
236-
237237
if args.verbose >= 2:
238238
print(f"[VVERBOSE] {a_fp8.shape = }")
239239
print(f"[VVERBOSE] {b_fp8.shape = }")
@@ -244,7 +244,7 @@ def testGemmFp8NtGroupwise(args):
244244
b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode)
245245

246246
def run_backend(backend):
247-
if backend == "cutlass":
247+
if backend in ["cutlass", "trtllm"]:
248248
return flashinfer.gemm.gemm_fp8_nt_groupwise(
249249
a=a_fp8,
250250
b=b_fp8,
@@ -253,18 +253,7 @@ def run_backend(backend):
253253
scale_major_mode=scale_major_mode,
254254
out_dtype=out_dtype,
255255
mma_sm=mma_sm,
256-
backend="cutlass",
257-
)
258-
elif backend == "trtllm":
259-
return flashinfer.gemm.gemm_fp8_nt_groupwise(
260-
a=a_fp8,
261-
b=b_fp8,
262-
a_scale=a_scale,
263-
b_scale=b_scale,
264-
scale_major_mode=None,
265-
out_dtype=out_dtype,
266-
mma_sm=mma_sm,
267-
backend="trtllm",
256+
backend=backend,
268257
)
269258
else:
270259
raise ValueError(f"Unsupported backend: {backend}")

0 commit comments

Comments
 (0)