Skip to content

Commit 96551cb

Browse files
issue/867 fix page caching api, paged attn support more head dims
1 parent 01a4a0c commit 96551cb

File tree

19 files changed

+171
-149
lines changed

19 files changed

+171
-149
lines changed

include/infinicore/ops/paged_caching.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ namespace infinicore::op {
88
class PagedCaching {
99
public:
1010
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
11-
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
11+
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
1212
static common::OpDispatcher<schema> &dispatcher();
1313
};
1414

15-
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
15+
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
1616

1717
} // namespace infinicore::op

include/infiniop/ops/paged_caching.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
1414
*
1515
* @param handle The handle to the InfiniOP library context.
1616
* @param desc_ptr A pointer to store the created descriptor.
17-
* @param k_desc Descriptor for the source key tensor.
18-
* @param v_desc Descriptor for the source value tensor.
1917
* @param k_cache_desc Descriptor for the key cache pool tensor.
2018
* @param v_cache_desc Descriptor for the value cache pool tensor.
19+
* @param k_desc Descriptor for the source key tensor.
20+
* @param v_desc Descriptor for the source value tensor.
2121
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
2222
* @return infiniStatus_t Status code of the operation.
2323
*/
2424
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
2525
infiniopHandle_t handle,
2626
infiniopPagedCachingDescriptor_t *desc_ptr,
27-
infiniopTensorDescriptor_t k_desc,
28-
infiniopTensorDescriptor_t v_desc,
2927
infiniopTensorDescriptor_t k_cache_desc,
3028
infiniopTensorDescriptor_t v_cache_desc,
29+
infiniopTensorDescriptor_t k_desc,
30+
infiniopTensorDescriptor_t v_desc,
3131
infiniopTensorDescriptor_t slot_mapping_desc);
3232

