File tree Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Expand file tree Collapse file tree 1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -814,6 +814,15 @@ def testBatchPrefillWithPagedKVCacheWrapper(args):
814
814
.int ()
815
815
.to (device )
816
816
)
817
+ qo_indptr_cudnn = torch .cat (
818
+ [
819
+ torch .tensor ([0 ], device = device ),
820
+ torch .cumsum (actual_seq_lens_q_device .view (- 1 ), dim = 0 )
821
+ * head_dim_qk
822
+ * num_qo_heads ,
823
+ ]
824
+ ).int ()
825
+
817
826
# Because actual_seq_lens_kv is the same as actual_seq_lens_q, kv_indptr will become the same as qo_indptr
818
827
kv_indptr = (
819
828
torch .cat (
@@ -935,12 +944,14 @@ def run_backend_wrapper(backend):
935
944
workspace_buffer ,
936
945
max_token_per_sequence = s_qo ,
937
946
max_sequence_kv = s_kv ,
938
- actual_seq_lens_q = actual_seq_lens_q ,
939
- actual_seq_lens_kv = actual_seq_lens_kv ,
947
+ actual_seq_lens_q = actual_seq_lens_q_device ,
948
+ actual_seq_lens_kv = actual_seq_lens_kv_device ,
940
949
block_tables = block_tables ,
941
950
causal = causal ,
942
951
return_lse = True ,
943
952
is_cuda_graph_compatible = is_cuda_graph_compatible ,
953
+ batch_offsets_q = qo_indptr_cudnn ,
954
+ batch_offsets_o = qo_indptr_cudnn ,
944
955
)[0 ]
945
956
elif backend == "fa2" :
946
957
return fi_fa2_paged_wrapper .run (
You can’t perform that action at this time.
0 commit comments