Skip to content

Commit 4b30a91

Browse files
authored
Bugfix: some typos in Persistent kernel (#1562)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description tests pass ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent d6c3e33 commit 4b30a91

File tree

2 files changed

+2
-5
lines changed

2 files changed

+2
-5
lines changed

β€Žinclude/flashinfer/attention/persistent.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ struct BlockBatchPagedAttentionPersistent {
266266
const auto [q_indptr, kv_indptr, o_indptr, q_len, kv_len, packed_qo_start, kv_start, kv_end,
267267
kv_head_idx, len_kv_chunk] = get_block_coord(params, work_idx);
268268

269-
const uint32_t kv_chunk_idx = ceil_div(kv_start, len_kv_chunk);
269+
const uint32_t kv_chunk_idx = kv_start / len_kv_chunk;
270270
const uint32_t num_kv_chunks = ceil_div(
271271
CAUSAL
272272
? min((kv_len - q_len) + (packed_qo_start + cluster_tile_q) / gqa_group_size, kv_len)

β€Žinclude/flashinfer/attention/scheduler.cuh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,8 +1189,7 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
11891189
cluster_kv_start(num_clusters, std::vector<IdType>()),
11901190
cluster_kv_end(num_clusters, std::vector<IdType>()),
11911191
cluster_kv_head_idx(num_clusters, std::vector<IdType>()),
1192-
cluster_partial_indptr(num_clusters, std::vector<IdType>()),
1193-
cluster_len_kv_chunk(num_clusters, std::vector<IdType>());
1192+
cluster_partial_indptr(num_clusters, std::vector<IdType>());
11941193

11951194
for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec[task]) {
11961195
int packed_qo_len = qo_len * gqa_group_size;
@@ -1218,7 +1217,6 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
12181217
cluster_kv_indptr[cluster_idx].push_back(kv_indptr_h[i]);
12191218

12201219
// use kv_chunk to rematerize num_kv_tiles and kv_tile_idx
1221-
cluster_len_kv_chunk[cluster_idx].push_back(kv_len_limit);
12221220
cluster_partial_indptr[cluster_idx].push_back(partial_o_nnz);
12231221

12241222
cluster_q_start[cluster_idx].push_back(qo_tile_idx * cluster_tile_q);
@@ -1265,7 +1263,6 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
12651263
auto kv_start_vec = flatten(cluster_kv_start, total_num_works);
12661264
auto kv_end_vec = flatten(cluster_kv_end, total_num_works);
12671265
auto kv_head_idx_vec = flatten(cluster_kv_head_idx, total_num_works);
1268-
auto len_kv_chunk_vec = flatten(cluster_len_kv_chunk, total_num_works);
12691266

12701267
plan_info.tasks[task].q_indptr_offset =
12711268
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_indptr");

0 commit comments

Comments
Β (0)