Skip to content

Commit 7b3cb4c

Browse files
committed
Fixes to the benchmark attention wrapper
1 parent 49d744d commit 7b3cb4c

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

benchmarks/routines/attention.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,15 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
814814
.int()
815815
.to(device)
816816
)
817+
qo_indptr_cudnn = torch.cat(
818+
[
819+
torch.tensor([0], device=device),
820+
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
821+
* head_dim_qk
822+
* num_qo_heads,
823+
]
824+
).int()
825+
817826
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
818827
kv_indptr = (
819828
torch.cat(
@@ -935,12 +944,14 @@ def run_backend_wrapper(backend):
935944
workspace_buffer,
936945
max_token_per_sequence=s_qo,
937946
max_sequence_kv=s_kv,
938-
actual_seq_lens_q=actual_seq_lens_q,
939-
actual_seq_lens_kv=actual_seq_lens_kv,
947+
actual_seq_lens_q=actual_seq_lens_q_device,
948+
actual_seq_lens_kv=actual_seq_lens_kv_device,
940949
block_tables=block_tables,
941950
causal=causal,
942951
return_lse=True,
943952
is_cuda_graph_compatible=is_cuda_graph_compatible,
953+
batch_offsets_q=qo_indptr_cudnn,
954+
batch_offsets_o=qo_indptr_cudnn,
944955
)[0]
945956
elif backend == "fa2":
946957
return fi_fa2_paged_wrapper.run(

0 commit comments

Comments
 (0)