Skip to content

Commit 0305341

Browse files
authored
fix: remove redundant zero_init reverted by #1459 (#1463)
<!-- .github/pull_request_template.md --> ## 📌 Description The duplicate zero_init should be fixed. But we got some crash reported from DLFW. So we revert it in #1459 and make 0.2.11.post1. After this fix, **workspace buffer passed into any trtllm-gen attn interface must be zero-initialized**. This PR is to enable this optimization. It should be merged and released only after these two are tested. - sgl-project/sglang#9065 - vllm-project/vllm#22603 ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7ce448b commit 0305341

File tree

6 files changed

+38
-67
lines changed

6 files changed

+38
-67
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include <flashinfer/trtllm/fmha/fmhaRunnerParams.h>
2121
#include <nvrtc.h>
2222

23-
#include <flashinfer/semaphore_utils.cuh>
2423
#include <flashinfer/trtllm/fmha/fmhaRunner.cuh>
2524
#include <flashinfer/trtllm/fmha/gen_kernel_launcher.cuh>
2625
#include <flashinfer/utils.cuh>
@@ -146,13 +145,13 @@ void trtllm_paged_attention_launcher(
146145
use_multi_block ? TileScheduler::Static : TileScheduler::Persistent;
147146
runner_params.mMultiCtasKvMode = use_multi_block;
148147

148+
size_t max_batch_size = 8192; // todo(Yingyi): get from dlfw
149+
size_t max_num_qo_heads = 256; // todo(Yingyi): get from dlfw, in total 8MB
149150
size_t num_semaphores =
150-
round_up(batch_size * num_qo_heads, 8); // align multiCtasKvScratchPtr to 16 bytes
151+
round_up(max_batch_size * max_num_qo_heads, 8); // max 8MB, should align to 16 bytes
151152
runner_params.multiCtasKvScratchPtr = reinterpret_cast<void*>(
152153
static_cast<char*>(workspace_buffer) + num_semaphores * sizeof(uint32_t));
153154
runner_params.multiCtasKvCounterPtr = reinterpret_cast<int32_t*>(workspace_buffer);
154-
zero_gmem_semaphore_launcher(runner_params.multiCtasKvCounterPtr, num_semaphores,
155-
/*enable_pdl=*/true, stream);
156155
}
157156

158157
auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);

flashinfer/decode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ class BatchDecodeWithPagedKVCacheWrapper:
593593
>>> max_num_pages = 128
594594
>>> page_size = 16
595595
>>> # allocate 128MB workspace buffer
596-
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
596+
>>> workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
597597
>>> decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
598598
... workspace_buffer, "NHD"
599599
... )
@@ -658,7 +658,7 @@ def __init__(
658658
659659
Parameters
660660
----------
661-
float_workspace_buffer : torch.Tensor
661+
float_workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
662662
The user reserved float workspace buffer used to store intermediate attention results
663663
in the split-k algorithm. The recommended size is 128MB, the device of the workspace
664664
buffer should be the same as the device of the input tensors.
@@ -2000,7 +2000,7 @@ def trtllm_batch_decode_with_kv_cache(
20002000
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim]
20012001
If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim]
20022002
2003-
workspace_buffer : torch.Tensor
2003+
workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
20042004
workspace
20052005
20062006
block_tables : torch.Tensor
@@ -2198,7 +2198,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
21982198
Parameters:
21992199
query: [batch_size, q_len_per_request, num_heads, head_dim_qk], head_dim_qk = qk_nope_head_dim (kv_lora_rank) + qk_rope_head_dim, should be concated q_nope + q_rope; q_len_per_request is the MTP query length.
22002200
kv_cache: [num_pages, page_size, head_dim_ckv + head_dim_kpe], should be concated ckv_cache + kpe_cache
2201-
workspace_buffer: [num_semaphores, 4], used for multi_block mode
2201+
workspace_buffer: [num_semaphores, 4], used for multi_block mode. Must be initialized to 0 for its first use.
22022202
qk_nope_head_dim: qk_nope_head_dim, must be 128
22032203
kv_lora_rank: kv_lora_rank, must be 512
22042204
qk_rope_head_dim: qk_rope_head_dim, must be 64

