15
15
dtype_str_to_torch_dtype ,
16
16
get_device ,
17
17
print_perf_metrics ,
18
+ is_close_stats ,
18
19
)
19
20
20
21
@@ -485,7 +486,7 @@ def run_backend_wrapper(backend):
485
486
)
486
487
elif backend == "trtllm-gen-native" :
487
488
return flashinfer .decode .trtllm_batch_decode_with_kv_cache (
488
- query = q ,
489
+ query = q . contiguous () ,
489
490
kv_cache = kv_cache ,
490
491
workspace_buffer = workspace_buffer ,
491
492
block_tables = block_tables ,
@@ -498,19 +499,14 @@ def run_backend_wrapper(backend):
498
499
raise ValueError (f"Backend { backend } not supported" )
499
500
500
501
has_reference_output = False
501
- if run_refcheck and "fa2" in backends :
502
- reference_output = (
503
- backend_wrappers ["fa2" ]
504
- .run (q , kv_cache , k_scale = k_scale , v_scale = v_scale )
505
- .detach ()
506
- )
507
- has_reference_output = True
508
-
509
502
# Iterate over each backend:
510
503
for cur_backend in backends :
511
504
if run_refcheck :
512
- outputs [cur_backend ] = run_backend_wrapper (cur_backend ).detach ()
513
- if is_cuda_graph_compatible :
505
+ outputs [cur_backend ] = run_backend_wrapper (cur_backend ).detach ().clone ()
506
+ if cur_backend == "fa2" :
507
+ has_reference_output = True
508
+ reference_output = outputs [cur_backend ]
509
+ if is_cuda_graph_compatible and cur_backend != "fa2" :
514
510
backend_times [cur_backend ] = bench_gpu_time_with_cudagraph (
515
511
fn = lambda : run_backend_wrapper (cur_backend ),
516
512
dry_run_iters = args .dry_run_iters ,
@@ -550,8 +546,14 @@ def run_backend_wrapper(backend):
550
546
reference_output , tested_outputs [i ], rtol = rtol , atol = atol
551
547
)
552
548
except AssertionError as e :
549
+ (
550
+ num_different_elements ,
551
+ num_elements ,
552
+ num_different_elements_percentage ,
553
+ ) = is_close_stats (reference_output , tested_outputs [i ], rtol , atol )
553
554
print (
554
- f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} "
555
+ f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} : "
556
+ f"{ num_different_elements } / { num_elements } ({ num_different_elements_percentage :.2f} %) elements are different"
555
557
)
556
558
if not args .allow_output_mismatch :
557
559
print (e )
@@ -721,9 +723,6 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
721
723
722
724
# Check for layer-specific constraints
723
725
layer_not_supported = False
724
- if not ((head_dim_qk == 128 and head_dim_qk == head_dim_vo ) or head_dim_qk == 192 ):
725
- print ("[ERROR] Head dimension must be 128 or 192" )
726
- layer_not_supported = True
727
726
if layer_not_supported :
728
727
print ("[ERROR] Layer not supported. Exiting." )
729
728
return
@@ -882,7 +881,9 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
882
881
flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
883
882
workspace_buffer ,
884
883
"HND" ,
885
- use_cuda_graph = is_cuda_graph_compatible ,
884
+ use_cuda_graph = is_cuda_graph_compatible
885
+ if backend != "fa2"
886
+ else False ,
886
887
qo_indptr_buf = qo_indptr ,
887
888
paged_kv_indptr_buf = kv_indptr ,
888
889
paged_kv_indices_buf = kv_indices ,
@@ -958,17 +959,14 @@ def run_backend_wrapper(backend):
958
959
raise ValueError (f"Backend { backend } not supported" )
959
960
960
961
has_reference_output = False
961
- if run_refcheck and "fa2" in backends :
962
- reference_output = backend_wrappers ["fa2" ].run (
963
- q , kv_cache , k_scale = k_scale , v_scale = v_scale
964
- )
965
- has_reference_output = True
966
-
967
962
# Iterate over each backend:
968
963
for cur_backend in backends :
969
964
if run_refcheck :
970
- outputs [cur_backend ] = run_backend_wrapper (cur_backend )
971
- if is_cuda_graph_compatible :
965
+ outputs [cur_backend ] = run_backend_wrapper (cur_backend ).detach ().clone ()
966
+ if cur_backend == "fa2" :
967
+ has_reference_output = True
968
+ reference_output = outputs [cur_backend ]
969
+ if is_cuda_graph_compatible and cur_backend != "fa2" :
972
970
backend_times [cur_backend ] = bench_gpu_time_with_cudagraph (
973
971
fn = lambda : run_backend_wrapper (cur_backend ),
974
972
dry_run_iters = args .dry_run_iters ,
@@ -1008,8 +1006,14 @@ def run_backend_wrapper(backend):
1008
1006
reference_output , tested_outputs [i ], rtol = rtol , atol = atol
1009
1007
)
1010
1008
except AssertionError as e :
1009
+ (
1010
+ num_different_elements ,
1011
+ num_elements ,
1012
+ num_different_elements_percentage ,
1013
+ ) = is_close_stats (reference_output , tested_outputs [i ], rtol , atol )
1011
1014
print (
1012
- f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} "
1015
+ f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} : "
1016
+ f"{ num_different_elements } / { num_elements } ({ num_different_elements_percentage :.2f} %) elements are different"
1013
1017
)
1014
1018
if not args .allow_output_mismatch :
1015
1019
print (e )
@@ -1295,7 +1299,9 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args):
1295
1299
flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
1296
1300
workspace_buffer ,
1297
1301
"NHD" ,
1298
- use_cuda_graph = is_cuda_graph_compatible ,
1302
+ use_cuda_graph = is_cuda_graph_compatible
1303
+ if backend != "fa2"
1304
+ else False ,
1299
1305
qo_indptr_buf = qo_indptr ,
1300
1306
kv_indptr_buf = kv_indptr ,
1301
1307
backend = backend ,
@@ -1350,15 +1356,14 @@ def run_backend_wrapper(backend):
1350
1356
raise ValueError (f"Backend { backend } not supported" )
1351
1357
1352
1358
has_reference_output = False
1353
- if run_refcheck and "fa2" in backends :
1354
- reference_output = backend_wrappers ["fa2" ].run_return_lse (q , k , v )[0 ]
1355
- has_reference_output = True
1356
-
1357
1359
# Iterate over each backend:
1358
1360
for cur_backend in backends :
1359
1361
if run_refcheck :
1360
- outputs [cur_backend ] = run_backend_wrapper (cur_backend )
1361
- if is_cuda_graph_compatible :
1362
+ outputs [cur_backend ] = run_backend_wrapper (cur_backend ).detach ().clone ()
1363
+ if cur_backend == "fa2" :
1364
+ has_reference_output = True
1365
+ reference_output = outputs [cur_backend ]
1366
+ if is_cuda_graph_compatible and cur_backend != "fa2" :
1362
1367
backend_times [cur_backend ] = bench_gpu_time_with_cudagraph (
1363
1368
fn = lambda : run_backend_wrapper (cur_backend ),
1364
1369
dry_run_iters = args .dry_run_iters ,
@@ -1398,8 +1403,14 @@ def run_backend_wrapper(backend):
1398
1403
reference_output , tested_outputs [i ], rtol = rtol , atol = atol
1399
1404
)
1400
1405
except AssertionError as e :
1406
+ (
1407
+ num_different_elements ,
1408
+ num_elements ,
1409
+ num_different_elements_percentage ,
1410
+ ) = is_close_stats (reference_output , tested_outputs [i ], rtol , atol )
1401
1411
print (
1402
- f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} "
1412
+ f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} : "
1413
+ f"{ num_different_elements } / { num_elements } ({ num_different_elements_percentage :.2f} %) elements are different"
1403
1414
)
1404
1415
if not args .allow_output_mismatch :
1405
1416
print (e )
@@ -1693,19 +1704,15 @@ def run_backend_wrapper(backend):
1693
1704
else :
1694
1705
raise ValueError (f"Unsupported backend: { backend } " )
1695
1706
1696
- if run_refcheck and "fa2" in backends :
1697
- reference_output = fi_fa2_mla_wrapper .run (
1698
- q_nope , q_pe , ckv_cache , kpe_cache , return_lse = False
1699
- )
1700
- has_reference_output = True
1701
- else :
1702
- has_reference_output = False
1703
-
1707
+ has_reference_output = False
1704
1708
# Iterate over each backend:
1705
1709
for cur_backend in backends :
1706
1710
if run_refcheck :
1707
- outputs [cur_backend ] = run_backend_wrapper (cur_backend ).detach ()
1708
- if is_cuda_graph_compatible :
1711
+ outputs [cur_backend ] = run_backend_wrapper (cur_backend ).detach ().clone ()
1712
+ if cur_backend == "fa2" :
1713
+ has_reference_output = True
1714
+ reference_output = outputs [cur_backend ]
1715
+ if is_cuda_graph_compatible and cur_backend != "fa2" :
1709
1716
backend_times [cur_backend ] = bench_gpu_time_with_cudagraph (
1710
1717
fn = lambda : run_backend_wrapper (cur_backend ),
1711
1718
dry_run_iters = args .dry_run_iters ,
@@ -1741,8 +1748,14 @@ def run_backend_wrapper(backend):
1741
1748
reference_output , tested_outputs [i ], rtol = rtol , atol = atol
1742
1749
)
1743
1750
except AssertionError as e :
1751
+ (
1752
+ num_different_elements ,
1753
+ num_elements ,
1754
+ num_different_elements_percentage ,
1755
+ ) = is_close_stats (reference_output , tested_outputs [i ], rtol , atol )
1744
1756
print (
1745
- f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} "
1757
+ f"[ERROR] Output tensor mismatch between backends { tested_backends [0 ]} and { tested_backends [i ]} : "
1758
+ f"{ num_different_elements } / { num_elements } ({ num_different_elements_percentage :.2f} %) elements are different"
1746
1759
)
1747
1760
if not args .allow_output_mismatch :
1748
1761
print (e )
0 commit comments