Skip to content

Commit 5451029

Browse files
authored
fix: remote redundant zero_init from trtllm-gen attn (#1444)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent bfccf68 commit 5451029

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ void trtllm_paged_attention_launcher(
151151
runner_params.multiCtasKvScratchPtr = reinterpret_cast<void*>(
152152
static_cast<char*>(workspace_buffer) + num_semaphores * sizeof(uint32_t));
153153
runner_params.multiCtasKvCounterPtr = reinterpret_cast<int32_t*>(workspace_buffer);
154-
zero_gmem_semaphore_launcher(runner_params.multiCtasKvCounterPtr, num_semaphores,
155-
/*enable_pdl=*/true, stream);
156154
}
157155

158156
auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);

tests/test_trtllm_gen_decode.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def test_trtllm_batch_decode_mla(
465465

466466
# Allocate workspace buffer
467467
# todo(Yingyi): calculate the actual size of workspace buffer
468-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
468+
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8, device=device)
469469

470470
bmm1_log2_scale_tensor = (
471471
torch.tensor(
@@ -483,21 +483,22 @@ def test_trtllm_batch_decode_mla(
483483
)
484484

485485
# Run decode-MLA
486-
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
487-
query=query,
488-
kv_cache=kv_cache.unsqueeze(1),
489-
workspace_buffer=workspace_buffer,
490-
qk_nope_head_dim=qk_nope_head_dim,
491-
kv_lora_rank=kv_lora_rank,
492-
qk_rope_head_dim=qk_rope_head_dim,
493-
block_tables=block_tables,
494-
seq_lens=seq_lens_tensor,
495-
max_seq_len=max_seq_len,
496-
bmm1_scale=scale / ((128 + 64) ** 0.5),
497-
bmm2_scale=1.0,
498-
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
499-
bmm2_scale_tensor=bmm2_scale_tensor,
500-
)
486+
for _ in range(3):
487+
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
488+
query=query,
489+
kv_cache=kv_cache.unsqueeze(1),
490+
workspace_buffer=workspace_buffer,
491+
qk_nope_head_dim=qk_nope_head_dim,
492+
kv_lora_rank=kv_lora_rank,
493+
qk_rope_head_dim=qk_rope_head_dim,
494+
block_tables=block_tables,
495+
seq_lens=seq_lens_tensor,
496+
max_seq_len=max_seq_len,
497+
bmm1_scale=scale / ((128 + 64) ** 0.5),
498+
bmm2_scale=1.0,
499+
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
500+
bmm2_scale_tensor=bmm2_scale_tensor,
501+
)
501502

502503
# Run reference attention and align output
503504
sm_scale = scale / (

0 commit comments

Comments
 (0)