Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/infinicore/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace infinicore::op {
class PagedCaching {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor);
static void execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
static void execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);
static common::OpDispatcher<schema> &dispatcher();
};

void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping);

} // namespace infinicore::op
16 changes: 8 additions & 8 deletions include/infiniop/ops/paged_caching.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc);

/**
Expand All @@ -46,10 +46,10 @@ __C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
Expand All @@ -58,10 +58,10 @@ __C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
const void *k,
const void *v,
void *k_cache,
void *v_cache,
const void *k,
const void *v,
const void *slot_mapping,
void *stream);

Expand Down
8 changes: 4 additions & 4 deletions python/infinicore/ops/paged_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@


def paged_caching(
k: Tensor,
v: Tensor,
k_cache: Tensor,
v_cache: Tensor,
k: Tensor,
v: Tensor,
slot_mapping: Tensor,
):
Tensor(
_infinicore.paged_caching_(
k._underlying,
v._underlying,
k_cache._underlying,
v_cache._underlying,
k._underlying,
v._underlying,
slot_mapping._underlying,
)
)
Expand Down
12 changes: 6 additions & 6 deletions src/infinicore/ops/paged_caching/paged_caching.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ common::OpDispatcher<PagedCaching::schema> &PagedCaching::dispatcher() {
return dispatcher_;
};

void PagedCaching::execute(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k, v, k_cache, v_cache, slot_mapping);
infinicore::context::setDevice(k->device());
dispatcher().lookup(k->device().getType())(k, v, k_cache, v_cache, slot_mapping);
void PagedCaching::execute(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(k_cache, v_cache, k, v, slot_mapping);
infinicore::context::setDevice(k_cache->device());
dispatcher().lookup(k_cache->device().getType())(k_cache, v_cache, k, v, slot_mapping);
}

void paged_caching_(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
PagedCaching::execute(k, v, k_cache, v_cache, slot_mapping);
void paged_caching_(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
PagedCaching::execute(k_cache, v_cache, k, v, slot_mapping);
}

} // namespace infinicore::op
8 changes: 4 additions & 4 deletions src/infinicore/ops/paged_caching/paged_caching_infiniop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedCachingDescriptor_t> caches(
}
});

void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_mapping) {
size_t seed = hash_combine(k, v, k_cache, v_cache, slot_mapping);
void calculate(Tensor k_cache, Tensor v_cache, Tensor k, Tensor v, Tensor slot_mapping) {
size_t seed = hash_combine(k_cache, v_cache, k, v, slot_mapping);

auto device = context::getDevice();
auto &cache = caches.getCache(device);
Expand All @@ -27,7 +27,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedCachingDescriptor(
context::getInfiniopHandle(device), &desc,
k->desc(), v->desc(), k_cache->desc(), v_cache->desc(), slot_mapping->desc()));
k_cache->desc(), v_cache->desc(), k->desc(), v->desc(), slot_mapping->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
Expand All @@ -39,7 +39,7 @@ void calculate(Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, Tensor slot_m

INFINICORE_CHECK_ERROR(infiniopPagedCaching(
desc, workspace->data(), workspace_size,
k->data(), v->data(), k_cache->data(), v_cache->data(), slot_mapping->data(), context::getStream()));
k_cache->data(), v_cache->data(), k->data(), v->data(), slot_mapping->data(), context::getStream()));
}

static bool registered = []() {
Expand Down
4 changes: 2 additions & 2 deletions src/infinicore/pybind11/ops/paged_caching.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace infinicore::ops {
inline void bind_paged_caching(py::module &m) {
m.def("paged_caching_",
&op::paged_caching_,
py::arg("k"),
py::arg("v"),
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("k"),
py::arg("v"),
py::arg("slot_mapping"),
R"doc(Paged caching of key and value tensors.)doc");
}
Expand Down
6 changes: 2 additions & 4 deletions src/infiniop/ops/paged_attention/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,9 @@ class PagedAttentionInfo {
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];

if (head_size != 128) {
// 输出具体的错误原因和当前的参数值
std::cerr << "[Error] Now only supports head_size = 128, but got "
if (head_size != 16 && head_size != 32 && head_size != 64 && head_size != 128 && head_size != 256) {
std::cerr << "[Error] Now only supports head_size = 16/32/64/128/256, but got "
<< head_size << "." << std::endl;
// 建议返回 SHAPE 相关的错误码
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

Expand Down
60 changes: 36 additions & 24 deletions src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,37 +98,49 @@ infiniStatus_t Descriptor::calculate(
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream_) const {
cudaStream_t stream = (cudaStream_t)stream_;

#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);

#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}

if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_1024>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_512>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_512)
} else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) {
if (_info.head_size == 128) {
launchKernel<128, CUDA_BLOCK_SIZE_4096>(
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes,
_info.num_heads, _info.num_seqs,
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size,
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride,
stream);
}
SWITCH_HEAD_SIZE(CUDA_BLOCK_SIZE_4096)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}

#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE

return INFINI_STATUS_SUCCESS;
}

Expand Down
42 changes: 23 additions & 19 deletions src/infiniop/ops/paged_attention/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_attention_metax.h"
// #endif

__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
Expand All @@ -34,11 +34,12 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}

__C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
Expand All @@ -54,11 +55,12 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}

__C infiniStatus_t infiniopPagedAttention(
Expand All @@ -78,11 +80,12 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}

__C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
Expand All @@ -97,9 +100,10 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
4 changes: 2 additions & 2 deletions src/infiniop/ops/paged_caching/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ class PagedCachingInfo {
ptrdiff_t v_cache_block_stride;

static utils::Result<PagedCachingInfo> create(
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {

auto dtype = k_desc->dtype();
Expand Down
8 changes: 4 additions & 4 deletions src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ Descriptor::~Descriptor() {
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t slot_mapping_desc) {

auto info = PagedCachingInfo::create(k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc);
auto info = PagedCachingInfo::create(k_cache_desc, v_cache_desc, k_desc, v_desc, slot_mapping_desc);
CHECK_RESULT(info);

// Create and return the Descriptor instance.
Expand Down Expand Up @@ -121,8 +121,8 @@ infiniStatus_t launchKernel(const PagedCachingInfo &info,
// Execution method implementation
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
const void *k, const void *v,
void *k_cache, void *v_cache,
const void *k, const void *v,
const void *slot_mapping,
void *stream_) const {

Expand Down
Loading