Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added a.out
Binary file not shown.
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "infiniop/ops/dequantize.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/paged_attention.h"
#include "infiniop/ops/paged_caching.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
Expand Down
88 changes: 88 additions & 0 deletions include/infiniop/ops/paged_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#ifndef __INFINIOP_PAGED_ATTENTION_API_H__
#define __INFINIOP_PAGED_ATTENTION_API_H__

#include "../operator_descriptor.h"

// Define an opaque handle for the Paged Attention descriptor.
typedef struct InfiniopDescriptor *infiniopPagedAttentionDescriptor_t;

/**
* @brief Creates a descriptor for the Paged Attention v1 operation.
*
* This function initializes a descriptor that holds all the metadata needed
* for the paged attention computation.
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param q_desc Descriptor for the query tensor.
* @param k_cache_desc Descriptor for the key cache tensor.
* @param v_cache_desc Descriptor for the value cache tensor.
* @param block_tables_desc Descriptor for the block tables tensor.
* @param seq_lens_desc Descriptor for the sequence lengths tensor.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* @param max_num_blocks_per_seq The maximum number of batched blocks tables.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);

/**
* @brief Retrieves the workspace size required for the Paged Attention operation.
*
* @param desc The Paged Attention descriptor.
* @param size A pointer to store the required workspace size in bytes.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
infiniopPagedAttentionDescriptor_t desc, size_t *size);

/**
* @brief Executes the Paged Attention v1 operation.
*
* @param desc The Paged Attention descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data.
* @param k_cache Pointer to the key cache data.
* @param v_cache Pointer to the value cache data.
* @param block_tables Pointer to the block tables data.
* @param seq_lens Pointer to the sequence lengths data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttention(
infiniopPagedAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *alibi_slopes,
void *stream);

/**
* @brief Destroys a Paged Attention descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
infiniopPagedAttentionDescriptor_t desc);

#endif // __INFINIOP_PAGED_ATTENTION_API_H__
77 changes: 77 additions & 0 deletions include/infiniop/ops/paged_caching.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#ifndef __INFINIOP_PAGED_CACHING_API_H__
#define __INFINIOP_PAGED_CACHING_API_H__

#include "../operator_descriptor.h"

// Define an opaque handle for the Paged Caching descriptor.
typedef struct InfiniopDescriptor *infiniopPagedCachingDescriptor_t;

/**
* @brief Creates a descriptor for the Paged Caching operation.
*
* This function initializes a descriptor that holds all the metadata needed
* to copy key/value vectors into their respective cache pools.
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param k_desc Descriptor for the source key tensor.
* @param v_desc Descriptor for the source value tensor.
* @param k_cache_desc Descriptor for the key cache pool tensor.
* @param v_cache_desc Descriptor for the value cache pool tensor.
* @param slot_mapping_desc Descriptor for the slot mapping tensor.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedCachingDescriptor(
infiniopHandle_t handle,
infiniopPagedCachingDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t slot_mapping_desc);

/**
* @brief Retrieves the workspace size required for the Paged Caching operation.
*
* @param desc The Paged Caching descriptor.
* @param size A pointer to store the required workspace size in bytes (typically 0).
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopGetPagedCachingWorkspaceSize(
infiniopPagedCachingDescriptor_t desc, size_t *size);

/**
* @brief Executes the Paged Caching operation.
*
* @param desc The Paged Caching descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param k Pointer to the source key tensor data.
* @param v Pointer to the source value tensor data.
* @param k_cache Pointer to the key cache pool data.
* @param v_cache Pointer to the value cache pool data.
* @param slot_mapping Pointer to the slot mapping data.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedCaching(
infiniopPagedCachingDescriptor_t desc,
void *workspace,
size_t workspace_size,
const void *k,
const void *v,
void *k_cache,
void *v_cache,
const void *slot_mapping,
void *stream);

/**
* @brief Destroys a Paged Caching descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopDestroyPagedCachingDescriptor(
infiniopPagedCachingDescriptor_t desc);

#endif // __INFINIOP_PAGED_CACHING_API_H__
2 changes: 2 additions & 0 deletions scripts/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def run_tests(args):
"clip.py",
"gemm.py",
"mul.py",
"paged_attention.py",
"paged_caching.py",
"random_sample.py",
"rearrange.py",
"rms_norm.py",
Expand Down
2 changes: 2 additions & 0 deletions src/infiniop-test/include/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST(causal_softmax)
DECLARE_INFINIOP_TEST(rearrange)
DECLARE_INFINIOP_TEST(sub)
DECLARE_INFINIOP_TEST(paged_attention)

#define REGISTER_INFINIOP_TEST(name) \
{ \
Expand Down Expand Up @@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub)
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(paged_attention) \
}

namespace infiniop_test {
Expand Down
163 changes: 163 additions & 0 deletions src/infiniop-test/src/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>

namespace infiniop_test::paged_attention {

// The Test class for the paged_attention operator.
struct Test::Attributes {
// Paged attention uses tensors for most parameters, but scale is a scalar.
std::shared_ptr<Tensor> scale;

// Tensors for the operation.
std::shared_ptr<Tensor> q;
std::shared_ptr<Tensor> k_cache;
std::shared_ptr<Tensor> v_cache;
std::shared_ptr<Tensor> block_tables;
std::shared_ptr<Tensor> seq_lens;
std::shared_ptr<Tensor> alibi_slopes; // Can be null
std::shared_ptr<Tensor> ans;
std::shared_ptr<Tensor> out;

// MODIFIED: op_desc and workspace are removed from here.
// They will be managed as local variables within the run() function,
// which is a cleaner, safer pattern demonstrated by the causal_softmax example.
};

// Factory method to build a Test object from GGUF data.
std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();
if (!check_names(tensors, Test::tensor_names())) {
throw std::runtime_error("Invalid Test: Missing tensors.");
}

test->_attributes->scale = tensors["scale"];
test->_attributes->q = tensors["q"];
test->_attributes->k_cache = tensors["k_cache"];
test->_attributes->v_cache = tensors["v_cache"];
test->_attributes->block_tables = tensors["block_tables"];
test->_attributes->seq_lens = tensors["seq_lens"];
if (tensors.count("alibi_slopes")) {
test->_attributes->alibi_slopes = tensors["alibi_slopes"];
} else {
test->_attributes->alibi_slopes = nullptr;
}
test->_attributes->ans = tensors["ans"];
test->_attributes->out = tensors["out"];

return test;
}

// Executes the test case.
std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {

// MODIFIED: op_desc and workspace are now local variables.
infiniopPagedAttentionDescriptor_t op_desc = nullptr;
void *workspace = nullptr;

// Move tensors to the target device
auto scale_tensor = _attributes->scale->to(device, device_id);
auto q = _attributes->q->to(device, device_id);
auto k_cache = _attributes->k_cache->to(device, device_id);
auto v_cache = _attributes->v_cache->to(device, device_id);
auto block_tables = _attributes->block_tables->to(device, device_id);
auto seq_lens = _attributes->seq_lens->to(device, device_id);
auto out = _attributes->out->to(device, device_id);
std::shared_ptr<Tensor> alibi_slopes = nullptr;
if (_attributes->alibi_slopes) {
alibi_slopes = _attributes->alibi_slopes->to(device, device_id);
}

float scale_val = *reinterpret_cast<float*>(scale_tensor->data());

// Create operator descriptor
CHECK_OR(infiniopCreatePagedAttentionDescriptor(
handle, &op_desc, out->desc(), q->desc(), k_cache->desc(), v_cache->desc(),
block_tables->desc(), seq_lens->desc(),
alibi_slopes ? alibi_slopes->desc() : nullptr, scale_val),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));

// Get workspace size and allocate memory
size_t workspace_size;
CHECK_OR(infiniopGetPagedAttentionWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
if (workspace_size > 0) {
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
}

// Execute the operator for the first time
CHECK_OR(infiniopPagedAttention(op_desc, workspace, workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(),
block_tables->data(), seq_lens->data(),
alibi_slopes ? alibi_slopes->data() : nullptr, nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));

// Verify the result
try {
allClose(out, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}


// Benchmark the operation
double elapsed_time = 0.;
elapsed_time = benchmark(
[=]() { // Use reference capture to ensure local variables are available
infiniopPagedAttention(op_desc, workspace, workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(),
block_tables->data(), seq_lens->data(),
alibi_slopes ? alibi_slopes->data() : nullptr, nullptr);
},
warm_ups, iterations);
// return TEST_PASSED(elapsed_time);

// Cleanup and return success
if (op_desc) { infiniopDestroyPagedAttentionDescriptor(op_desc); }
if (workspace) { infinirtFree(workspace); }
return TEST_PASSED(elapsed_time);
}

// Define expected attribute and tensor names for validation.
std::vector<std::string> Test::attribute_names() { return {}; }
std::vector<std::string> Test::tensor_names() {
return {"scale", "q", "k_cache", "v_cache", "block_tables", "seq_lens", "ans", "out"};
}
std::vector<std::string> Test::output_names() { return {"out"}; }

// MODIFIED: Added a toString() method for better debugging and logging, mimicking the reference file.
std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- q: " << _attributes->q->info() << std::endl;
oss << "- k_cache: " << _attributes->k_cache->info() << std::endl;
oss << "- v_cache: " << _attributes->v_cache->info() << std::endl;
oss << "- block_tables: " << _attributes->block_tables->info() << std::endl;
oss << "- seq_lens: " << _attributes->seq_lens->info() << std::endl;
if (_attributes->alibi_slopes) {
oss << "- alibi_slopes: " << _attributes->alibi_slopes->info() << std::endl;
}
oss << "- out: " << _attributes->out->info() << std::endl;
oss << "- ans: " << _attributes->ans->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}

// Destructor to clean up resources.
// MODIFIED: The destructor is now simpler as it only needs to free the attributes struct.
Test::~Test() {
if (_attributes) {
delete _attributes;
}
}

} // namespace infiniop_test::paged_attention
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include <cub/block/block_reduce.cuh>

#include "../../../reduce/cuda/reduce.cuh"

#include "../cuda/kernel.cuh"
Expand Down
Loading
Loading