Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ cache/
# JSON
*.json

#GGUF
# GGUF
*.gguf
6 changes: 6 additions & 0 deletions include/infinicore.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,10 @@ typedef enum {
INFINI_DTYPE_BF16 = 19,
} infiniDtype_t;

typedef enum {
INFINIOP_ATTENTION_MASK_TYPE_NONE = 0,
INFINIOP_ATTENTION_MASK_TYPE_FULL = 1,
INFINIOP_ATTENTION_MASK_TYPE_CAUSAL = 2,
} infiniopAttentionMaskType_t;

#endif // __INFINICORE_API_H__
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/flash_attention_backward.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/random_sample.h"
Expand Down
37 changes: 37 additions & 0 deletions include/infiniop/ops/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t l_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t mask_desc,
infiniopAttentionMaskType_t mask_type);

__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
void *l,
const void *q,
const void *k,
const void *v,
const void *mask,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(infiniopFlashAttentionDescriptor_t desc);

#endif
42 changes: 42 additions & 0 deletions include/infiniop/ops/flash_attention_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef __INFINIOP_FLASH_ATTENTION_BACKWARD_H__
#define __INFINIOP_FLASH_ATTENTION_BACKWARD_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopFlashAttentionBackwardDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionBackwardDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionBackwardDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t grad_q_desc,
infiniopTensorDescriptor_t grad_k_desc,
infiniopTensorDescriptor_t grad_v_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t grad_out_desc,
infiniopTensorDescriptor_t mask_desc,
infiniopAttentionMaskType_t mask_type);

__C __export infiniStatus_t infiniopGetFlashAttentionBackwardWorkspaceSize(
infiniopFlashAttentionBackwardDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttentionBackward(
infiniopFlashAttentionBackwardDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *grad_q,
void *grad_k,
void *grad_v,
const void *q,
const void *k,
const void *v,
const void *grad_out,
const void *mask,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionBackwardDescriptor(
infiniopFlashAttentionBackwardDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions scripts/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def run_tests(args):
"sub.py",
"swiglu.py",
"softplus.py",
"flash_attention.py",
"flash_attention_backward.py",
]:
result = subprocess.run(
f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True
Expand Down
30 changes: 17 additions & 13 deletions src/infiniop-test/include/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST(causal_softmax)
DECLARE_INFINIOP_TEST(rearrange)
DECLARE_INFINIOP_TEST(sub)
DECLARE_INFINIOP_TEST(flash_attention)
DECLARE_INFINIOP_TEST(flash_attention_backward)

#define REGISTER_INFINIOP_TEST(name) \
{ \
Expand All @@ -30,19 +32,21 @@ DECLARE_INFINIOP_TEST(sub)
/*
* Register all the tests here
*/
#define TEST_BUILDER_MAPPINGS \
{ \
REGISTER_INFINIOP_TEST(gemm) \
REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(add) \
REGISTER_INFINIOP_TEST(mul) \
REGISTER_INFINIOP_TEST(clip) \
REGISTER_INFINIOP_TEST(swiglu) \
REGISTER_INFINIOP_TEST(rope) \
REGISTER_INFINIOP_TEST(rms_norm) \
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
#define TEST_BUILDER_MAPPINGS \
{ \
REGISTER_INFINIOP_TEST(gemm) \
REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(add) \
REGISTER_INFINIOP_TEST(mul) \
REGISTER_INFINIOP_TEST(clip) \
REGISTER_INFINIOP_TEST(swiglu) \
REGISTER_INFINIOP_TEST(rope) \
REGISTER_INFINIOP_TEST(rms_norm) \
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(flash_attention) \
REGISTER_INFINIOP_TEST(flash_attention_backward) \
}

namespace infiniop_test {
Expand Down
158 changes: 158 additions & 0 deletions src/infiniop-test/src/ops/flash_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>

namespace infiniop_test::flash_attention {
struct Test::Attributes {
int mask_type;
std::shared_ptr<Tensor> q;
std::shared_ptr<Tensor> k;
std::shared_ptr<Tensor> v;
std::shared_ptr<Tensor> mask;
std::shared_ptr<Tensor> out;
std::shared_ptr<Tensor> l;
std::shared_ptr<Tensor> ans;
};

std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {

auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();

if (attributes.find("mask_type") == attributes.end()
|| tensors.find("q") == tensors.end()
|| tensors.find("k") == tensors.end()
|| tensors.find("v") == tensors.end()
|| tensors.find("out") == tensors.end()
|| tensors.find("l") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test: Missing attributes or tensors");
}

if (tensors.find("mask") == tensors.end()) {
test->_attributes->mask = nullptr;
} else {
test->_attributes->mask = tensors["mask"];
}

test->_attributes->mask_type = *reinterpret_cast<int *>(attributes["mask_type"].data());

test->_attributes->q = tensors["q"];
test->_attributes->k = tensors["k"];
test->_attributes->v = tensors["v"];
test->_attributes->out = tensors["out"];
test->_attributes->l = tensors["l"];
test->_attributes->ans = tensors["ans"];

return test;
}

std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id,
size_t warm_ups, size_t iterations) {

infiniopFlashAttentionDescriptor_t op_desc;
infiniopAttentionMaskType_t mask_type = static_cast<infiniopAttentionMaskType_t>(_attributes->mask_type);
CHECK_OR(infiniopCreateFlashAttentionDescriptor(
handle, &op_desc,
_attributes->out->desc(),
_attributes->l->desc(),
_attributes->q->desc(),
_attributes->k->desc(),
_attributes->v->desc(),
_attributes->mask->desc(),
mask_type),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create FlashAttention descriptor"));

auto out = _attributes->out->to(device, device_id);
auto l = _attributes->l->to(device, device_id);
auto q = _attributes->q->to(device, device_id);
auto k = _attributes->k->to(device, device_id);
auto v = _attributes->v->to(device, device_id);
auto mask = _attributes->mask ? _attributes->mask->to(device, device_id) : nullptr;

size_t workspace_size;
CHECK_OR(infiniopGetFlashAttentionWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size"));
void *workspace = nullptr;
if (workspace_size > 0) {
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace"));
}

CHECK_OR(infiniopFlashAttention(op_desc,
workspace, workspace_size,
out->data(),
l->data(),
q->data(),
k->data(),
v->data(),
mask ? mask->data() : nullptr,
nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "FlashAttention execution failed"));

try {
allClose(out, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}

double elapsed_time = 0;

elapsed_time = benchmark(
[=]() {
infiniopFlashAttention(op_desc,
workspace, workspace_size,
out->data(),
l->data(),
q->data(),
k->data(),
v->data(),
mask ? mask->data() : nullptr,
nullptr);
},
warm_ups, iterations);

if (workspace != nullptr) {
infinirtFree(workspace);
}

return TEST_PASSED(elapsed_time);
}

std::vector<std::string> Test::attribute_names() {
return {"mask_type"};
}

std::vector<std::string> Test::tensor_names() {
return {"q", "k", "v", "mask", "out", "l", "ans"};
}

std::vector<std::string> Test::output_names() {
return {"out", "l"};
}

std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- masktype=" << static_cast<infiniopAttentionMaskType_t>(_attributes->mask_type) << std::endl;
oss << "- q: " << _attributes->q->info() << std::endl;
oss << "- k: " << _attributes->k->info() << std::endl;
oss << "- v: " << _attributes->v->info() << std::endl;
oss << "- mask: " << (_attributes->mask ? _attributes->mask->info() : "none") << std::endl;
oss << "- out: " << _attributes->out->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}

Test::~Test() {
delete _attributes;
}

} // namespace infiniop_test::flash_attention
Loading
Loading