Skip to content

Commit 0b1d8a6

Browse files
Edenzzzzhappierpig
andauthored
Bugfix: fix o_strides in persistent kernel (#1865)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description It will be buggy when q is non-contiguous (from torch.split) ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: happierpig <[email protected]>
1 parent d3e9b44 commit 0b1d8a6

File tree

2 files changed

+58
-5
lines changed

2 files changed

+58
-5
lines changed

β€Žinclude/flashinfer/attention/persistent.cuhβ€Ž

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,13 @@ struct BlockBatchPagedAttentionPersistent {
408408
warp_idx, lane_idx, tid);
409409
} else {
410410
// write through
411+
// o_stride_n = num_qo_heads* head_dim
412+
const uint32_t o_stride_n = num_kv_heads * gqa_group_size * HEAD_DIM_VO,
413+
o_stride_h = HEAD_DIM_VO;
411414
DTypeO* o_ptr_base =
412-
params.final_o + q_indptr * q_stride_n + (kv_head_idx * gqa_group_size) * q_stride_h;
415+
params.final_o + q_indptr * o_stride_n + (kv_head_idx * gqa_group_size) * o_stride_h;
413416
write_o_reg_gmem<KTraits>(o_frag, &q_smem, o_ptr_base, qo_packed_idx_base, q_len,
414-
q_stride_n, q_stride_h, gqa_group_size, tid);
417+
o_stride_n, o_stride_h, gqa_group_size, tid);
415418
}
416419

417420
if constexpr (variant.use_softmax) {

β€Žtests/attention/test_batch_attention.pyβ€Ž

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def _run_attention(
9797
logits_soft_cap=0.0,
9898
device="cuda",
9999
causal=True,
100+
is_chunked_q=False,
100101
):
101102
"""
102103
Run both implementations and return (output_old, lse_old, output_new, lse_new)
@@ -116,9 +117,19 @@ def _run_attention(
116117

117118
num_blocks = kv_indptr[-1].item()
118119

119-
q = torch.rand(
120-
q_indptr[-1].item(), num_qo_heads, head_dim, dtype=test_dtype, device=dev
121-
)
120+
if is_chunked_q:
121+
q_base = torch.rand(
122+
q_indptr[-1].item(),
123+
num_qo_heads,
124+
head_dim * 2,
125+
dtype=test_dtype,
126+
device=dev,
127+
)
128+
q = torch.chunk(q_base, 2, dim=-1)[0]
129+
else:
130+
q = torch.rand(
131+
q_indptr[-1].item(), num_qo_heads, head_dim, dtype=test_dtype, device=dev
132+
)
122133
if layout == "NHD":
123134
kv_data = torch.randn(
124135
num_blocks,
@@ -190,6 +201,45 @@ def _run_attention(
190201

191202

192203
# ------------------------- PyTest test case ----------------------------- #
204+
@pytest.mark.xfail(
205+
get_compute_capability(torch.device(device="cuda"))[0] == 12,
206+
reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",
207+
)
208+
def test_batch_attention_with_noncontiguous_q():
209+
# Pick the first sequence length config's first pair
210+
seq_len_pairs = _build_seq_len_configs()[0]
211+
kv_lens = [p[0] for p in seq_len_pairs]
212+
qo_lens = [p[1] for p in seq_len_pairs]
213+
214+
# Fixed single-case parameters
215+
page_block_size = 1
216+
num_kv_heads = 1
217+
gqa_group_size = 1
218+
num_qo_heads = num_kv_heads * gqa_group_size
219+
head_dim = 64
220+
test_dtype = torch.bfloat16
221+
layout = "NHD"
222+
logits_soft_cap = 0.0
223+
v_scale = None
224+
causal = True
225+
226+
_run_attention(
227+
kv_lens=kv_lens,
228+
qo_lens=qo_lens,
229+
page_block_size=page_block_size,
230+
num_kv_heads=num_kv_heads,
231+
num_qo_heads=num_qo_heads,
232+
head_dim=head_dim,
233+
v_scale=v_scale,
234+
causal=causal,
235+
layout=layout,
236+
test_dtype=test_dtype,
237+
logits_soft_cap=logits_soft_cap,
238+
device="cuda",
239+
is_chunked_q=True,
240+
)
241+
242+
193243
@pytest.mark.xfail(
194244
get_compute_capability(torch.device(device="cuda"))[0] == 12,
195245
reason="Expected failure for SM120/121 for now since the tile size/number of stages is too large.",

0 commit comments

Comments
Β (0)