Skip to content

Commit 7fdae77

Browse files
authored
bugfix: Support passing kv_data_type to MultiLevelCascadeAttentionWrapper.plan() (#1350)
<!-- .github/pull_request_template.md --> ## 📌 Description `MultiLevelCascadeAttentionWrapper.plan()` ends up calling `plan()` on `BatchPrefillWithPagedKVCacheWrapper`. `BatchPrefillWithPagedKVCacheWrapper.plan()` supports `kv_data_type` but `MultiLevelCascadeAttentionWrapper.plan()` does not. ## 🔍 Related Issues Fixes vllm-project/vllm#21822 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 647d31e commit 7fdae77

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

flashinfer/cascade.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import functools
1818
from functools import cache
19-
from typing import Any, List, Optional, Tuple
19+
from typing import Any, List, Optional, Tuple, Union
2020

2121
import torch
2222

@@ -418,6 +418,7 @@ def plan(
418418
rope_scale: Optional[float] = None,
419419
rope_theta: Optional[float] = None,
420420
q_data_type: str = "float16",
421+
kv_data_type: Optional[Union[str, torch.dtype]] = None,
421422
):
422423
r"""Create auxiliary data structures for multi-level cascade attention for multiple
423424
forward calls within the same decode step. Please check
@@ -476,6 +477,8 @@ def plan(
476477
The theta used in RoPE, if not provided, will be set to ``1e4``.
477478
q_data_type : Optional[Union[str, torch.dtype]]
478479
The data type of the query tensor. If None, will be set to torch.float16.
480+
kv_data_type : Optional[Union[str, torch.dtype]]
481+
The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`.
479482
"""
480483
for i, (
481484
wrapper,
@@ -510,6 +513,7 @@ def plan(
510513
rope_scale=rope_scale,
511514
rope_theta=rope_theta,
512515
q_data_type=q_data_type,
516+
kv_data_type=kv_data_type,
513517
)
514518

515519
begin_forward = plan

0 commit comments

Comments
 (0)