Skip to content

Commit 5ec5371

Browse files
authored
test: Add trtllm-gen prefill test. Fix related wrapper issue. (#1346)
<!-- .github/pull_request_template.md --> ## 📌 Description Make the old prefill test cover the direct function call and fp8. Also add a new prefill test, which is more comprehensive and have very similar structure to decode test, I plan to merge prefill/decode test after added the output_dtype API in prefill test. And we may able to remove the old test Also fix some wrapper issue captured. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 7fdae77 commit 5ec5371

File tree

3 files changed

+394
-58
lines changed

3 files changed

+394
-58
lines changed

flashinfer/prefill.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,6 +1812,7 @@ def run(
18121812
q: torch.Tensor,
18131813
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
18141814
*args,
1815+
q_scale: Optional[float] = None,
18151816
k_scale: Optional[float] = None,
18161817
v_scale: Optional[float] = None,
18171818
out: Optional[torch.Tensor] = None,
@@ -1890,6 +1891,8 @@ def run(
18901891
logits_soft_cap = 0.0
18911892
if sm_scale is None:
18921893
sm_scale = 1.0 / math.sqrt(q.size(-1))
1894+
if q_scale is not None:
1895+
sm_scale *= q_scale
18931896
if k_scale is not None:
18941897
sm_scale *= k_scale
18951898
if rope_scale is None:
@@ -1994,8 +1997,6 @@ def run(
19941997
]
19951998

19961999
self._cached_module.paged_run(*run_args)
1997-
if v_scale is not None:
1998-
out *= v_scale
19992000

20002001
return (out, lse) if return_lse else out
20012002

0 commit comments

Comments
 (0)