@@ -9,20 +9,31 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
99 return dispatcher_;
1010};
1111
12- void PagedAttentionPrefill::execute (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
13- INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, q, k_cache, v_cache, block_tables, cache_lens);
12+ void PagedAttentionPrefill::execute (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
13+ Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
14+ std::optional<Tensor> alibi_slopes, float scale) {
15+ INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q);
16+
1417 infinicore::context::setDevice (out->device ());
15- dispatcher ().lookup (out->device ().getType ())(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
18+
19+ dispatcher ().lookup (out->device ().getType ())(out, q, k_cache, v_cache, block_tables,
20+ kv_lens, cum_seqlens_q, alibi_slopes, scale);
1621}
1722
18- Tensor paged_attention_prefill (Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
23+ Tensor paged_attention_prefill (Tensor q, Tensor k_cache, Tensor v_cache,
24+ Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
25+ std::optional<Tensor> alibi_slopes, float scale) {
26+
1927 auto out = Tensor::empty (q->shape (), q->dtype (), q->device ());
20- paged_attention_prefill_ (out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets , alibi_slopes, scale);
28+ paged_attention_prefill_ (out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q , alibi_slopes, scale);
2129 return out;
2230}
2331
24- void paged_attention_prefill_ (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
25- PagedAttentionPrefill::execute (out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
32+ void paged_attention_prefill_ (Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
33+ Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
34+ std::optional<Tensor> alibi_slopes, float scale) {
35+
36+ PagedAttentionPrefill::execute (out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
2637}
2738
2839} // namespace infinicore::op
0 commit comments