Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ cache/
# JSON
*.json

#GGUF
# GGUF
*.gguf
2 changes: 2 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize.h"
#include "infiniop/ops/flash_attention.h"
#include "infiniop/ops/flash_attention_backward.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/random_sample.h"
Expand Down
43 changes: 43 additions & 0 deletions include/infiniop/ops/flash_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
#define __INFINIOP_FLASH_ATTENTION_API_H__

#include "../operator_descriptor.h"

typedef enum {
INFINIOP_ATTENTION_MASK_TYPE_NONE = 0,
INFINIOP_ATTENTION_MASK_TYPE_FULL = 1,
INFINIOP_ATTENTION_MASK_TYPE_CAUSAL = 2,
} infiniopAttentionMaskType_t;

typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t l_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t mask_desc,
infiniopAttentionMaskType_t mask_type);

__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
infiniopFlashAttentionDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttention(
infiniopFlashAttentionDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
void *l,
const void *q,
const void *k,
const void *v,
const void *mask,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(infiniopFlashAttentionDescriptor_t desc);

#endif
43 changes: 43 additions & 0 deletions include/infiniop/ops/flash_attention_backward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __INFINIOP_FLASH_ATTENTION_BACKWARD_H__
#define __INFINIOP_FLASH_ATTENTION_BACKWARD_H__

#include "../operator_descriptor.h"
#include "flash_attention.h"

typedef struct InfiniopDescriptor *infiniopFlashAttentionBackwardDescriptor_t;

__C __export infiniStatus_t infiniopCreateFlashAttentionBackwardDescriptor(
infiniopHandle_t handle,
infiniopFlashAttentionBackwardDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t grad_q_desc,
infiniopTensorDescriptor_t grad_k_desc,
infiniopTensorDescriptor_t grad_v_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t grad_out_desc,
infiniopTensorDescriptor_t mask_desc,
infiniopAttentionMaskType_t mask_type);

__C __export infiniStatus_t infiniopGetFlashAttentionBackwardWorkspaceSize(
infiniopFlashAttentionBackwardDescriptor_t desc,
size_t *size);

__C __export infiniStatus_t infiniopFlashAttentionBackward(
infiniopFlashAttentionBackwardDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *grad_q,
void *grad_k,
void *grad_v,
const void *q,
const void *k,
const void *v,
const void *grad_out,
const void *mask,
void *stream);

__C __export infiniStatus_t infiniopDestroyFlashAttentionBackwardDescriptor(
infiniopFlashAttentionBackwardDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions scripts/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def run_tests(args):
"sub.py",
"swiglu.py",
"softplus.py",
"flash_attention.py",
"flash_attention_backward.py",
]:
result = subprocess.run(
f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True
Expand Down
30 changes: 17 additions & 13 deletions src/infiniop-test/include/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ DECLARE_INFINIOP_TEST(add)
DECLARE_INFINIOP_TEST(causal_softmax)
DECLARE_INFINIOP_TEST(rearrange)
DECLARE_INFINIOP_TEST(sub)
DECLARE_INFINIOP_TEST(flash_attention)
DECLARE_INFINIOP_TEST(flash_attention_backward)

#define REGISTER_INFINIOP_TEST(name) \
{ \
Expand All @@ -30,19 +32,21 @@ DECLARE_INFINIOP_TEST(sub)
/*
* Register all the tests here
*/
#define TEST_BUILDER_MAPPINGS \
{ \
REGISTER_INFINIOP_TEST(gemm) \
REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(add) \
REGISTER_INFINIOP_TEST(mul) \
REGISTER_INFINIOP_TEST(clip) \
REGISTER_INFINIOP_TEST(swiglu) \
REGISTER_INFINIOP_TEST(rope) \
REGISTER_INFINIOP_TEST(rms_norm) \
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
#define TEST_BUILDER_MAPPINGS \
{ \
REGISTER_INFINIOP_TEST(gemm) \
REGISTER_INFINIOP_TEST(random_sample) \
REGISTER_INFINIOP_TEST(add) \
REGISTER_INFINIOP_TEST(mul) \
REGISTER_INFINIOP_TEST(clip) \
REGISTER_INFINIOP_TEST(swiglu) \
REGISTER_INFINIOP_TEST(rope) \
REGISTER_INFINIOP_TEST(rms_norm) \
REGISTER_INFINIOP_TEST(causal_softmax) \
REGISTER_INFINIOP_TEST(rearrange) \
REGISTER_INFINIOP_TEST(sub) \
REGISTER_INFINIOP_TEST(flash_attention) \
REGISTER_INFINIOP_TEST(flash_attention_backward) \
}

namespace infiniop_test {
Expand Down
158 changes: 158 additions & 0 deletions src/infiniop-test/src/ops/flash_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#include "ops.hpp"
#include "utils.hpp"
#include <infinirt.h>
#include <iomanip>
#include <iostream>

namespace infiniop_test::flash_attention {
struct Test::Attributes {
int mask_type;
std::shared_ptr<Tensor> q;
std::shared_ptr<Tensor> k;
std::shared_ptr<Tensor> v;
std::shared_ptr<Tensor> mask;
std::shared_ptr<Tensor> out;
std::shared_ptr<Tensor> l;
std::shared_ptr<Tensor> ans;
};

std::shared_ptr<Test> Test::build(
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
double rtol, double atol) {

auto test = std::shared_ptr<Test>(new Test(rtol, atol));
test->_attributes = new Attributes();

if (attributes.find("mask_type") == attributes.end()
|| tensors.find("q") == tensors.end()
|| tensors.find("k") == tensors.end()
|| tensors.find("v") == tensors.end()
|| tensors.find("out") == tensors.end()
|| tensors.find("l") == tensors.end()
|| tensors.find("ans") == tensors.end()) {
throw std::runtime_error("Invalid Test: Missing attributes or tensors");
}

if (tensors.find("mask") == tensors.end()) {
test->_attributes->mask = nullptr;
} else {
test->_attributes->mask = tensors["mask"];
}

test->_attributes->mask_type = *reinterpret_cast<int *>(attributes["mask_type"].data());

test->_attributes->q = tensors["q"];
test->_attributes->k = tensors["k"];
test->_attributes->v = tensors["v"];
test->_attributes->out = tensors["out"];
test->_attributes->l = tensors["l"];
test->_attributes->ans = tensors["ans"];

return test;
}

std::shared_ptr<infiniop_test::Result> Test::run(
infiniopHandle_t handle, infiniDevice_t device, int device_id,
size_t warm_ups, size_t iterations) {

infiniopFlashAttentionDescriptor_t op_desc;
infiniopAttentionMaskType_t mask_type = static_cast<infiniopAttentionMaskType_t>(_attributes->mask_type);
CHECK_OR(infiniopCreateFlashAttentionDescriptor(
handle, &op_desc,
_attributes->out->desc(),
_attributes->l->desc(),
_attributes->q->desc(),
_attributes->k->desc(),
_attributes->v->desc(),
_attributes->mask->desc(),
mask_type),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create FlashAttention descriptor"));

auto out = _attributes->out->to(device, device_id);
auto l = _attributes->l->to(device, device_id);
auto q = _attributes->q->to(device, device_id);
auto k = _attributes->k->to(device, device_id);
auto v = _attributes->v->to(device, device_id);
auto mask = _attributes->mask ? _attributes->mask->to(device, device_id) : nullptr;

size_t workspace_size;
CHECK_OR(infiniopGetFlashAttentionWorkspaceSize(op_desc, &workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size"));
void *workspace = nullptr;
if (workspace_size > 0) {
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace"));
}

CHECK_OR(infiniopFlashAttention(op_desc,
workspace, workspace_size,
out->data(),
l->data(),
q->data(),
k->data(),
v->data(),
mask ? mask->data() : nullptr,
nullptr),
return TEST_FAILED(OP_EXECUTION_FAILED, "FlashAttention execution failed"));

try {
allClose(out, _attributes->ans, _rtol, _atol);
} catch (const std::exception &e) {
return TEST_FAILED(RESULT_INCORRECT, e.what());
}

double elapsed_time = 0;

elapsed_time = benchmark(
[=]() {
infiniopFlashAttention(op_desc,
workspace, workspace_size,
out->data(),
l->data(),
q->data(),
k->data(),
v->data(),
mask ? mask->data() : nullptr,
nullptr);
},
warm_ups, iterations);

if (workspace != nullptr) {
infinirtFree(workspace);
}

return TEST_PASSED(elapsed_time);
}

std::vector<std::string> Test::attribute_names() {
return {"mask_type"};
}

std::vector<std::string> Test::tensor_names() {
return {"q", "k", "v", "mask", "out", "l", "ans"};
}

std::vector<std::string> Test::output_names() {
return {"out", "l"};
}

std::string Test::toString() const {
std::ostringstream oss;
oss << op_name() << std::endl;
oss << "- masktype=" << static_cast<infiniopAttentionMaskType_t>(_attributes->mask_type) << std::endl;
oss << "- q: " << _attributes->q->info() << std::endl;
oss << "- k: " << _attributes->k->info() << std::endl;
oss << "- v: " << _attributes->v->info() << std::endl;
oss << "- mask: " << (_attributes->mask ? _attributes->mask->info() : "none") << std::endl;
oss << "- out: " << _attributes->out->info() << std::endl;
oss << std::scientific << std::setprecision(2);
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
return oss.str();
}

Test::~Test() {
delete _attributes;
}

} // namespace infiniop_test::flash_attention
Loading
Loading