Skip to content

Commit 9b861cd

Browse files
authored
bugfix: fix merge_attention_state in BatchAttention w/ gqa-group-size in Qwen family (#1614)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> This PR fixes precision issues of BatchAttention (Persistent FA2 of #1137), when `CTA_TILE_Q` is not a multiple of `gqa_group_size` (e.g., Qwen family models). Prior implementation assumes that all `qo_heads` of a `kv_head` on a specific token will all be split-kv or non-split-kv. However, when `gqa-group-size == 7`, some `qo_heads` can be non-split while the remaining can be split. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> cc @Edenzzzz
1 parent 8e926de commit 9b861cd

File tree

3 files changed

+30
-16
lines changed

3 files changed

+30
-16
lines changed

include/flashinfer/attention/persistent.cuh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ struct BlockBatchPagedAttentionPersistent {
269269
const uint32_t kv_chunk_idx = kv_start / len_kv_chunk;
270270
const uint32_t num_kv_chunks = ceil_div(
271271
CAUSAL
272-
? min((kv_len - q_len) + (packed_qo_start + cluster_tile_q) / gqa_group_size, kv_len)
272+
? min((kv_len - q_len) + ceil_div(packed_qo_start + cluster_tile_q, gqa_group_size),
273+
kv_len)
273274
: kv_len,
274275
len_kv_chunk);
275276
const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * CTA_TILE_Q +
@@ -517,23 +518,21 @@ struct BlockBatchReductionPersistent {
517518

518519
// remap workload
519520
uint32_t packed_qo_idx = i / num_kv_heads;
521+
uint32_t kv_head_idx = i % num_kv_heads;
520522
const uint32_t num_index_sets = indptr[packed_qo_idx + 1] - indptr[packed_qo_idx];
521523
if (num_index_sets == 0 || num_index_sets == 1) {
522524
// already write through, bypass
523525
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kReduction);
524526
continue;
525527
}
526528

527-
uint32_t kv_head_idx = i % num_kv_heads;
528-
uint32_t qo_head_idx = packed_qo_idx % gqa_group_size;
529-
530529
// index calculation
531530
auto partial_idx_to_offset = [&](uint32_t off) {
532531
return (indptr[packed_qo_idx] + off) * num_kv_heads + kv_head_idx;
533532
};
534533
auto merge_idx_to_offset = [&]() {
535-
return (o_indices[packed_qo_idx] * num_kv_heads + kv_head_idx) * gqa_group_size +
536-
qo_head_idx;
534+
// NOTE (Yilong): qo_head_idx has been calculated in schedule.plan
535+
return o_indices[packed_qo_idx] + kv_head_idx * gqa_group_size;
537536
};
538537

539538
state_t<vec_size> st;

include/flashinfer/attention/scheduler.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,8 +1235,11 @@ inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspa
12351235
// non-split kv is directly written through
12361236
for (int row = 0; row < row_tile_size; ++row) {
12371237
merge_indptr.push_back(merge_indptr.back() + num_kv_tiles);
1238-
merge_o_indices.push_back(qo_indptr_h[i] +
1239-
(qo_tile_idx * cluster_tile_q + row) / gqa_group_size);
1238+
// output layout: [qo_len, num_kv_heads, gqa_group_size, head_dim]
1239+
// merge_o_indices is the indices of `gqa_group_size` dimension
1240+
auto q = (qo_tile_idx * cluster_tile_q + row) / gqa_group_size,
1241+
r = (qo_tile_idx * cluster_tile_q + row) % gqa_group_size;
1242+
merge_o_indices.push_back((qo_indptr_h[i] + q) * num_kv_heads * gqa_group_size + r);
12401243
}
12411244
partial_o_nnz += row_tile_size * num_kv_tiles;
12421245
}

tests/test_batch_attention.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,15 @@ def _build_seq_len_configs():
5757
torch.manual_seed(42)
5858

5959
seq_len_configs = [
60+
[(146, 146)],
61+
[(67, 67)],
6062
[(8190, 7939)],
61-
[(2, 235)]
62-
+ [(1, 13353)], # corner case with a large number of masked out tokens
63-
[(67, 1)],
64-
[(182, 1)],
65-
[(2011, 1)],
6663
[(2048, 1)] * 77, # decode-only
6764
[(4099, 129)] * 2, # prefill-only
6865
[(600, 1)] * 132 * 2 + [(5000, 3)] * 128,
6966
[(1024, 1)] * 100 + [(8192, 17)] * 8, # speculative decode
7067
[(766, 2)] * 99 + [(1024, 512)] * 1, # chunked prefill
68+
[(2, 235)] + [(1, 13353)], # real workload
7169
]
7270

7371
# Construct random seqlen tests
@@ -142,7 +140,7 @@ def _run_attention(
142140

143141
# --------- old scheduler --------- #
144142
wrapper_old = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
145-
torch.empty(256 * 1024 * 1024, dtype=torch.uint8, device=dev),
143+
torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=dev),
146144
kv_layout=layout,
147145
backend="fa2",
148146
)
@@ -190,8 +188,8 @@ def _run_attention(
190188
# ------------------------- PyTest test case ----------------------------- #
191189
@pytest.mark.parametrize("seq_len_pairs", _build_seq_len_configs())
192190
@pytest.mark.parametrize("page_block_size", [1, 8, 16])
193-
@pytest.mark.parametrize("num_kv_heads", [8, 1, 4])
194-
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7])
191+
@pytest.mark.parametrize("num_kv_heads", [1, 4])
192+
@pytest.mark.parametrize("gqa_group_size", [1, 4, 7, 8])
195193
@pytest.mark.parametrize("head_dim", [64, 128, 256])
196194
@pytest.mark.parametrize("causal", [False, True])
197195
@pytest.mark.parametrize("layout", ["HND", "NHD"])
@@ -225,3 +223,17 @@ def test_batch_attention_correctness(
225223
logits_soft_cap=logits_soft_cap,
226224
device="cuda",
227225
)
226+
227+
228+
if __name__ == "__main__":
229+
test_batch_attention_correctness(
230+
seq_len_pairs=[(1000, 1000)],
231+
page_block_size=1,
232+
num_kv_heads=4,
233+
gqa_group_size=7,
234+
head_dim=128,
235+
causal=True,
236+
layout="NHD",
237+
test_dtype=torch.bfloat16,
238+
logits_soft_cap=0.0,
239+
)

0 commit comments

Comments
 (0)