3333
/**
@@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
4646
* @param desc The Paged Caching descriptor.
4747
* @param workspace Pointer to the workspace memory.
4848
* @param workspace_size The size of the workspace.
49-
* @param k Pointer to the source key tensor data.
50-
* @param v Pointer to the source value tensor data.
5149
* @param k_cache Pointer to the key cache pool data.
5250
* @param v_cache Pointer to the value cache pool data.
51+
* @param k Pointer to the source key tensor data.
52+
* @param v Pointer to the source value tensor data.
5353
* @param slot_mapping Pointer to the slot mapping data.
5454
* @param stream The CUDA stream for the operation. Can be NULL.
5555
* @return infiniStatus_t Status code of the operation.
@@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
5858
infiniopPagedCachingDescriptor_t desc,
5959
void *workspace,
6060
size_t workspace_size,
61-
const void *k,
62-
const void *v,
6361
void *k_cache,
6462
void *v_cache,
63+
const void *k,
64+
const void *v,
6565
const void *slot_mapping,
6666
void *stream);
6767

python/infinicore/ops/paged_caching.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33

44

55
def paged_caching(
6-
k: Tensor,
7-
v: Tensor,
86
k_cache: Tensor,
97
v_cache: Tensor,
8+
k: Tensor,
9+
v: Tensor,
1010
slot_mapping: Tensor,
1111
):
1212
Tensor(
1313
_infinicore.paged_caching_(
14-
k._underlying,
15-
v._underlying,
1614
k_cache._underlying,
1715
v_cache._underlying,
16+
k._underlying,
17+
v._underlying,
1818
slot_mapping._underlying,
1919
)
2020
)

src/infinicore/ops/paged_caching/paged_caching.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
99
return dispatcher_;
1010
};
1111

12-
void PagedCaching::execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
13-
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k, v, k_cache, v_cache, slot_mapping);
14-
infinicore::context::setDevice(k->device());
15-
dispatcher().lookup(k->device().getType())(k, v, k_cache, v_cache, slot_mapping);
12+
void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
13+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping);
14+
infinicore::context::setDevice(k_cache->device());
15+
dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping);
1616
}
1717

18-
void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
19-
PagedCaching::execute(k, v, k_cache, v_cache, slot_mapping);
18+
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
19+
PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping);
2020
}
2121

2222
} // namespace infinicore::op

src/infinicore/ops/paged_caching/paged_caching_infiniop.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
1515
}
1616
});
1717

18-
void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
19-
size_t seed = hash_combine(k, v, k_cache, v_cache, slot_mapping);
18+
void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
19+
size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping);
2020

2121
auto device = context::getDevice();
2222
auto &cache = caches.getCache(device);
@@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
2727
if (!desc_opt) {
2828
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor(
2929
context::getInfiniopHandle(device), &desc,
30-
k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc()));
30+
k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc()));
3131
cache.put(seed, desc);
3232
} else {
3333
desc = *desc_opt;
@@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
3939

4040
INFINICORE_CHECK_ERROR(infiniopPagedCaching(
4141
desc, workspace->data(), workspace_size,
42-
k->data(), v->data(), k_cache->data(), v_cache->data(), slot_mapping->data(), context::getStream()));
42+
k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream()));
4343
}
4444

4545
static bool registered = []() {

src/infinicore/pybind11/ops/paged_caching.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ namespace infinicore::ops {
1111
inline void bind_paged_caching(py::module &m) {
1212
m.def("paged_caching_",
1313
&op::paged_caching_,
14-
py::arg("k"),
15-
py::arg("v"),
1614
py::arg("k_cache"),
1715
py::arg("v_cache"),
16+
py::arg("k"),
17+
py::arg("v"),
1818
py::arg("slot_mapping"),
1919
R"doc(Paged caching of key and value tensors.)doc");
2020
}

src/infiniop/ops/paged_attention/info.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,9 @@ class PagedAttentionInfo {
6767
size_t num_heads = q_shape[1];
6868
size_t head_size = q_shape[2];
6969

70-
if (head_size != 128) {
71-
// 输出具体的错误原因和当前的参数值
72-
std::cerr << "[Error] Now only supports head_size = 128, but got "
70+
if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) {
71+
std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got "
7372
<< head_size << "." << std::endl;
74-
// 建议返回 SHAPE 相关的错误码
7573
return INFINI_STATUS_BAD_TENSOR_SHAPE;
7674
}
7775

src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
9898
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
9999
void *stream_) const {
100100
cudaStream_t stream = (cudaStream_t)stream_;
101+
102+
#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
103+
launchKernel<__H_SIZE, __B_SIZE>( \
104+
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
105+
_info.num_heads, _info.num_seqs, \
106+
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
107+
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
108+
stream);
109+
110+
#define SWITCH_HEAD_SIZE(__B_SIZE) \
111+
switch (_info.head_size) { \
112+
case 16: \
113+
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
114+
break; \
115+
case 32: \
116+
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
117+
break; \
118+
case 64: \
119+
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
120+
break; \
121+
case 128: \
122+
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
123+
break; \
124+
case 256: \
125+
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
126+
break; \
127+
default: \
128+
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
129+
}
130+
101131
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
102-
if (_info.head_size == 128) {
103-
launchKernel<128, CUDA_BLOCK_SIZE_1024>(
104-
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
105-
_info.num_heads, _info.num_seqs,
106-
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
107-
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
108-
stream);
109-
}
132+
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
110133
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
111-
if (_info.head_size == 128) {
112-
launchKernel<128, CUDA_BLOCK_SIZE_512>(
113-
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
114-
_info.num_heads, _info.num_seqs,
115-
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
116-
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
117-
stream);
118-
}
134+
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
119135
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
120-
if (_info.head_size == 128) {
121-
launchKernel<128, CUDA_BLOCK_SIZE_4096>(
122-
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
123-
_info.num_heads, _info.num_seqs,
124-
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
125-
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
126-
stream);
127-
}
136+
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
128137
} else {
129138
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
130139
}
131140

141+
#undef LAUNCH_HEADSIZE_BLOCKSIZE
142+
#undef SWITCH_HEAD_SIZE
143+
132144
return INFINI_STATUS_SUCCESS;
133145
}
134146

src/infiniop/ops/paged_attention/operator.cc

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
#ifdef ENABLE_NVIDIA_API
66
#include "nvidia/paged_attention_nvidia.cuh"
77
#endif
8-
#ifdef ENABLE_METAX_API
9-
#include "metax/paged_attention_metax.h"
10-
#endif
8+
// #ifdef ENABLE_METAX_API
9+
// #include "metax/paged_attention_metax.h"
10+
// #endif
1111

1212
__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
1313
infiniopHandle_t handle,
@@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
3434
#ifdef ENABLE_NVIDIA_API
3535
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
3636
#endif
37-
#ifdef ENABLE_METAX_API
38-
CREATE(INFINI_DEVICE_METAX, metax)
39-
#endif
37+
// #ifdef ENABLE_METAX_API
38+
// CREATE(INFINI_DEVICE_METAX, metax)
39+
// #endif
40+
default:
41+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
4042
}
41-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
4243
}
4344

4445
__C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
@@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
5455
#ifdef ENABLE_NVIDIA_API
5556
GET(INFINI_DEVICE_NVIDIA, nvidia)
5657
#endif
57-
#ifdef ENABLE_METAX_API
58-
GET(INFINI_DEVICE_METAX, metax)
59-
#endif
58+
// #ifdef ENABLE_METAX_API
59+
// GET(INFINI_DEVICE_METAX, metax)
60+
// #endif
61+
default:
62+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
6063
}
61-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
6264
}
6365

6466
__C infiniStatus_t infiniopPagedAttention(
@@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention(
7880
#ifdef ENABLE_NVIDIA_API
7981
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
8082
#endif
81-
#ifdef ENABLE_METAX_API
82-
CALCULATE(INFINI_DEVICE_METAX, metax)
83-
#endif
83+
// #ifdef ENABLE_METAX_API
84+
// CALCULATE(INFINI_DEVICE_METAX, metax)
85+
// #endif
86+
default:
87+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
8488
}
85-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
8689
}
8790

8891
__C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
@@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
97100
#ifdef ENABLE_NVIDIA_API
98101
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
99102
#endif
100-
#ifdef ENABLE_METAX_API
101-
DESTROY(INFINI_DEVICE_METAX, metax)
102-
#endif
103+
// #ifdef ENABLE_METAX_API
104+
// DESTROY(INFINI_DEVICE_METAX, metax)
105+
// #endif
106+
default:
107+
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
103108
}
104-
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
105109
}

src/infiniop/ops/paged_caching/info.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ class PagedCachingInfo {
2828
ptrdiff_t v_cache_block_stride;
2929

3030
static utils::Result<PagedCachingInfo> create(
31-
infiniopTensorDescriptor_t k_desc,
32-
infiniopTensorDescriptor_t v_desc,
3331
infiniopTensorDescriptor_t k_cache_desc,
3432
infiniopTensorDescriptor_t v_cache_desc,
33+
infiniopTensorDescriptor_t k_desc,
34+
infiniopTensorDescriptor_t v_desc,
3535
infiniopTensorDescriptor_t slot_mapping_desc) {
3636

3737
auto dtype = k_desc->dtype();

0 commit comments

Comments
 (0)