Skip to content

Commit f0b235c

Browse files
authored
Heuristics + testing unification + CUDA Graphs (#1306)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 238f1d2 commit f0b235c

File tree

4 files changed

+1435
-1524
lines changed

4 files changed

+1435
-1524
lines changed

flashinfer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,4 @@
117117
from .sampling import top_p_renorm_probs as top_p_renorm_probs
118118
from .sampling import top_p_sampling_from_probs as top_p_sampling_from_probs
119119
from .sparse import BlockSparseAttentionWrapper as BlockSparseAttentionWrapper
120+
from .utils import next_positive_power_of_2 as next_positive_power_of_2

flashinfer/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ def _expand_4d(x: torch.Tensor, kv_layout: str) -> torch.Tensor:
8484
return x
8585

8686

87+
def next_positive_power_of_2(x: int) -> int:
88+
if x < 1:
89+
return 1
90+
91+
# Following code is equivalent to 1 << (x - 1).bit_length()
92+
# But this impl does not contain bit_length() so can be used by torch compile.
93+
# It can correctly handle 64bit number which should be enough for now.
94+
n = x - 1
95+
n |= n >> 1
96+
n |= n >> 2
97+
n |= n >> 4
98+
n |= n >> 8
99+
n |= n >> 16
100+
n |= n >> 32
101+
return n + 1
102+
103+
87104
def _check_pos_encoding_mode(pos_encoding_mode: str) -> None:
88105
if not hasattr(PosEncodingMode, pos_encoding_mode):
89106
raise KeyError("Invalid pos_encoding_mode {}".format(pos_encoding_mode))

0 commit comments

Comments
 (0)