flashinfer/prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,7 +1215,7 @@ class BatchPrefillWithPagedKVCacheWrapper:
12151215
>>> max_num_pages = 128
12161216
>>> page_size = 16
12171217
>>> # allocate 128MB workspace buffer
1218-
>>> workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
1218+
>>> workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
12191219
>>> prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
12201220
... workspace_buffer, "NHD"
12211221
... )
@@ -3144,7 +3144,7 @@ def trtllm_batch_context_with_kv_cache(
31443144
kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
31453145
If kv_cache is a single tensor, it should be a tensor with shape [num_pages, 1 or 2, num_kv_heads, page_size, head_dim]
31463146
If kv_cache is a tuple of two tensors, it should be a tuple of two tensors with shape [num_pages, num_kv_heads, page_size, head_dim]
3147-
workspace_buffer : torch.Tensor
3147+
workspace_buffer : torch.Tensor. Must be initialized to 0 for its first use.
31483148
workspace
31493149
block_tables : torch.Tensor
31503150
page_table of kv cache, [batch_size, num_pages]

include/flashinfer/semaphore_utils.cuh

Lines changed: 0 additions & 53 deletions
This file was deleted.

tests/test_trtllm_gen_context.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import flashinfer
88
from flashinfer.utils import FP4Tensor
99

10+
global_workspace_buffer = None
11+
1012

1113
def flip_coin(*args, **kwargs):
1214
# Use any test parameters to deterministically decide branch
@@ -97,7 +99,12 @@ def test_trtllm_batch_context_wrapper(
9799
kv_last_page_len_cpu = torch.full(
98100
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32
99101
)
100-
workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
102+
global global_workspace_buffer
103+
if global_workspace_buffer is None:
104+
global_workspace_buffer = torch.zeros(
105+
256 * 1024 * 1024, dtype=torch.int8, device="cuda:0"
106+
)
107+
workspace_buffer = global_workspace_buffer
101108

102109
# reference
103110
q_indptr_gpu = q_indptr_cpu.to(device)
@@ -337,7 +344,13 @@ def test_trtllm_batch_prefill(
337344
o_sf_vec_size = 16 if o_dtype == "nvfp4" else None
338345
sm_scale = float(1.0 / (head_dim**0.5))
339346

340-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
347+
global global_workspace_buffer
348+
if global_workspace_buffer is None:
349+
global_workspace_buffer = torch.zeros(
350+
128 * 1024 * 1024, dtype=torch.int8, device="cuda:0"
351+
)
352+
workspace_buffer = global_workspace_buffer
353+
341354
q_indptr = torch.cat(
342355
[
343356
torch.tensor([0], dtype=torch.int32, device=device),

tests/test_trtllm_gen_decode.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import flashinfer
99
from flashinfer.utils import FP4Tensor
1010

11+
global_workspace_buffer = None
12+
1113

1214
def flip_coin(*args, **kwargs):
1315
# Use any test parameters to deterministically decide branch
@@ -235,7 +237,12 @@ def test_trtllm_batch_decode_fmha(
235237

236238
sm_scale = float(1.0 / (head_dim**0.5))
237239

238-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
240+
global global_workspace_buffer
241+
if global_workspace_buffer is None:
242+
global_workspace_buffer = torch.zeros(
243+
128 * 1024 * 1024, dtype=torch.int8, device="cuda:0"
244+
)
245+
workspace_buffer = global_workspace_buffer
239246

240247
# Compute kv_indptr as cumulative sum of blocks per sequence
241248
kv_indptr = torch.cat(
@@ -469,7 +476,12 @@ def test_trtllm_batch_decode_mla(
469476

470477
# Allocate workspace buffer
471478
# todo(Yingyi): calculate the actual size of workspace buffer
472-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
479+
global global_workspace_buffer
480+
if global_workspace_buffer is None:
481+
global_workspace_buffer = torch.zeros(
482+
128 * 1024 * 1024, dtype=torch.int8, device="cuda:0"
483+
)
484+
workspace_buffer = global_workspace_buffer
473485

474486
bmm1_log2_scale_tensor = (
475487
torch.tensor(

0 commit comments

Comments
 (0)