Skip to content

Commit bd487ee

Browse files
authored
fix: pass workspace for trtllm-gen attention (#1635)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 1649e23 commit bd487ee

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

flashinfer/decode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1929,6 +1929,7 @@ def paged_run(
19291929
assert page_size is not None
19301930
assert max_kv_len is not None
19311931
assert enable_pdl is not None
1932+
assert workspace_size > 0, "workspace_size must be greater than 0"
19321933
o = module._paged_run(
19331934
q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect
19341935
paged_k_cache,

flashinfer/prefill.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,7 @@ def paged_run(
554554
assert cum_seq_lens_q is not None
555555
assert cum_seq_lens_kv is not None
556556
assert enable_pdl is not None
557+
assert workspace_size > 0, "workspace_size must be greater than 0"
557558
o = paged_run_func(
558559
q.contiguous(), # NOTE(Siyuan): without contiguous, the result is incorrect
559560
paged_k_cache,
@@ -2147,8 +2148,8 @@ def run(
21472148
None, # scale_v
21482149
rope_scale,
21492150
rope_theta,
2150-
self._workspace_size,
21512151
self._token_pos_in_items_len,
2152+
self._workspace_size,
21522153
self._num_qo_heads,
21532154
self._num_kv_heads,
21542155
self._block_tables,

0 commit comments

Comments
 (0)