Skip to content

Commit 905f755

Browse files
authored
test: minor update on trtllm-gen attn speculative-decoding test (#1760)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues comments in #1453 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 50319b2 commit 905f755

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

tests/test_trtllm_gen_attention.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,6 @@ def test_trtllm_batch_decode(
536536
workspace_buffer = global_trtllm_gen_fmha_workspace_buffer
537537
workspace_buffer_ref = global_workspace_buffer
538538

539-
# Run reference wrapper
540-
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
541-
workspace_buffer_ref, kv_layout, use_tensor_cores=True
542-
)
543539
plan_params = {
544540
"indptr": kv_indptr,
545541
"indices": all_page_ids,
@@ -553,11 +549,14 @@ def test_trtllm_batch_decode(
553549
"q_data_type": ref_q.dtype,
554550
"window_left": window_left,
555551
}
556-
wrapper_ref.plan(**plan_params)
557-
output_ref = wrapper_ref.run(ref_q, ref_kv_cache)
558-
559-
if q_len_per_req > 1:
560-
# hide the output_ref from decode wrapper for speculative decoding test
552+
# Run reference wrapper
553+
if q_len_per_req == 1:
554+
wrapper_ref = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
555+
workspace_buffer_ref, kv_layout, use_tensor_cores=True
556+
)
557+
wrapper_ref.plan(**plan_params)
558+
output_ref = wrapper_ref.run(ref_q, ref_kv_cache)
559+
else:
561560
wrapper_ref = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper(
562561
workspace_buffer_ref, kv_layout
563562
)

0 commit comments

Comments
 (0)