Skip to content

Commit 499b1dc

Browse files
issue/867 pass total kv lens as paged attn args
1 parent 0a2839a commit 499b1dc

File tree

14 files changed

+136
-122
lines changed

14 files changed

+136
-122
lines changed

include/infinicore/ops/paged_attention_prefill.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ class PagedAttentionPrefill {
1616
* 3. k_cache: Physical Key cache (Paged format)
1717
* 4. v_cache: Physical Value cache (Paged format)
1818
* 5. block_tables: Mapping table from logical blocks to physical blocks
19-
* 6. history_lens: Historical KV lengths (existing length of each sequence in cache)
19+
* 6. total_kv_lens: lengths of Complete Key/Value for each request
2020
* 7. cu_seqlens_q: Cumulative sequence lengths of Query (prefix sum for variable-length batch)
2121
* 8. alibi_slopes: ALiBi bias slopes (optional)
2222
* 9. scale: Scaling factor (typically 1/sqrt(head_size))
2323
*/
2424
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
2525

2626
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
27-
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
27+
Tensor block_tables, Tensor total_kv_lens, Tensor cum_seqlens_q,
2828
std::optional<Tensor> alibi_slopes, float scale);
2929

3030
static common::OpDispatcher<schema> &dispatcher();
@@ -34,8 +34,8 @@ Tensor paged_attention_prefill(Tensor q,
3434
Tensor k_cache,
3535
Tensor v_cache,
3636
Tensor block_tables,
37-
Tensor history_lens,
38-
Tensor cu_seqlens_q,
37+
Tensor total_kv_lens,
38+
Tensor cum_seqlens_q,
3939
std::optional<Tensor> alibi_slopes,
4040
float scale);
4141

@@ -44,8 +44,8 @@ void paged_attention_prefill_(Tensor out,
4444
Tensor k_cache,
4545
Tensor v_cache,
4646
Tensor block_tables,
47-
Tensor history_lens,
48-
Tensor cu_seqlens_q,
47+
Tensor total_kv_lens,
48+
Tensor cum_seqlens_q,
4949
std::optional<Tensor> alibi_slopes,
5050
float scale);
5151

include/infiniop/ops/paged_attention_prefill.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;
2020
* Shape: [max_num_blocks, num_kv_heads, block_size, head_size]
2121
* @param block_tables_desc Descriptor for the block tables mapping logic to physical blocks.
2222
* Shape: [batch_size, max_blocks_per_seq]
23-
* @param history_lens_desc Descriptor for the KV history lengths of each sequence.
23+
* @param seq_lens_desc Descriptor for the total KV lengths of each sequence.
2424
* Shape: [batch_size]
2525
* @param cum_seq_lens_q_desc Descriptor for the cumulative start position (prefix sum) of each Q sequence.
2626
* Shape: [batch_size + 1]
@@ -37,7 +37,7 @@ __C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
3737
infiniopTensorDescriptor_t k_cache_desc,
3838
infiniopTensorDescriptor_t v_cache_desc,
3939
infiniopTensorDescriptor_t block_tables_desc,
40-
infiniopTensorDescriptor_t history_lens_desc,
40+
infiniopTensorDescriptor_t seq_lens_desc,
4141
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
4242
infiniopTensorDescriptor_t alibi_slopes_desc,
4343
float scale);
@@ -58,7 +58,7 @@ __C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
5858
* @param k_cache Pointer to the global key cache data.
5959
* @param v_cache Pointer to the global value cache data.
6060
* @param block_tables Pointer to the block tables data.
61-
* @param history_lens Pointer to the KV history lengths data.
61+
* @param seq_lens Pointer to the KV lengths data.
6262
* @param cum_seq_lens_q Pointer to the Q cumulative sequence lengths data (prefix sum).
6363
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
6464
* @param stream The device stream (e.g., cudaStream_t) for the operation.
@@ -73,7 +73,7 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill(
7373
const void *k_cache,
7474
const void *v_cache,
7575
const void *block_tables,
76-
const void *history_lens,
76+
const void *seq_lens,
7777
const void *cum_seq_lens_q,
7878
const void *alibi_slopes,
7979
void *stream);

