Skip to content

Commit 30319e7

Browse files
weireweireyyihuang
andauthored
refactor: Unify and modularize decode and prefill test. (#1375)
<!-- .github/pull_request_template.md --> ## 📌 Description Unify decode and prefill attention test for trtllm-gen by extract common function. Support fp8/fp4 qkvo scale, o_sf, o_sf offset test. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Avery Yingyi Huang <[email protected]>
1 parent 3628a54 commit 30319e7

File tree

6 files changed

+756
-1110
lines changed

6 files changed

+756
-1110
lines changed

flashinfer/decode.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,13 @@ def run(
12881288

12891289
self._cached_module.paged_run(*run_args)
12901290
else:
1291+
# trtllm-gen does not need plan info
1292+
if self._backend == "trtllm-gen" and self._plan_info is None:
1293+
plan_info: List[int] = []
1294+
else:
1295+
plan_info = self._plan_info
1296+
assert plan_info is not None, "plan info is not initialized"
1297+
12911298
run_args = [
12921299
self._float_workspace_buffer,
12931300
self._int_workspace_buffer,

scripts/run_test_blackwell_attention_kernels.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ pytest -s tests/test_blackwell_fmha.py
77
pytest -s tests/test_deepseek_mla.py
88

99
# trtllm-gen
10-
pytest -s tests/test_trtllm_gen_context.py
11-
pytest -s tests/test_trtllm_gen_decode.py
10+
pytest -s tests/test_trtllm_gen_attention.py
11+
pytest -s tests/test_trtllm_gen_mla.py
1212

1313
# cudnn
1414
pytest -s tests/test_cudnn_decode.py

0 commit comments

Comments
 (0)