Skip to content

Commit e462997

Browse files
authored
fix: zero-init workspace buffer for trtllm-gen fmha (#1643)
1 parent 5ad2323 commit e462997

File tree

2 files changed

+61
-17
lines changed

2 files changed

+61
-17
lines changed

tests/test_trtllm_gen_attention.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
GPU_DEVICE = "cuda:0"
1919

20-
global_workspace_buffer = None
20+
global_workspace_buffer = None # can.be empty initialized
21+
global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized
2122
workspace_size = 128 * 1024 * 1024
2223

2324

@@ -320,16 +321,21 @@ def test_trtllm_batch_prefill(
320321
else None
321322
)
322323

323-
global global_workspace_buffer
324+
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
324325
if global_workspace_buffer is None:
325-
global_workspace_buffer = torch.zeros(
326+
global_workspace_buffer = torch.empty(
326327
workspace_size, dtype=torch.int8, device=GPU_DEVICE
327328
)
328-
workspace_buffer = global_workspace_buffer
329+
if global_trtllm_gen_fmha_workspace_buffer is None:
330+
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
331+
workspace_size, dtype=torch.int8, device=GPU_DEVICE
332+
)
333+
workspace_buffer_ref = global_workspace_buffer
334+
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
329335

330336
# Run reference wrapper
331337
wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
332-
workspace_buffer, kv_layout
338+
workspace_buffer_ref, kv_layout
333339
)
334340
plan_params = {
335341
"qo_indptr": q_indptr,
@@ -372,6 +378,9 @@ def test_trtllm_batch_prefill(
372378
o_sf_vec_size=o_sf_vec_size,
373379
enable_pdl=enable_pdl,
374380
)
381+
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
382+
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
383+
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
375384

376385
if o_dtype == "nvfp4":
377386
output, output_ref = unpack_compare_nvfp4(
@@ -414,6 +423,9 @@ def test_trtllm_batch_prefill(
414423
torch.testing.assert_close(
415424
output.float(), output_wrapper.float(), rtol=1e-1, atol=1e-1
416425
)
426+
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
427+
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
428+
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
417429

418430

419431
@pytest.mark.parametrize("kv_layout", ["HND"]) # trtllm-gen only support HND
@@ -505,16 +517,21 @@ def test_trtllm_batch_decode(
505517
else None
506518
)
507519

508-
global global_workspace_buffer
520+
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
509521
if global_workspace_buffer is None:
510-
global_workspace_buffer = torch.zeros(
522+
global_workspace_buffer = torch.empty(
523+
workspace_size, dtype=torch.int8, device=GPU_DEVICE
524+
)
525+
if global_trtllm_gen_fmha_workspace_buffer is None:
526+
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
511527
workspace_size, dtype=torch.int8, device=GPU_DEVICE
512528
)
513-
workspace_buffer = global_workspace_buffer
529+
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
530+
workspace_buffer_ref = global_workspace_buffer
514531

515532
# Run reference wrapper
516533
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
517-
workspace_buffer, kv_layout, use_tensor_cores=True
534+
workspace_buffer_ref, kv_layout, use_tensor_cores=True
518535
)
519536
plan_params = {
520537
"indptr": kv_indptr,
@@ -535,7 +552,7 @@ def test_trtllm_batch_decode(
535552
if q_len_per_req > 1:
536553
# hide the output_ref from decode wrapper for speculative decoding test
537554
wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
538-
workspace_buffer, kv_layout
555+
workspace_buffer_ref, kv_layout
539556
)
540557
plan_params_prefill = {
541558
"qo_indptr": q_indptr,
@@ -576,6 +593,9 @@ def test_trtllm_batch_decode(
576593
enable_pdl=enable_pdl,
577594
q_len_per_req=q_len_per_req,
578595
)
596+
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
597+
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
598+
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
579599

580600
if o_dtype == "nvfp4":
581601
output, output_ref = unpack_compare_nvfp4(
@@ -648,6 +668,9 @@ def test_trtllm_batch_decode(
648668
atol=1e-1,
649669
max_mismatched_elements=5,
650670
)
671+
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
672+
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
673+
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
651674

652675

653676
@pytest.mark.parametrize("batch_size", [4, 128, 256])
@@ -699,7 +722,17 @@ def test_trtllm_gen_prefill_deepseek(
699722
# Initialize scale
700723
scale = float(1.0 / (head_dim_qk**0.5))
701724

702-
workspace_buffer = torch.empty(workspace_size, dtype=torch.int8, device=device)
725+
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
726+
if global_workspace_buffer is None:
727+
global_workspace_buffer = torch.empty(
728+
workspace_size, dtype=torch.int8, device=device
729+
)
730+
if global_trtllm_gen_fmha_workspace_buffer is None:
731+
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
732+
workspace_size, dtype=torch.int8, device=device
733+
)
734+
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
735+
workspace_buffer_ref = global_workspace_buffer
703736

704737
qo_indptr = torch.cat(
705738
[
@@ -722,7 +755,7 @@ def test_trtllm_gen_prefill_deepseek(
722755
).int()
723756

724757
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
725-
torch.zeros(workspace_size, device="cuda", dtype=torch.uint8),
758+
workspace_buffer_ref,
726759
kv_layout="NHD",
727760
backend="cutlass",
728761
)
@@ -775,6 +808,9 @@ def test_trtllm_gen_prefill_deepseek(
775808
atol=1e-3,
776809
rtol=1e-3,
777810
)
811+
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
812+
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
813+
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
778814

779815

780816
if __name__ == "__main__":

tests/test_trtllm_gen_mla.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import flashinfer
77

8-
global_workspace_buffer = None
8+
global_workspace_buffer = None # can.be empty initialized
9+
global_trtllm_gen_fmha_workspace_buffer = None # must be zero initialized
910
workspace_size = 128 * 1024 * 1024
1011

1112

@@ -96,12 +97,17 @@ def test_trtllm_batch_decode_mla(
9697

9798
# Allocate workspace buffer
9899
# todo(Yingyi): calculate the actual size of workspace buffer
99-
global global_workspace_buffer
100+
global global_workspace_buffer, global_trtllm_gen_fmha_workspace_buffer
100101
if global_workspace_buffer is None:
101-
global_workspace_buffer = torch.zeros(
102+
global_workspace_buffer = torch.empty(
102103
workspace_size, dtype=torch.int8, device=device
103104
)
104-
workspace_buffer = global_workspace_buffer
105+
if global_trtllm_gen_fmha_workspace_buffer is None:
106+
global_trtllm_gen_fmha_workspace_buffer = torch.zeros(
107+
workspace_size, dtype=torch.int8, device=device
108+
)
109+
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
110+
workspace_buffer_ref = global_workspace_buffer
105111

106112
bmm1_log2_scale_tensor = (
107113
torch.tensor(
@@ -135,12 +141,14 @@ def test_trtllm_batch_decode_mla(
135141
bmm2_scale_tensor=bmm2_scale_tensor,
136142
enable_pdl=enable_pdl,
137143
)
144+
# check if the first 8192 * 256 * 4 bytes of workspace_buffer is zero
145+
# note(Yingyi): the first 8192 * 256 * 4 bytes of workspace_buffer is the counter workspace, size might change in the future
146+
assert (workspace_buffer[: 8192 * 256 * 4].cpu().numpy() == 0).all()
138147

139148
# Run reference attention and align output
140149
sm_scale = scale / (
141150
(128 + 64) ** 0.5
142151
) # use head dimension before matrix absorption
143-
workspace_buffer_ref = torch.empty(workspace_size, dtype=torch.int8, device=device)
144152
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
145153
workspace_buffer_ref,
146154
backend="fa2",

0 commit comments

Comments
 (0)