src/infinicore/ops/paged_attention/paged_attention.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
99
return dispatcher_;
1010
};
1111

12-
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
13-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
12+
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
13+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens);
1414
infinicore::context::setDevice(out->device());
15-
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
15+
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
1616
}
1717

18-
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
18+
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
1919
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
20-
paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
20+
paged_attention_(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
2121
return out;
2222
}
2323

24-
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
25-
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
24+
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
25+
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
2626
}
2727

2828
} // namespace infinicore::op

src/infinicore/ops/paged_attention/paged_attention_infiniop.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
1515
}
1616
});
1717

18-
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
19-
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
18+
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor kv_lens, std::optional<Tensor> alibi_slopes, float scale) {
19+
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, alibi_slopes, scale);
2020

2121
auto device = context::getDevice();
2222
auto &cache = caches.getCache(device);
@@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
2727
if (!desc_opt) {
2828
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
2929
context::getInfiniopHandle(device), &desc,
30-
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(),
30+
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), kv_lens->desc(),
3131
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
3232
scale));
3333
cache.put(seed, desc);
@@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
4141

4242
INFINICORE_CHECK_ERROR(infiniopPagedAttention(
4343
desc, workspace->data(), workspace_size,
44-
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(),
44+
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), kv_lens->data(),
4545
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
4646
context::getStream()));
4747
}

src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,30 @@ common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::disp
1010
};
1111

1212
void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
13-
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
13+
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
1414
std::optional<Tensor> alibi_slopes, float scale) {
15-
16-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q);
15+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q);
1716

1817
infinicore::context::setDevice(out->device());
1918

2019
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables,
21-
history_lens, cu_seqlens_q, alibi_slopes, scale);
20+
kv_lens, cum_seqlens_q, alibi_slopes, scale);
2221
}
2322

2423
Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache,
25-
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
24+
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
2625
std::optional<Tensor> alibi_slopes, float scale) {
2726

2827
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
29-
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
28+
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
3029
return out;
3130
}
3231

3332
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
34-
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
33+
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
3534
std::optional<Tensor> alibi_slopes, float scale) {
3635

37-
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
36+
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
3837
}
3938

4039
} // namespace infinicore::op

src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t>
1616
});
1717

1818
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
19-
Tensor block_tables, Tensor history_lens, Tensor cu_seqlens_q,
19+
Tensor block_tables, Tensor kv_lens, Tensor cum_seqlens_q,
2020
std::optional<Tensor> alibi_slopes, float scale) {
21-
22-
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes, scale);
21+
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, kv_lens, cum_seqlens_q, alibi_slopes, scale);
2322

2423
auto device = context::getDevice();
2524
auto &cache = caches.getCache(device);
@@ -35,8 +34,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
3534
k_cache->desc(),
3635
v_cache->desc(),
3736
block_tables->desc(),
38-
history_lens->desc(),
39-
cu_seqlens_q->desc(),
37+
kv_lens->desc(),
38+
cum_seqlens_q->desc(),
4039
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
4140
scale));
4241
cache.put(seed, desc);
@@ -57,8 +56,8 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache,
5756
k_cache->data(),
5857
v_cache->data(),
5958
block_tables->data(),
60-
history_lens->data(),
61-
cu_seqlens_q->data(),
59+
kv_lens->data(),
60+
cum_seqlens_q->data(),
6261
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
6362
context::getStream()));
6463
}

