17
17
18
18
GPU_DEVICE = "cuda:0"
19
19
20
- global_workspace_buffer = None
20
+ global_workspace_buffer = None # can.be empty initialized
21
+ global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized
21
22
workspace_size = 128 * 1024 * 1024
22
23
23
24
@@ -320,16 +321,21 @@ def test_trtllm_batch_prefill(
320
321
else None
321
322
)
322
323
323
- global global_workspace_buffer
324
+ global global_workspace_buffer , global_trtllm_gen_fmha_workspace_buffer
324
325
if global_workspace_buffer is None :
325
- global_workspace_buffer = torch .zeros (
326
+ global_workspace_buffer = torch .empty (
326
327
workspace_size , dtype = torch .int8 , device = GPU_DEVICE
327
328
)
328
- workspace_buffer = global_workspace_buffer
329
+ if global_trtllm_gen_fmha_workspace_buffer is None :
330
+ global_trtllm_gen_fmha_workspace_buffer = torch .zeros (
331
+ workspace_size , dtype = torch .int8 , device = GPU_DEVICE
332
+ )
333
+ workspace_buffer_ref = global_workspace_buffer
334
+ workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
329
335
330
336
# Run reference wrapper
331
337
wrapper_ref = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
332
- workspace_buffer , kv_layout
338
+ workspace_buffer_ref , kv_layout
333
339
)
334
340
plan_params = {
335
341
"qo_indptr" : q_indptr ,
@@ -372,6 +378,9 @@ def test_trtllm_batch_prefill(
372
378
o_sf_vec_size = o_sf_vec_size ,
373
379
enable_pdl = enable_pdl ,
374
380
)
381
+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
382
+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
383
+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
375
384
376
385
if o_dtype == "nvfp4" :
377
386
output , output_ref = unpack_compare_nvfp4 (
@@ -414,6 +423,9 @@ def test_trtllm_batch_prefill(
414
423
torch .testing .assert_close (
415
424
output .float (), output_wrapper .float (), rtol = 1e-1 , atol = 1e-1
416
425
)
426
+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
427
+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
428
+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
417
429
418
430
419
431
@pytest .mark .parametrize ("kv_layout" , ["HND" ]) # trtllm-gen only support HND
@@ -505,16 +517,21 @@ def test_trtllm_batch_decode(
505
517
else None
506
518
)
507
519
508
- global global_workspace_buffer
520
+ global global_workspace_buffer , global_trtllm_gen_fmha_workspace_buffer
509
521
if global_workspace_buffer is None :
510
- global_workspace_buffer = torch .zeros (
522
+ global_workspace_buffer = torch .empty (
523
+ workspace_size , dtype = torch .int8 , device = GPU_DEVICE
524
+ )
525
+ if global_trtllm_gen_fmha_workspace_buffer is None :
526
+ global_trtllm_gen_fmha_workspace_buffer = torch .zeros (
511
527
workspace_size , dtype = torch .int8 , device = GPU_DEVICE
512
528
)
513
- workspace_buffer = global_workspace_buffer
529
+ workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
530
+ workspace_buffer_ref = global_workspace_buffer
514
531
515
532
# Run reference wrapper
516
533
wrapper_ref = flashinfer .decode .BatchDecodeWithPagedKVCacheWrapper (
517
- workspace_buffer , kv_layout , use_tensor_cores = True
534
+ workspace_buffer_ref , kv_layout , use_tensor_cores = True
518
535
)
519
536
plan_params = {
520
537
"indptr" : kv_indptr ,
@@ -535,7 +552,7 @@ def test_trtllm_batch_decode(
535
552
if q_len_per_req > 1 :
536
553
# hide the output_ref from decode wrapper for speculative decoding test
537
554
wrapper_ref = flashinfer .prefill .BatchPrefillWithPagedKVCacheWrapper (
538
- workspace_buffer , kv_layout
555
+ workspace_buffer_ref , kv_layout
539
556
)
540
557
plan_params_prefill = {
541
558
"qo_indptr" : q_indptr ,
@@ -576,6 +593,9 @@ def test_trtllm_batch_decode(
576
593
enable_pdl = enable_pdl ,
577
594
q_len_per_req = q_len_per_req ,
578
595
)
596
+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
597
+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
598
+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
579
599
580
600
if o_dtype == "nvfp4" :
581
601
output , output_ref = unpack_compare_nvfp4 (
@@ -648,6 +668,9 @@ def test_trtllm_batch_decode(
648
668
atol = 1e-1 ,
649
669
max_mismatched_elements = 5 ,
650
670
)
671
+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
672
+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
673
+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
651
674
652
675
653
676
@pytest .mark .parametrize ("batch_size" , [4 , 128 , 256 ])
@@ -699,7 +722,17 @@ def test_trtllm_gen_prefill_deepseek(
699
722
# Initialize scale
700
723
scale = float (1.0 / (head_dim_qk ** 0.5 ))
701
724
702
- workspace_buffer = torch .empty (workspace_size , dtype = torch .int8 , device = device )
725
+ global global_workspace_buffer , global_trtllm_gen_fmha_workspace_buffer
726
+ if global_workspace_buffer is None :
727
+ global_workspace_buffer = torch .empty (
728
+ workspace_size , dtype = torch .int8 , device = device
729
+ )
730
+ if global_trtllm_gen_fmha_workspace_buffer is None :
731
+ global_trtllm_gen_fmha_workspace_buffer = torch .zeros (
732
+ workspace_size , dtype = torch .int8 , device = device
733
+ )
734
+ workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
735
+ workspace_buffer_ref = global_workspace_buffer
703
736
704
737
qo_indptr = torch .cat (
705
738
[
@@ -722,7 +755,7 @@ def test_trtllm_gen_prefill_deepseek(
722
755
).int ()
723
756
724
757
wrapper = flashinfer .prefill .BatchPrefillWithRaggedKVCacheWrapper (
725
- torch . zeros ( workspace_size , device = "cuda" , dtype = torch . uint8 ) ,
758
+ workspace_buffer_ref ,
726
759
kv_layout = "NHD" ,
727
760
backend = "cutlass" ,
728
761
)
@@ -775,6 +808,9 @@ def test_trtllm_gen_prefill_deepseek(
775
808
atol = 1e-3 ,
776
809
rtol = 1e-3 ,
777
810
)
811
+ # check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
812
+ # note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
813
+ assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
778
814
779
815
780
816
if __name__ == "__main__" :
0 commit comments