From 59ce336ee0f7584688be4f43248e4dbfcaaa9676 Mon Sep 17 00:00:00 2001 From: suss <1152623206@qq.com> Date: Tue, 26 Aug 2025 09:24:44 +0800 Subject: [PATCH 1/4] paged attn v1 --- include/infiniop.h | 2 + include/infiniop/ops/paged_attention.h | 88 +++++ include/infiniop/ops/paged_caching.h | 77 +++++ scripts/python_test.py | 2 + src/infiniop-test/include/ops.hpp | 2 + src/infiniop-test/src/ops/paged_attention.cpp | 163 +++++++++ .../nvidia/causal_softmax_nvidia.cu | 1 - .../cpu/paged_attention_cpu.cc | 150 +++++++++ .../paged_attention/cpu/paged_attention_cpu.h | 10 + .../ops/paged_attention/cuda/kernel.cuh | 222 +++++++++++++ src/infiniop/ops/paged_attention/info.h | 100 ++++++ .../nvidia/paged_attention_nvidia.cu | 162 +++++++++ .../nvidia/paged_attention_nvidia.cuh | 8 + src/infiniop/ops/paged_attention/operator.cc | 138 ++++++++ .../ops/paged_attention/paged_attention.h | 53 +++ .../ops/paged_caching/cuda/kernel.cuh | 147 +++++++++ src/infiniop/ops/paged_caching/info.h | 77 +++++ .../nvidia/paged_caching_nvidia.cu | 143 ++++++++ .../nvidia/paged_caching_nvidia.cuh | 8 + src/infiniop/ops/paged_caching/operator.cc | 93 ++++++ .../ops/paged_caching/paged_caching.h | 52 +++ .../testcases/paged_attention.py | 110 +++++++ test/infiniop/libinfiniop/op_register.py | 84 +++++ test/infiniop/paged_attention.py | 308 ++++++++++++++++++ test/infiniop/paged_caching.py | 243 ++++++++++++++ xmake/nvidia.lua | 6 + 26 files changed, 2448 insertions(+), 1 deletion(-) create mode 100644 include/infiniop/ops/paged_attention.h create mode 100644 include/infiniop/ops/paged_caching.h create mode 100644 src/infiniop-test/src/ops/paged_attention.cpp create mode 100644 src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.cc create mode 100644 src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.h create mode 100644 src/infiniop/ops/paged_attention/cuda/kernel.cuh create mode 100644 src/infiniop/ops/paged_attention/info.h create mode 100644 src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu create mode 100644 src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh create mode 100644 src/infiniop/ops/paged_attention/operator.cc create mode 100644 src/infiniop/ops/paged_attention/paged_attention.h create mode 100644 src/infiniop/ops/paged_caching/cuda/kernel.cuh create mode 100644 src/infiniop/ops/paged_caching/info.h create mode 100644 src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu create mode 100644 src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh create mode 100644 src/infiniop/ops/paged_caching/operator.cc create mode 100644 src/infiniop/ops/paged_caching/paged_caching.h create mode 100644 test/infiniop-test/test_generate/testcases/paged_attention.py create mode 100644 test/infiniop/paged_attention.py create mode 100644 test/infiniop/paged_caching.py diff --git a/include/infiniop.h b/include/infiniop.h index cc1c19be6..b8bf32a50 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -10,6 +10,8 @@ #include "infiniop/ops/dequantize.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/mul.h" +#include "infiniop/ops/paged_attention.h" +#include "infiniop/ops/paged_caching.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" diff --git a/include/infiniop/ops/paged_attention.h b/include/infiniop/ops/paged_attention.h new file mode 100644 index 000000000..2cabf3a1b --- /dev/null +++ b/include/infiniop/ops/paged_attention.h @@ -0,0 +1,88 @@ +#ifndef __INFINIOP_PAGED_ATTENTION_API_H__ +#define __INFINIOP_PAGED_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +// Define an opaque handle for the Paged Attention descriptor. +typedef struct InfiniopDescriptor *infiniopPagedAttentionDescriptor_t; + +/** + * @brief Creates a descriptor for the Paged Attention v1 operation. + * + * This function initializes a descriptor that holds all the metadata needed + * for the paged attention computation. + * + * @param handle The handle to the InfiniOP library context. + * @param desc_ptr A pointer to store the created descriptor. + * @param out_desc Descriptor for the output tensor. + * @param q_desc Descriptor for the query tensor. + * @param k_cache_desc Descriptor for the key cache tensor. + * @param v_cache_desc Descriptor for the value cache tensor. + * @param block_tables_desc Descriptor for the block tables tensor. + * @param seq_lens_desc Descriptor for the sequence lengths tensor. + * @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL. + * @param scale The attention scaling factor. + * @param max_num_blocks_per_seq The maximum number of batched blocks tables. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopCreatePagedAttentionDescriptor( + infiniopHandle_t handle, + infiniopPagedAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t alibi_slopes_desc, + float scale); + +/** + * @brief Retrieves the workspace size required for the Paged Attention operation. + * + * @param desc The Paged Attention descriptor. + * @param size A pointer to store the required workspace size in bytes. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( + infiniopPagedAttentionDescriptor_t desc, size_t *size); + +/** + * @brief Executes the Paged Attention v1 operation. + * + * @param desc The Paged Attention descriptor. + * @param workspace Pointer to the workspace memory. + * @param workspace_size The size of the workspace. + * @param out Pointer to the output tensor data. + * @param q Pointer to the query tensor data. + * @param k_cache Pointer to the key cache data. + * @param v_cache Pointer to the value cache data. + * @param block_tables Pointer to the block tables data. + * @param seq_lens Pointer to the sequence lengths data. + * @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL. + * @param stream The CUDA stream for the operation. Can be NULL. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopPagedAttention( + infiniopPagedAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *alibi_slopes, + void *stream); + +/** + * @brief Destroys a Paged Attention descriptor. + * + * @param desc The descriptor to be destroyed. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopDestroyPagedAttentionDescriptor( + infiniopPagedAttentionDescriptor_t desc); + +#endif // __INFINIOP_PAGED_ATTENTION_API_H__ \ No newline at end of file diff --git a/include/infiniop/ops/paged_caching.h b/include/infiniop/ops/paged_caching.h new file mode 100644 index 000000000..807f4e924 --- /dev/null +++ b/include/infiniop/ops/paged_caching.h @@ -0,0 +1,77 @@ +#ifndef __INFINIOP_PAGED_CACHING_API_H__ +#define __INFINIOP_PAGED_CACHING_API_H__ + +#include "../operator_descriptor.h" + +// Define an opaque handle for the Paged Caching descriptor. +typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t; + +/** + * @brief Creates a descriptor for the Paged Caching operation. + * + * This function initializes a descriptor that holds all the metadata needed + * to copy key/value vectors into their respective cache pools. + * + * @param handle The handle to the InfiniOP library context. + * @param desc_ptr A pointer to store the created descriptor. + * @param k_desc Descriptor for the source key tensor. + * @param v_desc Descriptor for the source value tensor. + * @param k_cache_desc Descriptor for the key cache pool tensor. + * @param v_cache_desc Descriptor for the value cache pool tensor. + * @param slot_mapping_desc Descriptor for the slot mapping tensor. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor( + infiniopHandle_t handle, + infiniopPagedCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t slot_mapping_desc); + +/** + * @brief Retrieves the workspace size required for the Paged Caching operation. + * + * @param desc The Paged Caching descriptor. + * @param size A pointer to store the required workspace size in bytes (typically 0). + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize( + infiniopPagedCachingDescriptor_t desc, size_t *size); + +/** + * @brief Executes the Paged Caching operation. + * + * @param desc The Paged Caching descriptor. + * @param workspace Pointer to the workspace memory. + * @param workspace_size The size of the workspace. + * @param k Pointer to the source key tensor data. + * @param v Pointer to the source value tensor data. + * @param k_cache Pointer to the key cache pool data. + * @param v_cache Pointer to the value cache pool data. + * @param slot_mapping Pointer to the slot mapping data. + * @param stream The CUDA stream for the operation. Can be NULL. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopPagedCaching( + infiniopPagedCachingDescriptor_t desc, + void *workspace, + size_t workspace_size, + const void *k, + const void *v, + void *k_cache, + void *v_cache, + const void *slot_mapping, + void *stream); + +/** + * @brief Destroys a Paged Caching descriptor. + * + * @param desc The descriptor to be destroyed. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopDestroyPagedCachingDescriptor( + infiniopPagedCachingDescriptor_t desc); + +#endif // __INFINIOP_PAGED_CACHING_API_H__ \ No newline at end of file diff --git a/scripts/python_test.py b/scripts/python_test.py index eb2d4319e..0edf56b00 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -18,6 +18,8 @@ def run_tests(args): "clip.py", "gemm.py", "mul.py", + "paged_attention.py", + "paged_caching.py", "random_sample.py", "rearrange.py", "rms_norm.py", diff --git a/src/infiniop-test/include/ops.hpp b/src/infiniop-test/include/ops.hpp index 3820f7cfd..e350e8605 100644 --- a/src/infiniop-test/include/ops.hpp +++ b/src/infiniop-test/include/ops.hpp @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add) DECLARE_INFINIOP_TEST(causal_softmax) DECLARE_INFINIOP_TEST(rearrange) DECLARE_INFINIOP_TEST(sub) +DECLARE_INFINIOP_TEST(paged_attention) #define REGISTER_INFINIOP_TEST(name) \ { \ @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub) REGISTER_INFINIOP_TEST(causal_softmax) \ REGISTER_INFINIOP_TEST(rearrange) \ REGISTER_INFINIOP_TEST(sub) \ + REGISTER_INFINIOP_TEST(paged_attention) \ } namespace infiniop_test { diff --git a/src/infiniop-test/src/ops/paged_attention.cpp b/src/infiniop-test/src/ops/paged_attention.cpp new file mode 100644 index 000000000..245172138 --- /dev/null +++ b/src/infiniop-test/src/ops/paged_attention.cpp @@ -0,0 +1,163 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::paged_attention { + +// The Test class for the paged_attention operator. +struct Test::Attributes { + // Paged attention uses tensors for most parameters, but scale is a scalar. + std::shared_ptr scale; + + // Tensors for the operation. + std::shared_ptr q; + std::shared_ptr k_cache; + std::shared_ptr v_cache; + std::shared_ptr block_tables; + std::shared_ptr seq_lens; + std::shared_ptr alibi_slopes; // Can be null + std::shared_ptr ans; + std::shared_ptr out; + + // MODIFIED: op_desc and workspace are removed from here. + // They will be managed as local variables within the run() function, + // which is a cleaner, safer pattern demonstrated by the causal_softmax example. +}; + +// Factory method to build a Test object from GGUF data. +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + if (!check_names(tensors, Test::tensor_names())) { + throw std::runtime_error("Invalid Test: Missing tensors."); + } + + test->_attributes->scale = tensors["scale"]; + test->_attributes->q = tensors["q"]; + test->_attributes->k_cache = tensors["k_cache"]; + test->_attributes->v_cache = tensors["v_cache"]; + test->_attributes->block_tables = tensors["block_tables"]; + test->_attributes->seq_lens = tensors["seq_lens"]; + if (tensors.count("alibi_slopes")) { + test->_attributes->alibi_slopes = tensors["alibi_slopes"]; + } else { + test->_attributes->alibi_slopes = nullptr; + } + test->_attributes->ans = tensors["ans"]; + test->_attributes->out = tensors["out"]; + + return test; +} + +// Executes the test case. +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) { + + // MODIFIED: op_desc and workspace are now local variables. + infiniopPagedAttentionDescriptor_t op_desc = nullptr; + void *workspace = nullptr; + + // Move tensors to the target device + auto scale_tensor = _attributes->scale->to(device, device_id); + auto q = _attributes->q->to(device, device_id); + auto k_cache = _attributes->k_cache->to(device, device_id); + auto v_cache = _attributes->v_cache->to(device, device_id); + auto block_tables = _attributes->block_tables->to(device, device_id); + auto seq_lens = _attributes->seq_lens->to(device, device_id); + auto out = _attributes->out->to(device, device_id); + std::shared_ptr alibi_slopes = nullptr; + if (_attributes->alibi_slopes) { + alibi_slopes = _attributes->alibi_slopes->to(device, device_id); + } + + float scale_val = *reinterpret_cast(scale_tensor->data()); + + // Create operator descriptor + CHECK_OR(infiniopCreatePagedAttentionDescriptor( + handle, &op_desc, out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), + block_tables->desc(), seq_lens->desc(), + alibi_slopes ? alibi_slopes->desc() : nullptr, scale_val), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor.")); + + // Get workspace size and allocate memory + size_t workspace_size; + CHECK_OR(infiniopGetPagedAttentionWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size.")); + if (workspace_size > 0) { + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace.")); + } + + // Execute the operator for the first time + CHECK_OR(infiniopPagedAttention(op_desc, workspace, workspace_size, + out->data(), q->data(), k_cache->data(), v_cache->data(), + block_tables->data(), seq_lens->data(), + alibi_slopes ? alibi_slopes->data() : nullptr, nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution.")); + + // Verify the result + try { + allClose(out, _attributes->ans, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + + // Benchmark the operation + double elapsed_time = 0.; + elapsed_time = benchmark( + [=]() { // Use reference capture to ensure local variables are available + infiniopPagedAttention(op_desc, workspace, workspace_size, + out->data(), q->data(), k_cache->data(), v_cache->data(), + block_tables->data(), seq_lens->data(), + alibi_slopes ? alibi_slopes->data() : nullptr, nullptr); + }, + warm_ups, iterations); + // return TEST_PASSED(elapsed_time); + + // Cleanup and return success + if (op_desc) { infiniopDestroyPagedAttentionDescriptor(op_desc); } + if (workspace) { infinirtFree(workspace); } + return TEST_PASSED(elapsed_time); +} + +// Define expected attribute and tensor names for validation. +std::vector Test::attribute_names() { return {}; } +std::vector Test::tensor_names() { + return {"scale", "q", "k_cache", "v_cache", "block_tables", "seq_lens", "ans", "out"}; +} +std::vector Test::output_names() { return {"out"}; } + +// MODIFIED: Added a toString() method for better debugging and logging, mimicking the reference file. +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- q: " << _attributes->q->info() << std::endl; + oss << "- k_cache: " << _attributes->k_cache->info() << std::endl; + oss << "- v_cache: " << _attributes->v_cache->info() << std::endl; + oss << "- block_tables: " << _attributes->block_tables->info() << std::endl; + oss << "- seq_lens: " << _attributes->seq_lens->info() << std::endl; + if (_attributes->alibi_slopes) { + oss << "- alibi_slopes: " << _attributes->alibi_slopes->info() << std::endl; + } + oss << "- out: " << _attributes->out->info() << std::endl; + oss << "- ans: " << _attributes->ans->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +// Destructor to clean up resources. +// MODIFIED: The destructor is now simpler as it only needs to free the attributes struct. +Test::~Test() { + if (_attributes) { + delete _attributes; + } +} + +} // namespace infiniop_test::paged_attention \ No newline at end of file diff --git a/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu b/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu index 6dae5af61..57cf64da5 100644 --- a/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu +++ b/src/infiniop/ops/causal_softmax/nvidia/causal_softmax_nvidia.cu @@ -3,7 +3,6 @@ #include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include - #include "../../../reduce/cuda/reduce.cuh" #include "../cuda/kernel.cuh" diff --git a/src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.cc b/src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.cc new file mode 100644 index 000000000..0e8dada05 --- /dev/null +++ b/src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.cc @@ -0,0 +1,150 @@ +#include "paged_attention_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include +#include + +// TODO finish cpu version + +namespace op::paged_attention::cpu { + +Descriptor::~Descriptor() {} + +// Factory function to create a CPU descriptor for Paged Attention. +// NOTE: This part is already well-structured and consistent with the CUDA version, so no changes are needed. +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + const std::optional& alibi_slopes_desc, + float scale + ) { + + // auto result = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, + // block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); + // CHECK_RESULT(result); + // *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// ================================================================================= +// MODIFIED: The core CPU logic is completely refactored below. +// ================================================================================= +// template +// infiniStatus_t paged_attention(const PagedAttentionInfo *info, +// T *out, const T *q, +// const T *k_cache, const T *v_cache, +// const int *block_tables, const int *seq_lens, +// const float *alibi_slopes) { +// Parallelize the operation over sequences and heads using OpenMP. +// #pragma omp parallel for +// for (ptrdiff_t i = 0; i < ptrdiff_t(info->num_seqs * info->num_heads); ++i) { +// const size_t seq_idx = i / info->num_heads; +// const size_t head_idx = i % info->num_heads; +// const size_t seq_len = seq_lens[seq_idx]; + +// if (seq_len == 0) continue; + +// // MODIFIED: Pointer arithmetic now strictly uses strides from the info struct. +// // We cast to char* to perform byte-level stride calculations, which is the safest way. +// const char* q_base_ptr = (const char*)q + seq_idx * info->q_stride; +// const T* q_vec = (const T*)(q_base_ptr) + head_idx * info->head_size; + +// char* out_base_ptr = (char*)out + seq_idx * info->q_stride; // Output has same layout as query +// T* out_vec = (T*)(out_base_ptr) + head_idx * info->head_size; + +// const size_t kv_head_idx = head_idx / (info->num_heads / info->num_kv_heads); + +// std::vector logits(seq_len); +// float max_logit = -INFINITY; + +// // 1. Compute QK dot products and find max logit +// for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) { +// const size_t block_table_idx = seq_idx * info->max_num_blocks_per_seq + token_idx / info->block_size; +// const size_t block_num = block_tables[block_table_idx]; +// const size_t block_off = token_idx % info->block_size; + +// // MODIFIED: K-Cache access logic now matches the CUDA kernel's high-performance layout. +// // Layout assumption: [num_blocks, num_kv_heads, BLOCK_SIZE, HEAD_SIZE] +// const char* k_block_ptr = (const char*)k_cache + block_num * info->kv_block_stride; +// const char* k_head_ptr = k_block_ptr + kv_head_idx * info->kv_head_stride; +// const T* k_vec_ptr = (const T*)k_head_ptr + block_off * info->head_size; + +// float qk = 0.0f; +// for (size_t h = 0; h < info->head_size; ++h) { +// qk += utils::cast(q_vec[h]) * utils::cast(k_vec_ptr[h]); +// } + +// logits[token_idx] = qk * info->scale; +// if (info->has_alibi) { +// logits[token_idx] += alibi_slopes[head_idx] * (token_idx - seq_len + 1); +// } +// if (logits[token_idx] > max_logit) { +// max_logit = logits[token_idx]; +// } +// } + +// // 2. Compute Softmax +// float exp_sum = 0.0f; +// for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) { +// float val = std::exp(logits[token_idx] - max_logit); +// logits[token_idx] = val; +// exp_sum += val; +// } + +// const float inv_sum = 1.0f / (exp_sum + 1e-6f); +// for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) { +// logits[token_idx] *= inv_sum; +// } + +// // 3. Aggregate V values +// std::vector acc(info->head_size, 0.0f); +// for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) { +// const size_t block_table_idx = seq_idx * info->max_num_blocks_per_seq + token_idx / info->block_size; +// const size_t block_num = block_tables[block_table_idx]; +// const size_t block_off = token_idx % info->block_size; +// const float prob = logits[token_idx]; + +// // MODIFIED: V-Cache access logic also matches the CUDA kernel's layout. +// // We assume K and V have the same layout and strides. +// const char* v_block_ptr = (const char*)v_cache + block_num * info->kv_block_stride; +// const char* v_head_ptr = v_block_ptr + kv_head_idx * info->kv_head_stride; +// const T* v_vec_ptr = (const T*)v_head_ptr + block_off * info->head_size; + +// for (size_t h = 0; h < info->head_size; ++h) { +// acc[h] += prob * utils::cast(v_vec_ptr[h]); +// } +// } + +// for(size_t h = 0; h < info->head_size; ++h) { +// out_vec[h] = utils::cast(acc[h]); +// } +// } +// return INFINI_STATUS_SUCCESS; +// } + +// Dispatches the call to the correct templated implementation based on dtype. +// NOTE: This part is also consistent, no changes needed. +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *seq_lens, const void *alibi_slopes, + void *stream) const { + + // // NOTE: CPU version typically uses F32 for computation. If F16/BF16 support + // // is needed, conversions or specialized libraries would be required. + // if (_info.dtype == INFINI_DTYPE_F32) { + // CHECK_STATUS(paged_attention(&_info, (float *)out, (const float *)q, (const float *)k_cache, + // (const float *)v_cache, (const int *)block_tables, + // (const int *)seq_lens, (const float *)alibi_slopes)); + // } else { + // return INFINI_STATUS_BAD_TENSOR_DTYPE; + // } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::paged_attention::cpu diff --git a/src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.h b/src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.h new file mode 100644 index 000000000..e8744ecde --- /dev/null +++ b/src/infiniop/ops/paged_attention/cpu/paged_attention_cpu.h @@ -0,0 +1,10 @@ +#ifndef __PAGED_ATTENTION_CPU_H__ +#define __PAGED_ATTENTION_CPU_H__ + +#include "../paged_attention.h" + +// Use the DESCRIPTOR macro to generate the class declaration +// for the 'cpu' namespace. +DESCRIPTOR(cpu) + +#endif // __PAGED_ATTENTION_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_attention/cuda/kernel.cuh b/src/infiniop/ops/paged_attention/cuda/kernel.cuh new file mode 100644 index 000000000..1f50e68ea --- /dev/null +++ b/src/infiniop/ops/paged_attention/cuda/kernel.cuh @@ -0,0 +1,222 @@ +#ifndef __PAGED_ATTENTION_KERNEL_CUH__ +#define __PAGED_ATTENTION_KERNEL_CUH__ + +#include +#include +#include +#include + +// This kernel is refactored to be high-performance, adopting parallelism strategies +// from industry-standard implementations like vLLM. It fixes functional and performance +// issues in the original draft. + +namespace op::paged_attention::cuda { + +//================================================================================ +// MODIFICATION: Step 1 - Create our own parallel dot product helper function. +// This function is directly inspired by the `sum` function in `reduce.cuh`. +//================================================================================ +// template +// __device__ __forceinline__ Tcompute dot_product( +// const Tcompute* q_shared, +// const Tdata* k_vec_ptr +// ) { +// // Phase 1: Each thread computes a partial sum of (q * k) +// Tcompute partial_sum = 0.0f; +// for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { +// partial_sum += q_shared[i] * static_cast(k_vec_ptr[i]); +// } + +// // Phase 2: Reduce the partial sums from all threads in the block using CUB. +// using BlockReduce = cub::BlockReduce; +// __shared__ typename BlockReduce::TempStorage temp_storage; + +// return BlockReduce(temp_storage).Sum(partial_sum); +// } + + +template +__device__ void pagedAttentionKernel( + Tdata* out_, + const Tdata* q_, + const Tdata* k_cache_, + const Tdata* v_cache_, + const int32_t* block_tables_, + const int32_t* seq_lens_, + const float* alibi_slopes_, + const size_t num_kv_heads, + const float scale, + const size_t max_num_blocks_per_seq, + const size_t block_size, + const ptrdiff_t q_stride, + const ptrdiff_t kv_block_stride, + const ptrdiff_t kv_head_stride +) { + //================================================================================ + // 1. Setup & Query Loading (No changes in this section) + //================================================================================ + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("pagedAttentionKernel;\n"); + } + __syncthreads(); + const int seq_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int batch_size = gridDim.y; + const int32_t seq_len = seq_lens_[seq_idx]; + if (seq_len == 0) return; + + const size_t num_queries_per_kv = num_heads / num_kv_heads; + const size_t kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = (alibi_slopes_ == nullptr) ? 0.0f : alibi_slopes_[head_idx]; + + const int32_t* block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; + + const Tdata* q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata* out_ptr = out_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + + extern __shared__ char shared_mem_char[]; + Tcompute* shared_mem = reinterpret_cast(shared_mem_char); + Tcompute* q_shared = shared_mem; + Tcompute* logits = shared_mem + HEAD_SIZE; + + // printf("static_cast(q_ptr[i]);"); + for (size_t i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + q_shared[i] = static_cast(q_ptr[i]); + } + __syncthreads(); + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("q_shared over;\n"); + } + __syncthreads(); + + const int32_t physical_block_num1 = block_table[62]; + + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("k_shared start;%d\n", physical_block_num1); + } + __syncthreads(); + + //================================================================================ + // 2. Compute QK Dot Product & Find Max Logit + //================================================================================ + for (size_t token_idx = threadIdx.x; token_idx < seq_len; token_idx += NUM_THREADS) { + const int32_t block_idx = token_idx / block_size; + const int32_t token_in_block_idx = token_idx % block_size; + if (token_idx == seq_len-1 || token_idx == 0) { + printf("block_idx: %d\n", block_idx); + } + // const int32_t physical_block_num = 0 + const int32_t physical_block_num = block_table[block_idx]; + if (token_idx == seq_len-1 || token_idx == 0) { + printf("physical_block_num: %d\n", physical_block_num); + } + + const Tdata* k_vec_ptr = k_cache_ + physical_block_num * kv_block_stride + kv_head_idx * kv_head_stride + token_in_block_idx * HEAD_SIZE; + + //================================================================================ + // MODIFICATION: Step 2 - Integrate the new parallel dot product function. + // The slow, serial `for` loop is replaced with a single call. + //================================================================================ + // Tcompute qk = dot_product(q_shared, k_vec_ptr); + // __syncthreads(); // Sync is needed here because dot_product uses shared memory. + // printf("Integrate the new parallel dot product;"); + Tcompute qk = 0.0f; +#pragma unroll + for (size_t i = 0; i < HEAD_SIZE / 8; ++i) { + const size_t offset = i * 8; + + // 手动展开8次计算 + qk += q_shared[offset + 0] * static_cast(k_vec_ptr[offset + 0]); + qk += q_shared[offset + 1] * static_cast(k_vec_ptr[offset + 1]); + qk += q_shared[offset + 2] * static_cast(k_vec_ptr[offset + 2]); + qk += q_shared[offset + 3] * static_cast(k_vec_ptr[offset + 3]); + qk += q_shared[offset + 4] * static_cast(k_vec_ptr[offset + 4]); + qk += q_shared[offset + 5] * static_cast(k_vec_ptr[offset + 5]); + qk += q_shared[offset + 6] * static_cast(k_vec_ptr[offset + 6]); + qk += q_shared[offset + 7] * static_cast(k_vec_ptr[offset + 7]); + } + + qk *= scale; + if (alibi_slope != 0.0f) { + qk += alibi_slope * (token_idx - seq_len + 1); + } + + logits[token_idx] = qk; + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("compute global_qk_max start;\n"); + } + __syncthreads(); + + __shared__ Tcompute global_qk_max; + Tcompute global_qk_max_0 = op::common_cuda::reduce_op::max(logits, seq_len); + + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("Cmmon_cuda::reduce_op over;\n"); + } + __syncthreads(); + if (threadIdx.x == 0) { + global_qk_max = global_qk_max_0; + } + __syncthreads(); + + //================================================================================ + // 3. Compute Softmax (No changes in this section) + //================================================================================ + // printf("Compute Softmax;"); + // Tcompute exp_sum = 0.0f; + for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) { + Tcompute val = expf(logits[i] - global_qk_max); // 使用全局最大值 + logits[i] = val; + } + __syncthreads(); + + __shared__ Tcompute inv_sum; + Tcompute exp_sum_0 = op::common_cuda::reduce_op::sum(logits, seq_len); + if (threadIdx.x == 0) { + inv_sum = 1.0f / (exp_sum_0 + 1e-6f); + } + __syncthreads(); + + for (size_t i = threadIdx.x; i < seq_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("gregate Values (V) weighted by probabil start;\n"); + } + __syncthreads(); + + //================================================================================ + // 4. Aggregate Values (V) weighted by probabilities + //================================================================================ + // printf("Aggregate Values;"); + for (size_t h_dim = threadIdx.x; h_dim < HEAD_SIZE; h_dim += NUM_THREADS) { + Tcompute acc = 0.0f; + + for (size_t token_idx = 0; token_idx < seq_len; ++token_idx) { + const size_t block_idx = token_idx / block_size; + const size_t token_in_block_idx = token_idx % block_size; + const int32_t physical_block_num = block_table[block_idx]; + const Tcompute prob = logits[token_idx]; + + const Tdata* v_vec_ptr = v_cache_ + + physical_block_num * kv_block_stride + + kv_head_idx * kv_head_stride + + token_in_block_idx * HEAD_SIZE; + + const Tdata v_val = v_vec_ptr[h_dim]; + acc += prob * static_cast(v_val); + } + out_ptr[h_dim] = static_cast(acc); + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + printf("gregate Values (V) weighted by probabil over;\n"); + } + __syncthreads(); +} + +} // namespace op::paged_attention::cuda + +#endif // __PAGED_ATTENTION_KERNEL_CUH__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_attention/info.h b/src/infiniop/ops/paged_attention/info.h new file mode 100644 index 000000000..db63a649f --- /dev/null +++ b/src/infiniop/ops/paged_attention/info.h @@ -0,0 +1,100 @@ +#ifndef __PAGED_ATTENTION_INFO_H__ +#define __PAGED_ATTENTION_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op::paged_attention { + +class PagedAttentionInfo { + PagedAttentionInfo() = default; + +public: + // --- Data Types and Scale --- + infiniDtype_t dtype; + float scale; + + // --- Shape Dimensions --- + size_t num_seqs; + size_t num_heads; + size_t num_kv_heads; + size_t head_size; + size_t block_size; + size_t max_num_blocks_per_seq; + // size_t max_seq_len; + + // --- Strides for Memory Layout --- + ptrdiff_t q_stride; + ptrdiff_t kv_block_stride; + ptrdiff_t kv_head_stride; + + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + const std::optional& alibi_slopes_desc, + float scale) { + + auto dtype = q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || + block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // --- Extract shape dimensions --- + auto q_shape = q_desc->shape(); + auto k_cache_shape = k_cache_desc->shape(); + // k_cache_shape: [num_blocks, num_kv_heads, block_size, head_size] + + size_t num_seqs = q_shape[0]; + size_t num_heads = q_shape[1]; + size_t head_size = q_shape[2]; + + size_t num_kv_heads = k_cache_shape[1]; + size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠 + size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // --- Calculate max_seq_len for shared memory allocation --- + // This is a safe upper bound. + // info.max_seq_len = info.max_num_blocks_per_seq * info.block_size; + + // --- Extract strides for memory access --- + ptrdiff_t q_stride = q_desc->stride(0); + ptrdiff_t kv_block_stride = k_cache_desc->stride(0); + ptrdiff_t kv_head_stride = k_cache_desc->stride(1); + // Note: We assume k_cache and v_cache have compatible strides. + // A more robust implementation could check v_cache_desc->stride() as well. + + // --- Check for optional features --- + // info.has_alibi = alibi_slopes_desc.has_value(); + + return utils::Result(PagedAttentionInfo{ + dtype, + scale, + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_num_blocks_per_seq, + q_stride, + kv_block_stride, + kv_head_stride + }); + } +}; + +} // namespace op::paged_attention + +#endif // __PAGED_ATTENTION_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu new file mode 100644 index 000000000..9a9679f71 --- /dev/null +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu @@ -0,0 +1,162 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "paged_attention_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" +#include +#include + +template +INFINIOP_CUDA_KERNEL pagedAttention( + Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache, + const int32_t *block_tables, const int32_t *seq_lens, const float *alibi_slopes, + const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq, + const size_t block_size, + const ptrdiff_t q_stride, const ptrdiff_t kv_block_stride, const ptrdiff_t kv_head_stride + ) { + op::paged_attention::cuda::pagedAttentionKernel( + out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale, + max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride); +} + +namespace op::paged_attention::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + const std::optional& alibi_slopes_desc, + float scale + ) { + auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale); + CHECK_RESULT(info); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache, + infiniDtype_t dtype, + const void *block_tables, const void *seq_lens, const void *alibi_slopes, + size_t num_heads, size_t num_seqs, + size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, + ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, + cudaStream_t stream) { + printf("launchKernel; num_heads: %zu, num_seqs: %zu, num_kv_heads: %zu, scale: %f, max_num_blocks_per_seq: %zu, block_size: %zu, q_stride: %zu, kv_block_stride: %zu, kv_head_stride: %zu\n", num_heads, num_seqs, num_kv_heads, scale, max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride); + dim3 grid(uint32_t(num_heads), uint32_t(num_seqs), 1); + dim3 block(NUM_THREADS); + + // size_t shared_mem_size = 16; + if (dtype == INFINI_DTYPE_F16) { + size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(half); + // size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(Tcompute); + pagedAttention + <<>>( + (half*)out, + (const half*)q, (const half*)k_cache, (const half*)v_cache, + (const int32_t*)block_tables, (const int32_t*)seq_lens, (const float*)alibi_slopes, num_kv_heads, + scale, max_num_blocks_per_seq, block_size, + q_stride, kv_block_stride, kv_head_stride + ); + } else if (dtype == INFINI_DTYPE_BF16) { + size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(__nv_bfloat16); + pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> + <<>>( + (__nv_bfloat16*)out, (const __nv_bfloat16*)q, (const __nv_bfloat16*)k_cache, (const __nv_bfloat16*)v_cache, + (const int32_t*)block_tables, (const int32_t*)seq_lens, (const float*)alibi_slopes, num_kv_heads, + scale, max_num_blocks_per_seq, block_size, + q_stride, kv_block_stride, kv_head_stride + ); + } else if (dtype == INFINI_DTYPE_F32) { + size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); + pagedAttention + <<>>( + (float*)out, (const float*)q, (const float*)k_cache, (const float*)v_cache, + (const int32_t*)block_tables, (const int32_t*)seq_lens, (const float*)alibi_slopes, num_kv_heads, + scale, max_num_blocks_per_seq, block_size, + q_stride, kv_block_stride, kv_head_stride + ); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *seq_lens, const void *alibi_slopes, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + if (_info.head_size == 128) { + launchKernel<128, CUDA_BLOCK_SIZE_1024>( + out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, + _info.num_heads, _info.num_seqs, + _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, + _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, + stream); + + } + // else if head_size=128, block_size=16)for llama + else { + printf("head_size: %zu\n", _info.head_size); + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + if (_info.head_size == 128 ) { + launchKernel<128, CUDA_BLOCK_SIZE_512>( + out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, + _info.num_heads, _info.num_seqs, + _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, + _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, + stream); + + } + // else if head_size=128, block_size=16)for llama + else { + printf("head_size: %zu\n", _info.head_size); + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + if (_info.head_size == 128 ) { + launchKernel<128, CUDA_BLOCK_SIZE_4096>( + out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, + _info.num_heads, _info.num_seqs, + _info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, + _info.q_stride, _info.kv_block_stride, _info.kv_head_stride, + stream); + } + // else if head_size=128, block_size=16)for llama + else { + printf("head_size: %zu", _info.head_size); + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + + return INFINI_STATUS_SUCCESS; +} + +} \ No newline at end of file diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh new file mode 100644 index 000000000..886c716b2 --- /dev/null +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_NVIDIA_H__ +#define __PAGED_ATTENTION_NVIDIA_H__ + +#include "../paged_attention.h" + +DESCRIPTOR(nvidia) + +#endif // __PAGED_ATTENTION_NVIDIA_H__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_attention/operator.cc b/src/infiniop/ops/paged_attention/operator.cc new file mode 100644 index 000000000..7defbd490 --- /dev/null +++ b/src/infiniop/ops/paged_attention/operator.cc @@ -0,0 +1,138 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/paged_attention.h" + +// Add necessary includes for different platforms +#ifdef ENABLE_CPU_API +// #include "cpu/paged_attention_cpu.h" // Placeholder for future CPU implementation +#endif +#if defined(ENABLE_NVIDIA_API) +#include "nvidia/paged_attention_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +// #include "metax/paged_attention_metax.h" // Placeholder +#endif +#ifdef ENABLE_ASCEND_API +// #include "ascend/paged_attention_ascend.h" // Placeholder +#endif + +__C infiniStatus_t infiniopCreatePagedAttentionDescriptor( + infiniopHandle_t handle, + infiniopPagedAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t alibi_slopes_desc, + float scale + ) { + + std::optional alibi_opt = + (alibi_slopes_desc == nullptr) ? std::nullopt : std::optional(alibi_slopes_desc); + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::paged_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_opt, scale); + + switch (handle->device) { +#ifdef ENABLE_CPU_API + // CREATE(INFINI_DEVICE_CPU, cpu) // Uncomment when CPU implementation is ready +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + // CREATE(INFINI_DEVICE_METAX, metax) // Placeholder for future Metax implementation +#endif +#ifdef ENABLE_ASCEND_API + // CREATE(INFINI_DEVICE_ASCEND, ascend) // Placeholder for future Ascend implementation +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize( + infiniopPagedAttentionDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // GET(INFINI_DEVICE_CPU, cpu) // Uncomment when CPU implementation is ready +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + // GET(INFINI_DEVICE_METAX, metax) // Placeholder +#endif +#ifdef ENABLE_ASCEND_API + // GET(INFINI_DEVICE_ASCEND, ascend) // Placeholder +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopPagedAttention( + infiniopPagedAttentionDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *seq_lens, const void *alibi_slopes, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \ + seq_lens, alibi_slopes, stream); + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // CALCULATE(INFINI_DEVICE_CPU, cpu) // Uncomment when CPU implementation is ready +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + // CALCULATE(INFINI_DEVICE_METAX, metax) // Placeholder +#endif +#ifdef ENABLE_ASCEND_API + // CALCULATE(INFINI_DEVICE_ASCEND, ascend) // Placeholder +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroyPagedAttentionDescriptor( + infiniopPagedAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + // DESTROY(INFINI_DEVICE_CPU, cpu) // Uncomment when CPU implementation is ready +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_METAX_API + // DESTROY(INFINI_DEVICE_METAX, metax) // Placeholder +#endif +#ifdef ENABLE_ASCEND_API + // DESTROY(INFINI_DEVICE_ASCEND, ascend) // Placeholder +#endif + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} \ No newline at end of file diff --git a/src/infiniop/ops/paged_attention/paged_attention.h b/src/infiniop/ops/paged_attention/paged_attention.h new file mode 100644 index 000000000..53c1fe83f --- /dev/null +++ b/src/infiniop/ops/paged_attention/paged_attention.h @@ -0,0 +1,53 @@ +#ifndef PAGED_ATTENTION_H +#define PAGED_ATTENTION_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::paged_attention::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + PagedAttentionInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + PagedAttentionInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_cache_desc, \ + infiniopTensorDescriptor_t v_cache_desc, \ + infiniopTensorDescriptor_t block_tables_desc, \ + infiniopTensorDescriptor_t seq_lens_desc, \ + const std::optional& alibi_slopes_desc,\ + float scale); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, const void *q, const void *k_cache, const void *v_cache,\ + const void *block_tables, const void *seq_lens, \ + const void *alibi_slopes, \ + void *stream) const; \ + }; \ + } + +#endif // PAGED_ATTENTION_H \ No newline at end of file diff --git a/src/infiniop/ops/paged_caching/cuda/kernel.cuh b/src/infiniop/ops/paged_caching/cuda/kernel.cuh new file mode 100644 index 000000000..8418aaf03 --- /dev/null +++ b/src/infiniop/ops/paged_caching/cuda/kernel.cuh @@ -0,0 +1,147 @@ +#ifndef __PAGED_CACHING_KERNEL_CUH__ +#define __PAGED_CACHING_KERNEL_CUH__ + +#include + +//================================================================================ +// Paged Caching Operator CUDA Kernel +// +// This kernel implements the "paged_caching" operation, which copies Key and Value +// vectors from a contiguous source tensor into a paged, non-contiguous KV Cache. +// +// Design Principles: +// 1. Token-Centric Parallelism: A 1D grid of `num_tokens` is launched. Each CUDA +// block is responsible for caching one full token (all its heads). +// 2. Coalesced Memory Access: This grid strategy ensures that threads within a +// block read a large, contiguous chunk of memory from the source tensors, +// maximizing memory bandwidth utilization. +// 3. Vectorization: The copy operation is vectorized to further enhance memory +// throughput, processing multiple data elements in a single instruction. +//================================================================================ + +namespace op::paged_caching::cuda { + +template < + typename Tdata, // Data type of the tensors (e.g., half, __nv_bfloat16) + int NUM_THREADS // Number of threads per block, configured at launch time +> +__device__ void pagedCachingKernel( + // ----- Output Tensors ----- + Tdata* k_cache_ptr, // Pointer to the destination K cache pool + Tdata* v_cache_ptr, // Pointer to the destination V cache pool + // ----- Input Tensors ----- + const Tdata* k_ptr, // Pointer to the source Keys, shape [ntok, nkvh, dh] + const Tdata* v_ptr, // Pointer to the source Values, shape [ntok, nkvh, dh] + const int* slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok] + // ----- Metadata ----- + const int num_heads, // Number of key/value heads (nkvh) + const int head_size, // Dimension of each head (dh) + const int block_size, // Number of tokens per block in the KV cache + // ----- Stride Information ----- + const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor + const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor + const ptrdiff_t k_cache_block_stride, // Stride between blocks in the K cache pool + const ptrdiff_t v_cache_block_stride // Stride between blocks in the V cache pool +) { + //================================================================================ + // 1. Identify Work Unit & Calculate Addresses + //================================================================================ + + // Each block processes one token. + const int token_idx = blockIdx.x; + const int head_idx = blockIdx.y; + + // Retrieve the destination slot for the current token. + const int slot_idx = slot_mapping_ptr[token_idx]; + + // Handle padding: if slot_idx is negative, this token is padding and should be ignored. + if (slot_idx < 0) { + return; + } + + // if (blockIdx.x == 0 && threadIdx.x == 0) { + // printf("[Block %d, Thread %d] Debug Start\n", blockIdx.x, threadIdx.x); + // printf(" - token_idx: %d\n", token_idx); + // printf(" - slot_idx from mapping: %d\n", slot_idx); + // printf(" - Metadata: num_heads=%d, head_size=%d, block_size=%d\n", num_heads, head_size, block_size); + // } + + // Calculate the physical block index and the offset within that block. + const int physical_block_idx = slot_idx / block_size; + const int block_offset = slot_idx % block_size; + + // Calculate base pointers for source and destination for this specific token. + const Tdata* k_src_head_ptr = k_ptr + token_idx * k_src_stride + head_idx * head_size; + const Tdata* v_src_head_ptr = v_ptr + token_idx * v_src_stride + head_idx * head_size; + + + // Destination pointer calculation assumes a [num_blocks, block_size, num_heads, head_size] layout. + // We point to the beginning of the memory region for this token's slot. + const ptrdiff_t cache_head_stride = block_size * head_size; + + Tdata* k_cache_block_base_ptr = k_cache_ptr + physical_block_idx * k_cache_block_stride; + Tdata* k_dst_head_ptr = k_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; + + Tdata* v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride; + Tdata* v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; + + + // if (blockIdx.x == 0 && threadIdx.x == 0) { + // printf("[Block %d, Thread %d] Address Calculation\n", blockIdx.x, threadIdx.x); + // printf(" - physical_block_idx: %d\n", physical_block_idx); + // printf(" - block_offset: %d\n", block_offset); + // printf(" - k_src_stride: %ld, v_src_stride: %ld\n", k_src_stride, v_src_stride); + // printf(" - k_cache_block_stride: %ld, v_cache_block_stride: %ld\n", k_cache_block_stride, v_cache_block_stride); + // printf(" - Calculated dst_token_offset_k: %d\n", dst_token_offset_k); + // printf(" - Source Ptr (k): %p\n", k_src_token_ptr); + // printf(" - Dest Ptr (k_cache): %p\n", k_cache_dst_ptr); + // } + + + // //================================================================================ + // // 2. Perform Vectorized Data Copy + // //================================================================================ + + // // Total number of elements to copy for one token (all heads). + // const int total_elements_per_token = num_heads * head_size; + + // // Use vectorization to copy data more efficiently. + // // For Tdata=half (2 bytes), float4 (16 bytes) can process 8 elements at once. + // constexpr int VEC_SIZE = sizeof(float4) / sizeof(Tdata); + + // // Cast pointers to the vectorized type. + // const float4* k_src_vec_ptr = reinterpret_cast(k_src_token_ptr); + // const float4* v_src_vec_ptr = reinterpret_cast(v_src_token_ptr); + // float4* k_cache_dst_vec_ptr = reinterpret_cast(k_cache_dst_ptr); + // float4* v_cache_dst_vec_ptr = reinterpret_cast(v_cache_dst_ptr); + + // // if (blockIdx.x == 0 && threadIdx.x == 0) { + // // printf("[Block %d, Thread %d] Vectorized Copy Start\n", blockIdx.x, threadIdx.x); + // // printf(" - Total elements per token: %d\n", total_elements_per_token); + // // printf(" - Vector size: %d\n", VEC_SIZE); + // // printf(" - float4 size: %d\n", sizeof(float4)); + // // printf(" - Tdata size: %d\n", sizeof(Tdata)); + // // // printf(" - Vector size: %d\n", k_src_vec_ptr[i]); + // // } + + // // Each thread copies one vector (VEC_SIZE elements) per iteration. + // // The loop iterates over the vectorized chunks of data for the token. + // for (int i = threadIdx.x; i < total_elements_per_token / VEC_SIZE; i += NUM_THREADS) { + + // k_cache_dst_vec_ptr[i] = k_src_vec_ptr[i]; + // v_cache_dst_vec_ptr[i] = v_src_vec_ptr[i]; + // } + // } + + //================================================================================ + // 2. Perform Element-wise Data Copy (Safe, Non-Vectorized) + //================================================================================ + for (int i = threadIdx.x; i < head_size; i += NUM_THREADS) { + k_dst_head_ptr[i] = k_src_head_ptr[i]; + v_dst_head_ptr[i] = v_src_head_ptr[i]; + } +} + +} // namespace op::paged_caching::cuda + +#endif // __PAGED_CACHING_KERNEL_CUH__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_caching/info.h b/src/infiniop/ops/paged_caching/info.h new file mode 100644 index 000000000..23a4da0ac --- /dev/null +++ b/src/infiniop/ops/paged_caching/info.h @@ -0,0 +1,77 @@ +// File: infiniop/ops/paged_caching/info.h + +#ifndef __PAGED_CACHING_INFO_H__ +#define __PAGED_CACHING_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include + +namespace op::paged_caching { + +class PagedCachingInfo { + PagedCachingInfo() = default; + +public: + // --- Data Type --- + infiniDtype_t dtype; + + // --- Shape Dimensions --- + size_t num_tokens; + size_t num_heads; + size_t head_size; + size_t block_size; + + // --- Strides for Memory Layout --- + ptrdiff_t k_src_stride; + ptrdiff_t v_src_stride; + ptrdiff_t k_cache_block_stride; + ptrdiff_t v_cache_block_stride; + + static utils::Result create( + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t slot_mapping_desc) { + + auto dtype = k_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (v_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (slot_mapping_desc->dtype() != INFINI_DTYPE_I32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || + v_cache_desc->ndim() < 4 || slot_mapping_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + PagedCachingInfo info; + info.dtype = dtype; + + // --- Extract shape dimensions --- + auto k_shape = k_desc->shape(); + auto k_cache_shape = k_cache_desc->shape(); + + info.num_tokens = slot_mapping_desc->shape()[0]; + info.num_heads = k_shape[1]; + info.head_size = k_shape[2]; + info.block_size = k_cache_shape[2]; // Assuming [num_blocks, num_heads, block_size, head_size] layout + + // --- Extract strides for memory access --- + info.k_src_stride = k_desc->stride(0); + info.v_src_stride = v_desc->stride(0); + info.k_cache_block_stride = k_cache_desc->stride(0); + info.v_cache_block_stride = v_cache_desc->stride(0); + + return utils::Result(info); + } +}; + +} // namespace op::paged_caching + +#endif // __PAGED_CACHING_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu b/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu new file mode 100644 index 000000000..fb63557db --- /dev/null +++ b/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu @@ -0,0 +1,143 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "paged_caching_nvidia.cuh" +#include "../cuda/kernel.cuh" + +// We assume some common headers from your library are available. +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +template +INFINIOP_CUDA_KERNEL pagedCaching( + Tdata *k_cache, Tdata *v_cache, + const Tdata *k, const Tdata *v, + const int *slot_mapping, + const int num_heads, const int head_size, const int block_size, + const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, + const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride + ) { + op::paged_caching::cuda::pagedCachingKernel( + k_cache, v_cache, k, v, slot_mapping, num_heads, head_size, block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); +} + +namespace op::paged_caching::nvidia { +// PIMPL struct definition +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +// Destructor implementation +Descriptor::~Descriptor() { + delete _opaque; +} + +// Static factory method implementation +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t slot_mapping_desc) { + + // Use the Info struct's factory to parse and validate tensor metadata. + // NOTE: The implementation of PagedCachingInfo::create is omitted for brevity, + // but it would extract shapes, dtypes, and strides from the descriptors. + auto info = PagedCachingInfo::create(k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc); + CHECK_RESULT(info); + + // Create and return the Descriptor instance. + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + + +// The launchKernel function is a templated helper to encapsulate the CUDA kernel launch. +// It sets up grid/block dimensions and calls the device-side kernel. +template +void launchKernel(const PagedCachingInfo& info, + void *k_cache, void *v_cache, + const void *k, const void *v, + const void *slot_mapping, + cudaStream_t stream) { + + // Grid dimension is 1D, with one block per token, as we decided. + dim3 grid(uint32_t(info.num_tokens), uint32_t(info.num_heads), 1); + // Block dimension is 1D, using the number of threads specified at compile time. + dim3 block(NUM_THREADS); + + // This kernel does not require dynamic shared memory. + size_t shared_mem_size = 0; + + // Launch the device-side CUDA kernel. + pagedCaching + <<>>( + (Tdata*)k_cache, + (Tdata*)v_cache, + (const Tdata*)k, + (const Tdata*)v, + (const int*)slot_mapping, + info.num_heads, + info.head_size, + info.block_size, + info.k_src_stride, + info.v_src_stride, + info.k_cache_block_stride, + info.v_cache_block_stride + ); +} + + +// Execution method implementation +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + const void *k, const void *v, + void *k_cache, void *v_cache, + const void *slot_mapping, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + + // Dispatch logic based on the GPU's maximum threads per block. + // This allows selecting the largest, most efficient block size the hardware supports. + if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_1024) { + // Dispatch based on data type for a 1024-thread block. + if (_info.dtype == INFINI_DTYPE_F16) { + launchKernel( + _info, k_cache, v_cache, k, v, slot_mapping, stream); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + launchKernel<__nv_bfloat16, CUDA_BLOCK_SIZE_1024>( + _info, k_cache, v_cache, k, v, slot_mapping, stream); + } else if (_info.dtype == INFINI_DTYPE_F32) { + launchKernel( + _info, k_cache, v_cache, k, v, slot_mapping, stream); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) { + // Dispatch based on data type for a 512-thread block. + if (_info.dtype == INFINI_DTYPE_F16) { + launchKernel( + _info, k_cache, v_cache, k, v, slot_mapping, stream); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + launchKernel<__nv_bfloat16, CUDA_BLOCK_SIZE_512>( + _info, k_cache, v_cache, k, v, slot_mapping, stream); + } else if (_info.dtype == INFINI_DTYPE_F32) { + launchKernel( + _info, k_cache, v_cache, k, v, slot_mapping, stream); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + // If the GPU is older and supports fewer threads, return an error. + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + // Check for any asynchronous errors launched by the kernel. + // return CHECK_CUDA(cudaGetLastError()); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::paged_caching::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh b/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh new file mode 100644 index 000000000..a6af6ea9c --- /dev/null +++ b/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __PAGED_CACHING_NVIDIA_H__ +#define __PAGED_CACHING_NVIDIA_H__ + +#include "../paged_caching.h" + +DESCRIPTOR(nvidia) + +#endif // __PAGED_CACHING_NVIDIA_H__ \ No newline at end of file diff --git a/src/infiniop/ops/paged_caching/operator.cc b/src/infiniop/ops/paged_caching/operator.cc new file mode 100644 index 000000000..0d1c04626 --- /dev/null +++ b/src/infiniop/ops/paged_caching/operator.cc @@ -0,0 +1,93 @@ +// File: infiniop/ops/paged_caching/operator.cc + +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/paged_caching.h" // Assuming this is the public API header + +// Add necessary includes for different platforms +#ifdef ENABLE_NVIDIA_API +#include "nvidia/paged_caching_nvidia.cuh" +#endif +// ... other platforms + +__C infiniStatus_t infiniopCreatePagedCachingDescriptor( + infiniopHandle_t handle, + infiniopPagedCachingDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t slot_mapping_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::paged_caching::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + k_desc, v_desc, k_cache_desc, v_cache_desc, slot_mapping_desc); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + // ... other platforms + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetPagedCachingWorkspaceSize( + infiniopPagedCachingDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + // ... other platforms + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopPagedCaching( + infiniopPagedCachingDescriptor_t desc, + void *workspace, size_t workspace_size, + const void *k, const void *v, + void *k_cache, void *v_cache, + const void *slot_mapping, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, k, v, k_cache, v_cache, slot_mapping, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + // ... other platforms请先用中文分析并阐述我的意图,再根据我的意图回答我的问题。 + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroyPagedCachingDescriptor( + infiniopPagedCachingDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + // ... other platforms + } + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} \ No newline at end of file diff --git a/src/infiniop/ops/paged_caching/paged_caching.h b/src/infiniop/ops/paged_caching/paged_caching.h new file mode 100644 index 000000000..39d827272 --- /dev/null +++ b/src/infiniop/ops/paged_caching/paged_caching.h @@ -0,0 +1,52 @@ +// File: infiniop/ops/paged_caching/paged_caching.h + +#ifndef PAGED_CACHING_H +#define PAGED_CACHING_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::paged_caching::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + PagedCachingInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + PagedCachingInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t k_cache_desc, \ + infiniopTensorDescriptor_t v_cache_desc, \ + infiniopTensorDescriptor_t slot_mapping_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + const void *k, const void *v, \ + void *k_cache, void *v_cache, \ + const void *slot_mapping, \ + void *stream) const; \ + }; \ + } + +#endif // PAGED_CACHING_H \ No newline at end of file diff --git a/test/infiniop-test/test_generate/testcases/paged_attention.py b/test/infiniop-test/test_generate/testcases/paged_attention.py new file mode 100644 index 000000000..0429c9cb4 --- /dev/null +++ b/test/infiniop-test/test_generate/testcases/paged_attention.py @@ -0,0 +1,110 @@ +import numpy as np +import gguf +from typing import List +from enum import Enum, auto + +# Assuming these helpers are in a shared utility file +from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides + +# ============================================================================== +# NumPy Reference Implementation +# ============================================================================== +def ref_paged_attention_np(q, k_cache, v_cache, block_tables, seq_lens, scale, alibi_slopes): + # This is a simplified NumPy implementation for correctness checking. + # It mirrors the logic of the PyTorch reference. + output = np.empty_like(q, dtype=np.float64) + num_seqs, num_heads, head_size = q.shape + num_kv_heads = v_cache.shape[1] + num_queries_per_kv = num_heads // num_kv_heads + block_size = v_cache.shape[3] + + for i in range(num_seqs): + seq_len = seq_lens[i] + q_i = q[i] + + keys_lst = [] + values_lst = [] + for j in range(seq_len): + block_num = block_tables[i, j // block_size] + block_off = j % block_size + k = k_cache[block_num, :, :, block_off, :].reshape(num_kv_heads, head_size) + v = v_cache[block_num, :, :, block_off] + keys_lst.append(k) + values_lst.append(v) + + keys = np.stack(keys_lst, axis=0) + values = np.stack(values_lst, axis=0) + if num_queries_per_kv > 1: + keys = np.repeat(keys, num_queries_per_kv, axis=1) + values = np.repeat(values, num_queries_per_kv, axis=1) + + # einsum in numpy: qhd,khd->hqk + attn_scores = np.einsum('hd,khd->hk', q_i, keys) * scale + + if alibi_slopes is not None: + pos = np.arange(seq_len) + alibi_bias = (pos - seq_len + 1).astype(np.float32) + alibi_bias = alibi_slopes.reshape(-1, 1) * alibi_bias.reshape(1, -1) + attn_scores += alibi_bias + + exp_scores = np.exp(attn_scores - np.max(attn_scores, axis=-1, keepdims=True)) + probs = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True) + + # einsum in numpy: hqk,khd->qhd -> hd + out_i = np.einsum('hk,khd->hd', probs, values) + output[i] = out_i + + return output + +# ============================================================================== +# Test Case Definition and Generation +# ============================================================================== +class PagedAttentionTestCase(InfiniopTestCase): + def __init__(self, **kwargs): + super().__init__("paged_attention") + self.tensors = kwargs + + def write_test(self, test_writer: "InfiniopTestWriter"): + super().write_test(test_writer) + for name, tensor in self.tensors.items(): + test_writer.add_tensor(test_writer.gguf_key(name), tensor, raw_dtype=np_dtype_to_ggml(tensor.dtype)) + + ans = ref_paged_attention_np( + self.tensors["q"].astype(np.float64), + self.tensors["k_cache"].astype(np.float64), + self.tensors["v_cache"].astype(np.float64), + self.tensors["block_tables"], + self.tensors["seq_lens"], + self.tensors["scale"].item(), + self.tensors.get("alibi_slopes", None), + ) + test_writer.add_tensor(test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64) + + +if __name__ == "__main__": + test_writer = InfiniopTestWriter("paged_attention.gguf") + test_cases = [] + + # Test case configurations + _TEST_CASES_ = [(7, 40, 40, 128, 16, 1024), (5, 64, 8, 80, 32, 2048)] + _TENSOR_DTYPES_ = [np.float16, np.float32] + _NUM_BLOCKS = 2048 + + for dtype in _TENSOR_DTYPES_: + for num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len in _TEST_CASES_: + scale = 1.0 / (head_size**0.5) + x = 16 // dtype().itemsize + + tensors = { + "q": np.random.randn(num_seqs, num_heads, head_size).astype(dtype), + "k_cache": np.random.randn(_NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x).astype(dtype), + "v_cache": np.random.randn(_NUM_BLOCKS, num_kv_heads, head_size, block_size).astype(dtype), + "seq_lens": np.random.randint(1, max_seq_len, num_seqs, dtype=np.int32), + "block_tables": np.random.randint(0, _NUM_BLOCKS, (num_seqs, (max_seq_len + block_size - 1) // block_size), dtype=np.int32), + "scale": np.array(scale, dtype=np.float32), + "out": np.empty((num_seqs, num_heads, head_size), dtype=dtype) + } + test_cases.append(PagedAttentionTestCase(**tensors)) + + test_writer.add_tests(test_cases) + test_writer.save() \ No newline at end of file diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index aae69c153..99cde5f16 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -237,6 +237,90 @@ def mul_(lib): ] +@OpRegister.operator +def paged_attention_(lib): + lib.infiniopCreatePagedAttentionDescriptor.restype = c_int32 + lib.infiniopCreatePagedAttentionDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_float, + ] + + lib.infiniopGetPagedAttentionWorkspaceSize.restype = c_int32 + lib.infiniopGetPagedAttentionWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopPagedAttention.restype = c_int32 + lib.infiniopPagedAttention.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyPagedAttentionDescriptor.restype = c_int32 + lib.infiniopDestroyPagedAttentionDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + +# File: python/infinicore/op_register.py (or similar) + +@OpRegister.operator +def paged_caching_(lib): + lib.infiniopCreatePagedCachingDescriptor.restype = c_int32 + lib.infiniopCreatePagedCachingDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, # k_desc + infiniopTensorDescriptor_t, # v_desc + infiniopTensorDescriptor_t, # k_cache_desc + infiniopTensorDescriptor_t, # v_cache_desc + infiniopTensorDescriptor_t, # slot_mapping_desc + ] + + # infiniopGetPagedCachingWorkspaceSize + lib.infiniopGetPagedCachingWorkspaceSize.restype = c_int32 + lib.infiniopGetPagedCachingWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + # infiniopPagedCaching + lib.infiniopPagedCaching.restype = c_int32 + lib.infiniopPagedCaching.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, # workspace + c_size_t, # workspace_size + c_void_p, # k + c_void_p, # v + c_void_p, # k_cache + c_void_p, # v_cache + c_void_p, # slot_mapping + c_void_p, # stream + ] + + # infiniopDestroyPagedCachingDescriptor + lib.infiniopDestroyPagedCachingDescriptor.restype = c_int32 + lib.infiniopDestroyPagedCachingDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + @OpRegister.operator def random_sample_(lib): lib.infiniopCreateRandomSampleDescriptor.restype = c_int32 diff --git a/test/infiniop/paged_attention.py b/test/infiniop/paged_attention.py new file mode 100644 index 000000000..212d625f2 --- /dev/null +++ b/test/infiniop/paged_attention.py @@ -0,0 +1,308 @@ +import torch +import ctypes +import random +from ctypes import c_uint32, c_float, c_uint64, c_size_t, POINTER, addressof +import math +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + +# ============================================================================== +# Reference Implementation +# ============================================================================== +def get_alibi_slopes(n): + # 简化版的ALiBi斜率计算方法 + # 参考: https://github.com/ofirpress/attention_with_linear_biases/blob/master/fairseq/models/transformer.py#L742 + closest_power_of_2 = 2**math.floor(math.log2(n)) + base = 2**(-2**-(math.log2(closest_power_of_2) - 3)) + powers = [base**i for i in range(1, closest_power_of_2 + 1)] + if n > closest_power_of_2: + extra = [base**(i * 2) for i in range(1, 2 * (n - closest_power_of_2) + 1, 2)] + powers += extra + return powers[:n] + +def ref_masked_attention(query, key, value, scale, attn_mask=None): + # Reference implementation for a single masked attention head. + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + +def ref_single_query_cached_kv_attention(query, key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes): + # Reference implementation for paged attention, iterating through each sequence. + output = torch.empty_like(query) + num_query_heads, num_kv_heads = query.shape[1], value_cache.shape[1] + num_queries_per_kv = num_query_heads // num_kv_heads + head_size, block_size = value_cache.shape[3], value_cache.shape[2] + num_seqs = query.shape[0] + + for i in range(num_seqs): + q = query[i].unsqueeze(0) + seq_len = seq_lens[i].item() + block_table = block_tables[i] + + keys_lst, values_lst = [], [] + for j in range(seq_len): + block_num = block_table[j // block_size].item() + block_off = j % block_size + # k = key_cache[block_num, :, :, block_off, :].reshape(num_kv_heads, head_size) + k = key_cache[block_num, :, block_off, :] + v = value_cache[block_num, :, block_off, :] + keys_lst.append(k) + values_lst.append(v) + + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_queries_per_kv > 1: + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + # alibi_bias = None + # if alibi_slopes is not None: + # pos = torch.arange(seq_len, device=query.device).int() + # alibi_bias = (pos - seq_len + 1).float() + # alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) + alibi_bias = None + if alibi_slopes is not None: + pos = torch.arange(seq_len, device=query.device).int() + alibi_bias = (pos - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + output[i] = out.view(num_query_heads, head_size) + return output + +# ============================================================================== +# Test Configuration +# ============================================================================== +# (num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, use_alibi) +_TEST_CASES_ = [ + # (7, 40, 40, 128, 16, 1024, True), + # (1, 1, 1, 128, 16, 1024, False), + (1, 40, 40, 128, 16, 1024, False), + # (5, 8, 8, 128, 16, 1024, True), + # (5, 64, 8, 80, 16, 2048, True), + # (5, 64, 8, 80, 16, 2048, False), +] + +# Data types for testing +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, +} + +# Global flags for controlling test behavior +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + +def test( + handle, + device, + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_seq_len, + use_alibi, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing PagedAttention on {InfiniDeviceNames[device]} with " + f"num_seqs={num_seqs}, num_heads={num_heads}, head_size={head_size}, " + f"block_size={block_size}, dtype={InfiniDtypeNames[dtype]}, use_alibi={use_alibi}" + ) + + scale = 1.0 / (head_size**0.5) + # num_blocks = 2048 # A reasonable number for testing + max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + num_blocks = num_seqs*max_blocks_per_seq # A reasonable number for testing + + # Create input tensors + q = TestTensor((num_seqs, num_heads, head_size), None, dtype, device) + out = TestTensor((num_seqs, num_heads, head_size), None, dtype, device) + k_cache = TestTensor((num_blocks, num_kv_heads, block_size, head_size), None, dtype, device) + v_cache = TestTensor((num_blocks, num_kv_heads, block_size, head_size), None, dtype, device) + + + seq_lens_direct = 724 + seq_lens_torch = torch.randint(seq_lens_direct, seq_lens_direct+1, (num_seqs,), dtype=torch.int32) + # seq_lens_torch = torch.randint(max_seq_len-1, max_seq_len, (num_seqs,), dtype=torch.int32) + print(f"seq_lens_torch: {seq_lens_torch}") + seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I32, device) + + + # seq_lens = [random.randint(1, max_seq_len - 1) for _ in range(num_seqs)] + # seq_lens_ptr = (c_size_t * len(seq_lens))(*seq_lens) + # seq_lens_length = len(seq_lens) + # print(f"The length of seq_lens_ is: {seq_lens_length}") + # cpu_address = addressof(seq_lens_ptr) + + # print(f"--- On Python Side (CPU) ---") + # print(f"ctypes CPU array address (integer): {cpu_address}") + # print(f"ctypes CPU array address (hex): {hex(cpu_address)}") + + # block_tables_py = torch.randint(0, num_blocks, (num_seqs, max_blocks_per_seq), dtype=torch.int32) + block_tables_py = torch.arange(0, num_seqs*max_blocks_per_seq, dtype=torch.int32).view(num_seqs, max_blocks_per_seq) + block_tables = TestTensor.from_torch(block_tables_py, InfiniDtype.I32, device) + # block_tables = [[random.randint(0, num_blocks - 1) for _ in range(max_blocks_per_seq)] for _ in range(num_seqs)] + # flat_block_tables = [item for sublist in block_tables for item in sublist] + # block_tables_ptr = (c_size_t * len(flat_block_tables))(*flat_block_tables) + + + alibi_slopes_desc = ctypes.c_void_p(0) + alibi_slopes_data = ctypes.c_void_p(0) + alibi_slopes_torch = None + if use_alibi: + alibi_slopes = TestTensor((num_heads,), None, InfiniDtype.F32, device) + alibi_slopes_desc = alibi_slopes.descriptor + alibi_slopes_data = alibi_slopes.data() + alibi_slopes_torch = alibi_slopes.torch_tensor() + # alibi_slopes_list = [] + # alibi_slopes_ptr = POINTER(c_float)() + # if use_alibi: + # alibi_slopes_list = get_alibi_slopes(num_heads) + # alibi_slopes_ptr = (c_float * len(alibi_slopes_list))(*alibi_slopes_list) + + # Run reference implementation + ans = ref_single_query_cached_kv_attention( + q.torch_tensor(), k_cache.torch_tensor(), v_cache.torch_tensor(), + block_tables.torch_tensor(), seq_lens.torch_tensor(), + scale, alibi_slopes_torch) + + if sync: + sync() + + scale = 1.0 / (head_size**0.5) + # Create operator descriptor + descriptor = infiniopOperatorDescriptor_t() + check_error(LIBINFINIOP.infiniopCreatePagedAttentionDescriptor( + handle, ctypes.byref(descriptor), + out.descriptor, q.descriptor, k_cache.descriptor, v_cache.descriptor, + block_tables.descriptor, seq_lens.descriptor, alibi_slopes_desc, + scale, max_blocks_per_seq + )) + + # block_tables_ptr, seq_lens_ptr, alibi_slopes_ptr, c_float(scale) + + # Get workspace size and allocate memory + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetPagedAttentionWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, q.device) + + + # Invalidate descriptors to ensure kernel does not rely on them + q.destroy_desc() + out.destroy_desc() + k_cache.destroy_desc() + v_cache.destroy_desc() + # block_tables.destroy_desc() + # seq_lens.destroy_desc() + # if use_alibi: + # alibi_slopes.destroy_desc() + + # Define the library call as a lambda for profiling + def lib_paged_attention(): + check_error(LIBINFINIOP.infiniopPagedAttention( + descriptor, workspace.data(), workspace_size.value, + out.data(), q.data(), k_cache.data(), v_cache.data(), + block_tables.data(), seq_lens.data(), alibi_slopes_data, None + )) + + # Execute the custom operator + lib_paged_attention() + + if sync: + sync() + + + + # Verify correctness + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: ref_single_query_cached_kv_attention( + q.torch_tensor(), k_cache.torch_tensor(), v_cache.torch_tensor(), + block_tables.torch_tensor(), seq_lens.torch_tensor(), + scale, alibi_slopes_torch), + device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lib_paged_attention, device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + # Clean up resources + check_error(LIBINFINIOP.infiniopDestroyPagedAttentionDescriptor(descriptor)) + + +# if __name__ == "__main__": +# args = get_args() + +# # Configure testing options from command line arguments +# DEBUG = args.debug +# PROFILE = args.profile +# NUM_PRERUN = args.num_prerun +# NUM_ITERATIONS = args.num_iterations + +# # for device in get_test_devices(args): +# # test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) +# # test_operator(device, test_wrapper, _TEST_CASES_, _TENSOR_DTYPES) +# # # first stage +# num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len = 7, 40, 40, 128, 16, 1024 +# for device in get_test_devices(args): +# test(None, device, num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, dtype=InfiniDtype.F16, use_alibi=False, sync=None) +# test(None, device, num_seqs, num_heads, num_kv_heads, head_size, block_size, max_seq_len, dtype=InfiniDtype.F16, use_alibi=True, sync=None) + +# print("\033[92mTest passed!\033[0m") + +# if __name__ == "__main__": +# args = get_args() +# for device in get_test_devices(args): +# for use_alibi_flag in [True, False]: +# # Create a new closure for test_operator to capture `use_alibi` +# def test_wrapper(handle, device, *test_args, dtype, sync): +# test(*((handle, device) + test_args), dtype=dtype, use_alibi=use_alibi_flag, sync=sync) +# print("\033[92mTest passed!\033[0m") + +if __name__ == "__main__": + args = get_args() + + # Configure testing options from command line arguments + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/test/infiniop/paged_caching.py b/test/infiniop/paged_caching.py new file mode 100644 index 000000000..aa22ba7cd --- /dev/null +++ b/test/infiniop/paged_caching.py @@ -0,0 +1,243 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + TestWorkspace, +) + +# ============================================================================== +# Reference Implementation +# ============================================================================== +def ref_paged_caching(key, value, key_cache_pool, value_cache_pool, slot_mapping): + """ + Reference implementation for paged_caching operator. + + Args: + key (torch.Tensor): Keys, shape [ntok, nkvh, dh] + value (torch.Tensor): Values, shape [ntok, nkvh, dh] + key_cache_pool (torch.Tensor): K cache pool, shape [num_blocks, nkvh, block_size, dh] + value_cache_pool (torch.Tensor): V cache pool, shape [num_blocks, nkvh, block_size, dh] + slot_mapping (torch.Tensor): Slot mapping, shape [ntok] + """ + ntok = key.shape[0] + block_size = key_cache_pool.shape[2] + + # This reference implementation operates on a cloned cache to avoid modifying the original input tensor, + # mimicking the behavior where the custom operator writes to its output tensor. + k_cache_ref = key_cache_pool.clone() + v_cache_ref = value_cache_pool.clone() + + for i in range(ntok): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + + key_token = key[i] + value_token = value[i] + + k_cache_ref[block_idx, :, block_offset, :] = key_token + v_cache_ref[block_idx, :, block_offset, :] = value_token + + return k_cache_ref, v_cache_ref + +# ============================================================================== +# Test Configuration +# ============================================================================== +# (num_seqs, max_seq_len, num_kv_heads, head_size, block_size) +# _TEST_CASES_ = [ +# (1, 128, 8, 128, 16), +# (5, 512, 40, 128, 16), +# (16, 1024, 8, 64, 32), +# (3, 20, 1, 80, 8), # Test with small values +# ] +_TEST_CASES_ = [ + (1, 128, 8, 128, 16), + (5, 512, 40, 128, 16), + (16, 1024, 8, 64, 32), + (10, 1024, 40, 64, 32), + # (1, 1, 4, 1, 2), # Test with small values +] + +# Data types for testing +_TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 5e-3, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-5}, +} + +# Global flags for controlling test behavior +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + +def test( + handle, + device, + num_seqs, # nreq + max_seq_len, + num_kv_heads, # nkvh + head_size, # dh + block_size, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing PagedCaching on {InfiniDeviceNames[device]} with " + f"num_seqs={num_seqs}, max_seq_len={max_seq_len}, num_kv_heads={num_kv_heads}, " + f"head_size={head_size}, block_size={block_size}, dtype={InfiniDtypeNames[dtype]}" + ) + + num_blocks = 4096 # A reasonably large cache pool for testing + + # Create metadata: variable context lengths for each sequence in the batch + context_lens_torch = torch.randint(1, max_seq_len + 1, (num_seqs,), dtype=torch.int32) + ntok = torch.sum(context_lens_torch).item() + + # If ntok is 0 (all sequences have length 0), skip the test + if ntok == 0: + print("Skipping test case with ntok=0") + return + + # Simulate the scheduler's behavior to create the slot_mapping + slot_mapping_list = [] + # We need to ensure allocated blocks do not overlap. Let's simulate a simple block allocator. + allocated_slots = set() + current_slot = 0 + for length in context_lens_torch: + # Find a contiguous chunk of 'length' slots + start_slot = current_slot + slot_mapping_list.extend(range(start_slot, start_slot + length.item())) + current_slot += length.item() + + # Ensure we don't exceed the total number of slots in the cache + assert current_slot <= num_blocks * block_size, "Not enough blocks in the cache pool for this test case" + + slot_mapping_torch = torch.tensor(slot_mapping_list, dtype=torch.int32) + + # Create input tensors based on the calculated total tokens (ntok) + k = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device) + v = TestTensor((ntok, num_kv_heads, head_size), None, dtype, device) + slot_mapping = TestTensor.from_torch(slot_mapping_torch, InfiniDtype.I32, device) + + # The cache pools are the "output" tensors for this operator + k_cache_pool = TestTensor((num_blocks, num_kv_heads, block_size, head_size), None, dtype, device) + v_cache_pool = TestTensor((num_blocks, num_kv_heads, block_size, head_size), None, dtype, device) + + # # Create input tensors based on the calculated total tokens (ntok) + # k_torch = torch.ones((ntok, num_kv_heads, head_size), dtype=torch.float16) + # v_torch = torch.ones((ntok, num_kv_heads, head_size), dtype=torch.float16) + # k = TestTensor.from_torch(k_torch, dtype, device) + # v = TestTensor.from_torch(v_torch, dtype, device) + # slot_mapping = TestTensor.from_torch(slot_mapping_torch, InfiniDtype.I32, device) + + # # The cache pools are the "output" tensors for this operator + # k_cache_pool_torch = torch.zeros((num_blocks, num_kv_heads, block_size, head_size), dtype=torch.float16) + # v_cache_pool_torch = torch.zeros((num_blocks, num_kv_heads, block_size, head_size), dtype=torch.float16) + # k_cache_pool = TestTensor.from_torch(k_cache_pool_torch, dtype, device) + # v_cache_pool = TestTensor.from_torch(v_cache_pool_torch, dtype, device) + + # Run reference implementation + k_cache_ref, v_cache_ref = ref_paged_caching( + k.torch_tensor(), v.torch_tensor(), + k_cache_pool.torch_tensor(), v_cache_pool.torch_tensor(), + slot_mapping.torch_tensor() + ) + + if sync: + sync() + + # Create operator descriptor + descriptor = infiniopOperatorDescriptor_t() + check_error(LIBINFINIOP.infiniopCreatePagedCachingDescriptor( + handle, ctypes.byref(descriptor), + k.descriptor, v.descriptor, + k_cache_pool.descriptor, v_cache_pool.descriptor, + slot_mapping.descriptor + )) + + # Get workspace size (likely 0 for this operator, but good practice to include) + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetPagedCachingWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + # Invalidate descriptors to ensure kernel does not rely on them + k.destroy_desc() + v.destroy_desc() + k_cache_pool.destroy_desc() + v_cache_pool.destroy_desc() + slot_mapping.destroy_desc() + + # Define the library call as a lambda for profiling + def lib_paged_caching(): + check_error(LIBINFINIOP.infiniopPagedCaching( + descriptor, workspace.data(), workspace_size.value, + k.data(), v.data(), + k_cache_pool.data(), v_cache_pool.data(), + slot_mapping.data(), None + )) + + # Execute the custom operator + lib_paged_caching() + + if sync: + sync() + + # Verify correctness + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + print("Verifying K cache...") + debug(k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol) + print("Verifying V cache...") + debug(v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol) + + assert torch.allclose(k_cache_pool.actual_tensor(), k_cache_ref, atol=atol, rtol=rtol) + assert torch.allclose(v_cache_pool.actual_tensor(), v_cache_ref, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: ref_paged_caching( + k.torch_tensor(), v.torch_tensor(), + k_cache_pool.torch_tensor(), v_cache_pool.torch_tensor(), + slot_mapping.torch_tensor()), + device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lib_paged_caching, device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + + # Clean up resources + check_error(LIBINFINIOP.infiniopDestroyPagedCachingDescriptor(descriptor)) + +if __name__ == "__main__": + args = get_args() + + # Configure testing options from command line arguments + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 797edcb5e..6f1a1f067 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -8,6 +8,12 @@ target("infiniop-nvidia") add_deps("infini-utils") on_install(function (target) end) + -- if is_mode("debug") then + -- print("Enabling CUDA debug flags (-g -G)") + -- add_cuflags("-g", "-G") + -- set_optimize("none") + -- end + set_policy("build.cuda.devlink", true) set_toolchains("cuda") add_links("cudart", "cublas") From 19ef94d775599c8d5423e22fe5fda1a0e0fe5b0d Mon Sep 17 00:00:00 2001 From: suss <1152623206@qq.com> Date: Tue, 26 Aug 2025 10:40:13 +0800 Subject: [PATCH 2/4] check paged attn --- .../ops/paged_attention/cuda/kernel.cuh | 43 +------------------ .../nvidia/paged_attention_nvidia.cu | 7 ++- test/infiniop/paged_attention.py | 18 ++++---- 3 files changed, 14 insertions(+), 54 deletions(-) diff --git a/src/infiniop/ops/paged_attention/cuda/kernel.cuh b/src/infiniop/ops/paged_attention/cuda/kernel.cuh index 1f50e68ea..837261088 100644 --- a/src/infiniop/ops/paged_attention/cuda/kernel.cuh +++ b/src/infiniop/ops/paged_attention/cuda/kernel.cuh @@ -55,14 +55,10 @@ __device__ void pagedAttentionKernel( //================================================================================ // 1. Setup & Query Loading (No changes in this section) //================================================================================ - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("pagedAttentionKernel;\n"); - } - __syncthreads(); const int seq_idx = blockIdx.y; const int head_idx = blockIdx.x; const int num_heads = gridDim.x; - const int batch_size = gridDim.y; + // const int batch_size = gridDim.y; const int32_t seq_len = seq_lens_[seq_idx]; if (seq_len == 0) return; @@ -85,33 +81,14 @@ __device__ void pagedAttentionKernel( q_shared[i] = static_cast(q_ptr[i]); } __syncthreads(); - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("q_shared over;\n"); - } - __syncthreads(); - - const int32_t physical_block_num1 = block_table[62]; - - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("k_shared start;%d\n", physical_block_num1); - } - __syncthreads(); - //================================================================================ // 2. Compute QK Dot Product & Find Max Logit //================================================================================ for (size_t token_idx = threadIdx.x; token_idx < seq_len; token_idx += NUM_THREADS) { const int32_t block_idx = token_idx / block_size; const int32_t token_in_block_idx = token_idx % block_size; - if (token_idx == seq_len-1 || token_idx == 0) { - printf("block_idx: %d\n", block_idx); - } - // const int32_t physical_block_num = 0 const int32_t physical_block_num = block_table[block_idx]; - if (token_idx == seq_len-1 || token_idx == 0) { - printf("physical_block_num: %d\n", physical_block_num); - } - + const Tdata* k_vec_ptr = k_cache_ + physical_block_num * kv_block_stride + kv_head_idx * kv_head_stride + token_in_block_idx * HEAD_SIZE; //================================================================================ @@ -144,18 +121,10 @@ __device__ void pagedAttentionKernel( logits[token_idx] = qk; } - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("compute global_qk_max start;\n"); - } - __syncthreads(); __shared__ Tcompute global_qk_max; Tcompute global_qk_max_0 = op::common_cuda::reduce_op::max(logits, seq_len); - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("Cmmon_cuda::reduce_op over;\n"); - } - __syncthreads(); if (threadIdx.x == 0) { global_qk_max = global_qk_max_0; } @@ -183,10 +152,6 @@ __device__ void pagedAttentionKernel( logits[i] *= inv_sum; } __syncthreads(); - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("gregate Values (V) weighted by probabil start;\n"); - } - __syncthreads(); //================================================================================ // 4. Aggregate Values (V) weighted by probabilities @@ -211,10 +176,6 @@ __device__ void pagedAttentionKernel( } out_ptr[h_dim] = static_cast(acc); } - if (blockIdx.x == 0 && threadIdx.x == 0) { - printf("gregate Values (V) weighted by probabil over;\n"); - } - __syncthreads(); } } // namespace op::paged_attention::cuda diff --git a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu index 9a9679f71..02ec499fa 100644 --- a/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu +++ b/src/infiniop/ops/paged_attention/nvidia/paged_attention_nvidia.cu @@ -61,13 +61,12 @@ infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size, ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, cudaStream_t stream) { - printf("launchKernel; num_heads: %zu, num_seqs: %zu, num_kv_heads: %zu, scale: %f, max_num_blocks_per_seq: %zu, block_size: %zu, q_stride: %zu, kv_block_stride: %zu, kv_head_stride: %zu\n", num_heads, num_seqs, num_kv_heads, scale, max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride); dim3 grid(uint32_t(num_heads), uint32_t(num_seqs), 1); dim3 block(NUM_THREADS); + size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); // size_t shared_mem_size = 16; if (dtype == INFINI_DTYPE_F16) { - size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(half); // size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(Tcompute); pagedAttention <<>>( @@ -78,7 +77,7 @@ infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const q_stride, kv_block_stride, kv_head_stride ); } else if (dtype == INFINI_DTYPE_BF16) { - size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(__nv_bfloat16); + // size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); pagedAttention<__nv_bfloat16, float, HEAD_SIZE, NUM_THREADS> <<>>( (__nv_bfloat16*)out, (const __nv_bfloat16*)q, (const __nv_bfloat16*)k_cache, (const __nv_bfloat16*)v_cache, @@ -87,7 +86,7 @@ infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const q_stride, kv_block_stride, kv_head_stride ); } else if (dtype == INFINI_DTYPE_F32) { - size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); + // size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float); pagedAttention <<>>( (float*)out, (const float*)q, (const float*)k_cache, (const float*)v_cache, diff --git a/test/infiniop/paged_attention.py b/test/infiniop/paged_attention.py index 212d625f2..e5f2033ee 100644 --- a/test/infiniop/paged_attention.py +++ b/test/infiniop/paged_attention.py @@ -94,10 +94,10 @@ def ref_single_query_cached_kv_attention(query, key_cache, value_cache, block_ta _TEST_CASES_ = [ # (7, 40, 40, 128, 16, 1024, True), # (1, 1, 1, 128, 16, 1024, False), - (1, 40, 40, 128, 16, 1024, False), + (5, 40, 40, 128, 16, 1024, False), # (5, 8, 8, 128, 16, 1024, True), # (5, 64, 8, 80, 16, 2048, True), - # (5, 64, 8, 80, 16, 2048, False), + (5, 64, 8, 128, 16, 2048, False), ] # Data types for testing @@ -146,11 +146,10 @@ def test( k_cache = TestTensor((num_blocks, num_kv_heads, block_size, head_size), None, dtype, device) v_cache = TestTensor((num_blocks, num_kv_heads, block_size, head_size), None, dtype, device) - - seq_lens_direct = 724 + seq_lens_direct = 1023 + # seq_lens_direct = 725 seq_lens_torch = torch.randint(seq_lens_direct, seq_lens_direct+1, (num_seqs,), dtype=torch.int32) # seq_lens_torch = torch.randint(max_seq_len-1, max_seq_len, (num_seqs,), dtype=torch.int32) - print(f"seq_lens_torch: {seq_lens_torch}") seq_lens = TestTensor.from_torch(seq_lens_torch, InfiniDtype.I32, device) @@ -222,10 +221,10 @@ def test( out.destroy_desc() k_cache.destroy_desc() v_cache.destroy_desc() - # block_tables.destroy_desc() - # seq_lens.destroy_desc() - # if use_alibi: - # alibi_slopes.destroy_desc() + block_tables.destroy_desc() + seq_lens.destroy_desc() + if use_alibi: + alibi_slopes.destroy_desc() # Define the library call as a lambda for profiling def lib_paged_attention(): @@ -248,6 +247,7 @@ def lib_paged_attention(): if DEBUG: debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + # print(f"out.actual_tensor() : {out.actual_tensor()}, ans: {ans}") # Profiling workflow if PROFILE: From 9e035d142fb72ceada93dd7d2c11c6bdf0beaf73 Mon Sep 17 00:00:00 2001 From: suss <1152623206@qq.com> Date: Mon, 8 Sep 2025 11:40:23 +0800 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=AE=8C?= =?UTF-8?q?=E6=95=B4=E7=89=88Paged=20Attention=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ops/paged_attention/cuda/kernel.cuh | 1 + .../ops/paged_caching/cuda/kernel.cuh | 78 ++--------- src/infiniop/ops/paged_caching/info.h | 34 +++-- .../nvidia/paged_caching_nvidia.cu | 125 +++++++++++------- test/infiniop/paged_attention.py | 2 +- 5 files changed, 113 insertions(+), 127 deletions(-) diff --git a/src/infiniop/ops/paged_attention/cuda/kernel.cuh b/src/infiniop/ops/paged_attention/cuda/kernel.cuh index 837261088..3bfd79f3f 100644 --- a/src/infiniop/ops/paged_attention/cuda/kernel.cuh +++ b/src/infiniop/ops/paged_attention/cuda/kernel.cuh @@ -121,6 +121,7 @@ __device__ void pagedAttentionKernel( logits[token_idx] = qk; } + __syncthreads(); __shared__ Tcompute global_qk_max; Tcompute global_qk_max_0 = op::common_cuda::reduce_op::max(logits, seq_len); diff --git a/src/infiniop/ops/paged_caching/cuda/kernel.cuh b/src/infiniop/ops/paged_caching/cuda/kernel.cuh index 8418aaf03..e933be1ea 100644 --- a/src/infiniop/ops/paged_caching/cuda/kernel.cuh +++ b/src/infiniop/ops/paged_caching/cuda/kernel.cuh @@ -27,16 +27,15 @@ template < > __device__ void pagedCachingKernel( // ----- Output Tensors ----- - Tdata* k_cache_ptr, // Pointer to the destination K cache pool - Tdata* v_cache_ptr, // Pointer to the destination V cache pool + Tdata* k_cache_ptr, // Pointer to the destination K cache pool [num_blocks, nkvh, block_size, dh] + Tdata* v_cache_ptr, // Pointer to the destination V cache pool [num_blocks, nkvh, block_size, dh] // ----- Input Tensors ----- const Tdata* k_ptr, // Pointer to the source Keys, shape [ntok, nkvh, dh] const Tdata* v_ptr, // Pointer to the source Values, shape [ntok, nkvh, dh] - const int* slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok] + const int32_t* slot_mapping_ptr, // Pointer to the slot mapping, shape [ntok] // ----- Metadata ----- - const int num_heads, // Number of key/value heads (nkvh) - const int head_size, // Dimension of each head (dh) - const int block_size, // Number of tokens per block in the KV cache + const size_t head_size, // Dimension of each head (dh) + const size_t block_size, // Number of tokens per block in the KV cache // ----- Stride Information ----- const ptrdiff_t k_src_stride, // Stride between tokens in the source K tensor const ptrdiff_t v_src_stride, // Stride between tokens in the source V tensor @@ -48,27 +47,20 @@ __device__ void pagedCachingKernel( //================================================================================ // Each block processes one token. - const int token_idx = blockIdx.x; - const int head_idx = blockIdx.y; + const int token_idx = blockIdx.y; + const int head_idx = blockIdx.x; + // const int num_kv_heads = gridDim.y; // Retrieve the destination slot for the current token. - const int slot_idx = slot_mapping_ptr[token_idx]; + const int32_t slot_idx = slot_mapping_ptr[token_idx]; // Handle padding: if slot_idx is negative, this token is padding and should be ignored. if (slot_idx < 0) { return; } - - // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("[Block %d, Thread %d] Debug Start\n", blockIdx.x, threadIdx.x); - // printf(" - token_idx: %d\n", token_idx); - // printf(" - slot_idx from mapping: %d\n", slot_idx); - // printf(" - Metadata: num_heads=%d, head_size=%d, block_size=%d\n", num_heads, head_size, block_size); - // } - // Calculate the physical block index and the offset within that block. - const int physical_block_idx = slot_idx / block_size; - const int block_offset = slot_idx % block_size; + const int32_t physical_block_idx = slot_idx / block_size; + const int32_t block_offset = slot_idx % block_size; // Calculate base pointers for source and destination for this specific token. const Tdata* k_src_head_ptr = k_ptr + token_idx * k_src_stride + head_idx * head_size; @@ -85,54 +77,6 @@ __device__ void pagedCachingKernel( Tdata* v_cache_block_base_ptr = v_cache_ptr + physical_block_idx * v_cache_block_stride; Tdata* v_dst_head_ptr = v_cache_block_base_ptr + head_idx * cache_head_stride + block_offset * head_size; - - // if (blockIdx.x == 0 && threadIdx.x == 0) { - // printf("[Block %d, Thread %d] Address Calculation\n", blockIdx.x, threadIdx.x); - // printf(" - physical_block_idx: %d\n", physical_block_idx); - // printf(" - block_offset: %d\n", block_offset); - // printf(" - k_src_stride: %ld, v_src_stride: %ld\n", k_src_stride, v_src_stride); - // printf(" - k_cache_block_stride: %ld, v_cache_block_stride: %ld\n", k_cache_block_stride, v_cache_block_stride); - // printf(" - Calculated dst_token_offset_k: %d\n", dst_token_offset_k); - // printf(" - Source Ptr (k): %p\n", k_src_token_ptr); - // printf(" - Dest Ptr (k_cache): %p\n", k_cache_dst_ptr); - // } - - - // //================================================================================ - // // 2. Perform Vectorized Data Copy - // //================================================================================ - - // // Total number of elements to copy for one token (all heads). - // const int total_elements_per_token = num_heads * head_size; - - // // Use vectorization to copy data more efficiently. - // // For Tdata=half (2 bytes), float4 (16 bytes) can process 8 elements at once. - // constexpr int VEC_SIZE = sizeof(float4) / sizeof(Tdata); - - // // Cast pointers to the vectorized type. - // const float4* k_src_vec_ptr = reinterpret_cast(k_src_token_ptr); - // const float4* v_src_vec_ptr = reinterpret_cast(v_src_token_ptr); - // float4* k_cache_dst_vec_ptr = reinterpret_cast(k_cache_dst_ptr); - // float4* v_cache_dst_vec_ptr = reinterpret_cast(v_cache_dst_ptr); - - // // if (blockIdx.x == 0 && threadIdx.x == 0) { - // // printf("[Block %d, Thread %d] Vectorized Copy Start\n", blockIdx.x, threadIdx.x); - // // printf(" - Total elements per token: %d\n", total_elements_per_token); - // // printf(" - Vector size: %d\n", VEC_SIZE); - // // printf(" - float4 size: %d\n", sizeof(float4)); - // // printf(" - Tdata size: %d\n", sizeof(Tdata)); - // // // printf(" - Vector size: %d\n", k_src_vec_ptr[i]); - // // } - - // // Each thread copies one vector (VEC_SIZE elements) per iteration. - // // The loop iterates over the vectorized chunks of data for the token. - // for (int i = threadIdx.x; i < total_elements_per_token / VEC_SIZE; i += NUM_THREADS) { - - // k_cache_dst_vec_ptr[i] = k_src_vec_ptr[i]; - // v_cache_dst_vec_ptr[i] = v_src_vec_ptr[i]; - // } - // } - //================================================================================ // 2. Perform Element-wise Data Copy (Safe, Non-Vectorized) //================================================================================ diff --git a/src/infiniop/ops/paged_caching/info.h b/src/infiniop/ops/paged_caching/info.h index 23a4da0ac..f17072a8d 100644 --- a/src/infiniop/ops/paged_caching/info.h +++ b/src/infiniop/ops/paged_caching/info.h @@ -19,7 +19,7 @@ class PagedCachingInfo { // --- Shape Dimensions --- size_t num_tokens; - size_t num_heads; + size_t num_kv_heads; size_t head_size; size_t block_size; @@ -42,6 +42,7 @@ class PagedCachingInfo { return INFINI_STATUS_BAD_TENSOR_DTYPE; } if (slot_mapping_desc->dtype() != INFINI_DTYPE_I32) { + printf("slot_mapping must be int32_t.\n"); return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -50,25 +51,34 @@ class PagedCachingInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - PagedCachingInfo info; - info.dtype = dtype; + // PagedCachingInfo info; // --- Extract shape dimensions --- auto k_shape = k_desc->shape(); auto k_cache_shape = k_cache_desc->shape(); - info.num_tokens = slot_mapping_desc->shape()[0]; - info.num_heads = k_shape[1]; - info.head_size = k_shape[2]; - info.block_size = k_cache_shape[2]; // Assuming [num_blocks, num_heads, block_size, head_size] layout + size_t num_tokens = slot_mapping_desc->shape()[0]; + size_t num_kv_heads = k_shape[1]; + size_t head_size = k_shape[2]; + size_t block_size = k_cache_shape[2]; // Assuming [num_blocks, num_heads, block_size, head_size] layout // --- Extract strides for memory access --- - info.k_src_stride = k_desc->stride(0); - info.v_src_stride = v_desc->stride(0); - info.k_cache_block_stride = k_cache_desc->stride(0); - info.v_cache_block_stride = v_cache_desc->stride(0); + ptrdiff_t k_src_stride = k_desc->stride(0); + ptrdiff_t v_src_stride = v_desc->stride(0); + ptrdiff_t k_cache_block_stride = k_cache_desc->stride(0); + ptrdiff_t v_cache_block_stride = v_cache_desc->stride(0); - return utils::Result(info); + return utils::Result(PagedCachingInfo{ + dtype, + num_tokens, + num_kv_heads, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride + }); } }; diff --git a/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu b/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu index fb63557db..1e203c843 100644 --- a/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu +++ b/src/infiniop/ops/paged_caching/nvidia/paged_caching_nvidia.cu @@ -9,13 +9,14 @@ template INFINIOP_CUDA_KERNEL pagedCaching( Tdata *k_cache, Tdata *v_cache, const Tdata *k, const Tdata *v, - const int *slot_mapping, - const int num_heads, const int head_size, const int block_size, + const int32_t *slot_mapping, + const size_t head_size, const size_t block_size, const ptrdiff_t k_src_stride, const ptrdiff_t v_src_stride, const ptrdiff_t k_cache_block_stride, const ptrdiff_t v_cache_block_stride ) { op::paged_caching::cuda::pagedCachingKernel( - k_cache, v_cache, k, v, slot_mapping, num_heads, head_size, block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); + k_cache, v_cache, k, v, slot_mapping, head_size, + block_size, k_src_stride, v_src_stride, k_cache_block_stride, v_cache_block_stride); } namespace op::paged_caching::nvidia { @@ -56,15 +57,19 @@ infiniStatus_t Descriptor::create( // The launchKernel function is a templated helper to encapsulate the CUDA kernel launch. // It sets up grid/block dimensions and calls the device-side kernel. -template -void launchKernel(const PagedCachingInfo& info, +template +infiniStatus_t launchKernel(const PagedCachingInfo& info, void *k_cache, void *v_cache, + infiniDtype_t dtype, const void *k, const void *v, const void *slot_mapping, + size_t num_tokens, size_t num_kv_heads, size_t head_size, size_t block_size, + ptrdiff_t k_src_stride, ptrdiff_t v_src_stride, + ptrdiff_t k_cache_block_stride, ptrdiff_t v_cache_block_stride, cudaStream_t stream) { // Grid dimension is 1D, with one block per token, as we decided. - dim3 grid(uint32_t(info.num_tokens), uint32_t(info.num_heads), 1); + dim3 grid(uint32_t(num_kv_heads), uint32_t(num_tokens), 1); // Block dimension is 1D, using the number of threads specified at compile time. dim3 block(NUM_THREADS); @@ -72,23 +77,57 @@ void launchKernel(const PagedCachingInfo& info, size_t shared_mem_size = 0; // Launch the device-side CUDA kernel. - pagedCaching + if (dtype == INFINI_DTYPE_F16) { + pagedCaching <<>>( - (Tdata*)k_cache, - (Tdata*)v_cache, - (const Tdata*)k, - (const Tdata*)v, - (const int*)slot_mapping, - info.num_heads, - info.head_size, - info.block_size, - info.k_src_stride, - info.v_src_stride, - info.k_cache_block_stride, - info.v_cache_block_stride + (half*)k_cache, + (half*)v_cache, + (const half*)k, + (const half*)v, + (const int32_t*)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride ); -} + } else if (dtype == INFINI_DTYPE_BF16) { + pagedCaching<__nv_bfloat16, NUM_THREADS> + <<>>( + (__nv_bfloat16*)k_cache, + (__nv_bfloat16*)v_cache, + (const __nv_bfloat16*)k, + (const __nv_bfloat16*)v, + (const int32_t*)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride + ); + } else if (dtype == INFINI_DTYPE_F32) { + pagedCaching + <<>>( + (float*)k_cache, + (float*)v_cache, + (const float*)k, + (const float*)v, + (const int32_t*)slot_mapping, + head_size, + block_size, + k_src_stride, + v_src_stride, + k_cache_block_stride, + v_cache_block_stride + ); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} // Execution method implementation infiniStatus_t Descriptor::calculate( @@ -104,39 +143,31 @@ infiniStatus_t Descriptor::calculate( // This allows selecting the largest, most efficient block size the hardware supports. if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_1024) { // Dispatch based on data type for a 1024-thread block. - if (_info.dtype == INFINI_DTYPE_F16) { - launchKernel( - _info, k_cache, v_cache, k, v, slot_mapping, stream); - } else if (_info.dtype == INFINI_DTYPE_BF16) { - launchKernel<__nv_bfloat16, CUDA_BLOCK_SIZE_1024>( - _info, k_cache, v_cache, k, v, slot_mapping, stream); - } else if (_info.dtype == INFINI_DTYPE_F32) { - launchKernel( - _info, k_cache, v_cache, k, v, slot_mapping, stream); - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) { - // Dispatch based on data type for a 512-thread block. - if (_info.dtype == INFINI_DTYPE_F16) { - launchKernel( - _info, k_cache, v_cache, k, v, slot_mapping, stream); - } else if (_info.dtype == INFINI_DTYPE_BF16) { - launchKernel<__nv_bfloat16, CUDA_BLOCK_SIZE_512>( - _info, k_cache, v_cache, k, v, slot_mapping, stream); - } else if (_info.dtype == INFINI_DTYPE_F32) { - launchKernel( - _info, k_cache, v_cache, k, v, slot_mapping, stream); - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); + } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_4096) { + launchKernel( + _info, k_cache, v_cache, _info.dtype, k, v, slot_mapping, + _info.num_tokens, _info.num_kv_heads, _info.head_size, _info.block_size, + _info.k_src_stride, _info.v_src_stride, + _info.k_cache_block_stride, _info.v_cache_block_stride, + stream); } else { // If the GPU is older and supports fewer threads, return an error. return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } - // Check for any asynchronous errors launched by the kernel. - // return CHECK_CUDA(cudaGetLastError()); return INFINI_STATUS_SUCCESS; } diff --git a/test/infiniop/paged_attention.py b/test/infiniop/paged_attention.py index e5f2033ee..beeaf4f67 100644 --- a/test/infiniop/paged_attention.py +++ b/test/infiniop/paged_attention.py @@ -201,7 +201,7 @@ def test( handle, ctypes.byref(descriptor), out.descriptor, q.descriptor, k_cache.descriptor, v_cache.descriptor, block_tables.descriptor, seq_lens.descriptor, alibi_slopes_desc, - scale, max_blocks_per_seq + scale )) # block_tables_ptr, seq_lens_ptr, alibi_slopes_ptr, c_float(scale) From 084c896b14fddb45badbcb172a959ffe2b586a96 Mon Sep 17 00:00:00 2001 From: suss Date: Thu, 18 Sep 2025 16:39:32 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BC=98=E5=8C=96paged=E7=9A=84=E5=AF=BB?= =?UTF-8?q?=E5=9D=80=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- a.out | Bin 0 -> 15464 bytes .../ops/paged_attention/cuda/kernel.cuh | 3 ++- src/infiniop/ops/paged_attention/info.h | 2 ++ 3 files changed, 4 insertions(+), 1 deletion(-) create mode 100755 a.out diff --git a/a.out b/a.out new file mode 100755 index 0000000000000000000000000000000000000000..f03b854279893b0376114d0217ab9278fc486516 GIT binary patch literal 15464 zcmeHOZ)hCH6`zx%=A5{dQaeGhlX_!pEZ51NWmh2y$hosDX$7(!TXJbg!+LeMlFr;8 zakpn=r!)#80U-=&@~u#CAkY#@!6X<`O0cnU><r3GK!n|;62 zW_7aN(&R&#XWzVee`n@5qn$l&=EZ@L!L~#~fmW&AYDrQmErvhz!h>~BsMS<%Q0ukt zT}!5AwO7Rmt~x^(Zlj&6BFFg|=>$0Jtkx+l4kog6+2or=1dqo3fNV)SbV%<0Hdk$- z+$kJ!EOK18#C1y?_$lcpcrr|*`*^E$9PLXk>U#zq`}kY4@tsRJrJf%qODYc+K>jkBEYym0b=N&EMnV7Be1Gk?Q5_m8KY zbM;kD;-d56&0rlR=#>P?#-|DCdB8td5c z6RMI!<@#b5WqU63%Zu$7NtJk>x{_{=hx_x_t>hr%)A%m5|Bk+gI&<|#kUC6#8>LPx zXw7Fy@oQXh7W(Vdi6pz#Za53i*Ds%+c$Kl&IcSa+J9rqe2_d^7gMdN6AYc$M2p9wm z0tNwtfI+|@U=T0}+ztW$?Tvo}4g%$n| z%JhUBCjIheqs;OVk@eYelItSl?DMS1_pme3=OI(~3hxq*@3)I`4)t!>rD2?Voe@19 zj@5+tn~ktPbxA~c>HYzi8zO%y_OD}hN%y@moL67gU^>(c_4hwzb!SSRK5=Z(Y2|HJ zPY)eT-a%*P$wi+NrojzmHYav@J?a@bI1Jv(2Xp&?__p%FeE(7Ay_L@!VgqmGV@VrZ zZqe3h$#X{dGAtqrp2xe?KN}I1uJXD~&}zp00od1TeW&V@`vK5*Q>BeoA8>;Hr=p*g z{zBg;`m-_pDCraGew9pzo>*tD=sQKv*f1TXk8U?M)|JL}jQn|Aow5GUY5k9sWl<$U zzo9zkKN8E&F0F9;J?Rs)AgoyaceH+;QfaD0=p5DC)Eb5L0n__hPos?C^8x7->OK`; z=O1bPx+ZSO^NH54RrvddzD|`*5F&gn0cS;IR8uSHcdwvV*-B6iY7-N7maZW5j~zJV zjtq~E@eKrbZ$w(Uir@xCH_Mk1s^pR@yOV|TxL0s2i-9*zJg9_x3bn~n&7ShAQ_9XAFOgfwgNm}*a?Gz(^X1aF1ea_Tzu<9zvh`T) z=|W&nmIGx6{tT5G2D?-ae4S;wQl9oJ!Et5#Q?70 zDSNz1Mw-K-Um_j%XuOs$WuqpF9T{-zZ7?rs_{ri^lJeIPloxG@f30vG}wEfZr~j_oHb0*bl)EvN1yO zeA8r$#$#Uq|8XRitMT#k{>JZkz+YJ;Fj=|;3k52HKNRy+;r0=4S)m0y-d;TZb%}4z zuk3@kcNZUTJpQ+_c)VwH(Gf(Z+r_^@KD>6&KK4o6Ydpn+Bf>b)KJxlxV;98Z9@r&1 zKFMV1a;ZQ(_*-O(#$(@Ilz6OjWI!DF+p&0+7SA0aho5Md&ELCZU?229nGQXZ5|43) z+U4{A9vQg*;tn^3cz!-ciG7*C`MrhLAE@g{SlcWKufiZ4^8hvH!P=k1m-u2K(Z?Zcmht#+06Py|TL1t6 literal 0 HcmV?d00001 diff --git a/src/infiniop/ops/paged_attention/cuda/kernel.cuh b/src/infiniop/ops/paged_attention/cuda/kernel.cuh index 3bfd79f3f..9724dd4ba 100644 --- a/src/infiniop/ops/paged_attention/cuda/kernel.cuh +++ b/src/infiniop/ops/paged_attention/cuda/kernel.cuh @@ -58,6 +58,7 @@ __device__ void pagedAttentionKernel( const int seq_idx = blockIdx.y; const int head_idx = blockIdx.x; const int num_heads = gridDim.x; + const ptrdiff_t o_stride = q_stride/3; // qkv // const int batch_size = gridDim.y; const int32_t seq_len = seq_lens_[seq_idx]; if (seq_len == 0) return; @@ -69,7 +70,7 @@ __device__ void pagedAttentionKernel( const int32_t* block_table = block_tables_ + seq_idx * max_num_blocks_per_seq; const Tdata* q_ptr = q_ + seq_idx * q_stride + head_idx * HEAD_SIZE; - Tdata* out_ptr = out_ + seq_idx * q_stride + head_idx * HEAD_SIZE; + Tdata* out_ptr = out_ + seq_idx * o_stride + head_idx * HEAD_SIZE; extern __shared__ char shared_mem_char[]; Tcompute* shared_mem = reinterpret_cast(shared_mem_char); diff --git a/src/infiniop/ops/paged_attention/info.h b/src/infiniop/ops/paged_attention/info.h index db63a649f..723affe56 100644 --- a/src/infiniop/ops/paged_attention/info.h +++ b/src/infiniop/ops/paged_attention/info.h @@ -5,6 +5,7 @@ #include "../../tensor.h" #include #include +#include namespace op::paged_attention { @@ -71,6 +72,7 @@ class PagedAttentionInfo { // --- Extract strides for memory access --- ptrdiff_t q_stride = q_desc->stride(0); + // ptrdiff_t q_stride = 3584; ptrdiff_t kv_block_stride = k_cache_desc->stride(0); ptrdiff_t kv_head_stride = k_cache_desc->stride(1); // Note: We assume k_cache and v_cache have compatible strides.