src/infinicore/pybind11/ops/paged_attention_prefill.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ Tensor py_paged_attention_prefill(Tensor q,
1919
if (!alibi_slopes.is_none()) {
2020
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
2121
}
22-
return op::paged_attention_prefill(q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
22+
return op::paged_attention_prefill(
23+
q, k_cache, v_cache, block_tables, history_lens, cu_seqlens_q, alibi_slopes_tensor, scale);
2324
}
2425

2526
void py_paged_attention_prefill_(Tensor out,

src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@ template <typename Tdata, typename Tcompute>
2222
__global__ void pagedAttentionPrefillKernel(
2323
Tdata *out_, const Tdata *q_, const Tdata *k_cache_, const Tdata *v_cache_,
2424
const int64_t *block_tables_,
25-
const int64_t *history_lens_,
25+
const int64_t *total_kv_lens_,
2626
const int64_t *cum_seq_lens_q_,
2727
const float *alibi_slopes_,
2828
const size_t num_heads, const size_t num_kv_heads, const float scale,
2929
const size_t max_num_blocks_per_seq, const size_t block_size,
3030
const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride,
31+
const ptrdiff_t q_stride, const ptrdiff_t q_head_stride,
3132
const size_t head_size,
3233
const size_t num_seqs) {
3334

@@ -44,10 +45,12 @@ __global__ void pagedAttentionPrefillKernel(
4445

4546
size_t q_token_idx = global_token_idx - cum_seq_lens_q_[seq_idx];
4647

47-
const int64_t history_len = history_lens_[seq_idx];
48-
const int64_t causal_limit = history_len + q_token_idx;
48+
const size_t total_kv_len = total_kv_lens_[seq_idx];
49+
const size_t q_len = cum_seq_lens_q_[seq_idx + 1] - cum_seq_lens_q_[seq_idx];
50+
const size_t history_len = total_kv_len - q_len;
51+
const size_t causal_limit = history_len + q_token_idx;
4952

50-
const Tdata *q_vec = q_ + global_token_idx * num_heads * head_size + head_idx * head_size;
53+
const Tdata *q_vec = q_ + global_token_idx * q_stride + head_idx * q_head_stride;
5154
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;
5255

5356
const size_t num_queries_per_kv = num_heads / num_kv_heads;
@@ -57,10 +60,10 @@ __global__ void pagedAttentionPrefillKernel(
5760
const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx];
5861

5962
Tcompute max_score = -FLT_MAX;
60-
for (int64_t t = 0; t <= causal_limit; ++t) {
61-
const int64_t b_idx = t / block_size;
62-
const int64_t t_off = t % block_size;
63-
const int64_t physical_block_id = block_table[b_idx];
63+
for (size_t t = 0; t <= causal_limit; ++t) {
64+
const size_t b_idx = t / block_size;
65+
const size_t t_off = t % block_size;
66+
const ptrdiff_t physical_block_id = block_table[b_idx];
6467
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
6568

6669
Tcompute score = 0.0f;
@@ -77,10 +80,10 @@ __global__ void pagedAttentionPrefillKernel(
7780
}
7881

7982
Tcompute sum_exp = 0.0f;
80-
for (int64_t t = 0; t <= causal_limit; ++t) {
81-
const int64_t b_idx = t / block_size;
82-
const int64_t t_off = t % block_size;
83-
const int64_t physical_block_id = block_table[b_idx];
83+
for (size_t t = 0; t <= causal_limit; ++t) {
84+
const size_t b_idx = t / block_size;
85+
const size_t t_off = t % block_size;
86+
const ptrdiff_t physical_block_id = block_table[b_idx];
8487
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
8588

8689
Tcompute score = 0.0f;
@@ -96,10 +99,10 @@ __global__ void pagedAttentionPrefillKernel(
9699

97100
Tcompute acc = 0.0f;
98101
Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f);
99-
for (int64_t t = 0; t <= causal_limit; ++t) {
100-
const int64_t b_idx = t / block_size;
101-
const int64_t t_off = t % block_size;
102-
const int64_t physical_block_id = block_table[b_idx];
102+
for (size_t t = 0; t <= causal_limit; ++t) {
103+
const size_t b_idx = t / block_size;
104+
const size_t t_off = t % block_size;
105+
const ptrdiff_t physical_block_id = block_table[b_idx];
103106

104107
const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size;
105108
Tcompute score = 0.0f;

src/infiniop/ops/paged_attention_prefill/info.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class PagedAttentionPrefillInfo {
2525
size_t total_q_tokens;
2626

2727
ptrdiff_t q_stride;
28+
ptrdiff_t q_head_stride;
2829
ptrdiff_t kv_block_stride;
2930
ptrdiff_t kv_head_stride;
3031
ptrdiff_t o_stride;
@@ -35,7 +36,7 @@ class PagedAttentionPrefillInfo {
3536
infiniopTensorDescriptor_t k_cache_desc,
3637
infiniopTensorDescriptor_t v_cache_desc,
3738
infiniopTensorDescriptor_t block_tables_desc,
38-
infiniopTensorDescriptor_t history_lens_desc,
39+
infiniopTensorDescriptor_t seq_lens_desc,
3940
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
4041
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
4142
float scale) {
@@ -47,7 +48,7 @@ class PagedAttentionPrefillInfo {
4748
return INFINI_STATUS_BAD_TENSOR_DTYPE;
4849
}
4950

50-
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || history_lens_desc->dtype() != INFINI_DTYPE_I64) {
51+
if (cum_seq_lens_q_desc->dtype() != INFINI_DTYPE_I64 || seq_lens_desc->dtype() != INFINI_DTYPE_I64) {
5152
return INFINI_STATUS_BAD_TENSOR_DTYPE;
5253
}
5354

@@ -57,7 +58,7 @@ class PagedAttentionPrefillInfo {
5758
auto k_shape = k_cache_desc->shape();
5859
auto v_shape = v_cache_desc->shape();
5960
auto block_tables_shape = block_tables_desc->shape();
60-
auto history_lens_shape = history_lens_desc->shape();
61+
auto seq_lens_shape = seq_lens_desc->shape();
6162
auto cum_seq_lens_q_shape = cum_seq_lens_q_desc->shape();
6263

6364
if (k_shape.size() != 4 || v_shape.size() != 4) {
@@ -68,10 +69,11 @@ class PagedAttentionPrefillInfo {
6869
return INFINI_STATUS_BAD_TENSOR_SHAPE;
6970
}
7071

71-
if (history_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
72+
if (seq_lens_shape.size() != 1 || cum_seq_lens_q_shape.size() != 1) {
7273
return INFINI_STATUS_BAD_TENSOR_SHAPE;
7374
}
74-
if (cum_seq_lens_q_shape[0] != history_lens_shape[0] + 1) {
75+
76+
if (cum_seq_lens_q_shape[0] != seq_lens_shape[0] + 1) {
7577
return INFINI_STATUS_BAD_PARAM;
7678
}
7779

@@ -88,13 +90,13 @@ class PagedAttentionPrefillInfo {
8890
return INFINI_STATUS_BAD_PARAM;
8991
}
9092

91-
size_t num_seqs = history_lens_shape[0];
92-
93+
size_t num_seqs = seq_lens_shape[0];
9394
size_t num_kv_heads = k_shape[1];
9495
size_t block_size = k_shape[2];
9596
size_t max_num_blocks_per_seq = block_tables_shape[1];
9697

9798
ptrdiff_t q_stride = q_desc->stride(0);
99+
ptrdiff_t q_head_stride = q_desc->stride(1);
98100
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
99101
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
100102
ptrdiff_t o_stride = out_desc->stride(0);
@@ -110,6 +112,7 @@ class PagedAttentionPrefillInfo {
110112
max_num_blocks_per_seq,
111113
total_q_tokens,
112114
q_stride,
115+
q_head_stride,
113116
kv_block_stride,
114117
kv_head_stride,
115118
o_stride});

0 commit comments

Comments
 (0)