Skip to content

Commit df306f6

Browse files
authored
Revert "fix: remote redundant zero_init from trtllm-gen attn (#1444)" (#1459)
This reverts commit 5451029. <!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 5cd9805 commit df306f6

File tree

3 files changed

+19
-18
lines changed

3 files changed

+19
-18
lines changed

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ void trtllm_paged_attention_launcher(
151151
runner_params.multiCtasKvScratchPtr = reinterpret_cast<void*>(
152152
static_cast<char*>(workspace_buffer) + num_semaphores * sizeof(uint32_t));
153153
runner_params.multiCtasKvCounterPtr = reinterpret_cast<int32_t*>(workspace_buffer);
154+
zero_gmem_semaphore_launcher(runner_params.multiCtasKvCounterPtr, num_semaphores,
155+
/*enable_pdl=*/true, stream);
154156
}
155157

156158
auto [foundKernels, kinfo] = fmha_runner->isSupportedWithInfo(runner_params);

tests/test_trtllm_gen_decode.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def test_trtllm_batch_decode_mla(
469469

470470
# Allocate workspace buffer
471471
# todo(Yingyi): calculate the actual size of workspace buffer
472-
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8, device=device)
472+
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
473473

474474
bmm1_log2_scale_tensor = (
475475
torch.tensor(
@@ -487,22 +487,21 @@ def test_trtllm_batch_decode_mla(
487487
)
488488

489489
# Run decode-MLA
490-
for _ in range(3):
491-
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
492-
query=query,
493-
kv_cache=kv_cache.unsqueeze(1),
494-
workspace_buffer=workspace_buffer,
495-
qk_nope_head_dim=qk_nope_head_dim,
496-
kv_lora_rank=kv_lora_rank,
497-
qk_rope_head_dim=qk_rope_head_dim,
498-
block_tables=block_tables,
499-
seq_lens=seq_lens_tensor,
500-
max_seq_len=max_seq_len,
501-
bmm1_scale=scale / ((128 + 64) ** 0.5),
502-
bmm2_scale=1.0,
503-
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
504-
bmm2_scale_tensor=bmm2_scale_tensor,
505-
)
490+
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
491+
query=query,
492+
kv_cache=kv_cache.unsqueeze(1),
493+
workspace_buffer=workspace_buffer,
494+
qk_nope_head_dim=qk_nope_head_dim,
495+
kv_lora_rank=kv_lora_rank,
496+
qk_rope_head_dim=qk_rope_head_dim,
497+
block_tables=block_tables,
498+
seq_lens=seq_lens_tensor,
499+
max_seq_len=max_seq_len,
500+
bmm1_scale=scale / ((128 + 64) ** 0.5),
501+
bmm2_scale=1.0,
502+
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
503+
bmm2_scale_tensor=bmm2_scale_tensor,
504+
)
506505

507506
# Run reference attention and align output
508507
sm_scale = scale / (

version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.2.11
1+
0.2.11.post1

0 commit comments

Comments
 (0)