Skip to content

Commit 0c850d3

Browse files
committed
Fixes to the benchmark attention wrapper
1 parent cf1b2d2 commit 0c850d3

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

benchmarks/routines/attention.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,15 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
799799
.int()
800800
.to(device)
801801
)
802+
qo_indptr_cudnn = torch.cat(
803+
[
804+
torch.tensor([0], device=device),
805+
torch.cumsum(actual_seq_lens_q_device.view(-1), dim=0)
806+
* head_dim_qk
807+
* num_qo_heads,
808+
]
809+
).int()
810+
802811
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
803812
kv_indptr = (
804813
torch.cat(

0 commit comments

Comments
 (0)