Skip to content

Commit 4dde162

Browse files
committed
T1-3-1: FlashAttention and FlashAttentionBackward
1 parent 9b758b9 commit 4dde162

38 files changed

+4937
-15
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,5 @@ cache/
2020
# JSON
2121
*.json
2222

23-
#GGUF
23+
# GGUF
2424
*.gguf

include/infinicore.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,10 @@ typedef enum {
7272
INFINI_DTYPE_BF16 = 19,
7373
} infiniDtype_t;
7474

75+
typedef enum {
76+
INFINIOP_ATTENTION_MASK_TYPE_NONE = 0,
77+
INFINIOP_ATTENTION_MASK_TYPE_FULL = 1,
78+
INFINIOP_ATTENTION_MASK_TYPE_CAUSAL = 2,
79+
} infiniopAttentionMaskType_t;
80+
7581
#endif // __INFINICORE_API_H__

include/infiniop.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "infiniop/ops/causal_softmax.h"
88
#include "infiniop/ops/clip.h"
99
#include "infiniop/ops/conv.h"
10+
#include "infiniop/ops/flash_attention.h"
11+
#include "infiniop/ops/flash_attention_backward.h"
1012
#include "infiniop/ops/gemm.h"
1113
#include "infiniop/ops/mul.h"
1214
#include "infiniop/ops/random_sample.h"
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#ifndef __INFINIOP_FLASH_ATTENTION_API_H__
2+
#define __INFINIOP_FLASH_ATTENTION_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopFlashAttentionDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
infiniopTensorDescriptor_t l_desc,
13+
infiniopTensorDescriptor_t q_desc,
14+
infiniopTensorDescriptor_t k_desc,
15+
infiniopTensorDescriptor_t v_desc,
16+
infiniopTensorDescriptor_t mask_desc,
17+
infiniopAttentionMaskType_t mask_type);
18+
19+
__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
20+
infiniopFlashAttentionDescriptor_t desc,
21+
size_t *size);
22+
23+
__C __export infiniStatus_t infiniopFlashAttention(
24+
infiniopFlashAttentionDescriptor_t desc,
25+
void *workspace,
26+
size_t workspace_size,
27+
void *out,
28+
void *l,
29+
const void *q,
30+
const void *k,
31+
const void *v,
32+
const void *mask,
33+
void *stream);
34+
35+
__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(infiniopFlashAttentionDescriptor_t desc);
36+
37+
#endif
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#ifndef __INFINIOP_FLASH_ATTENTION_BACKWARD_H__
2+
#define __INFINIOP_FLASH_ATTENTION_BACKWARD_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopFlashAttentionBackwardDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateFlashAttentionBackwardDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopFlashAttentionBackwardDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t grad_q_desc,
12+
infiniopTensorDescriptor_t grad_k_desc,
13+
infiniopTensorDescriptor_t grad_v_desc,
14+
infiniopTensorDescriptor_t q_desc,
15+
infiniopTensorDescriptor_t k_desc,
16+
infiniopTensorDescriptor_t v_desc,
17+
infiniopTensorDescriptor_t grad_out_desc,
18+
infiniopTensorDescriptor_t mask_desc,
19+
infiniopAttentionMaskType_t mask_type);
20+
21+
__C __export infiniStatus_t infiniopGetFlashAttentionBackwardWorkspaceSize(
22+
infiniopFlashAttentionBackwardDescriptor_t desc,
23+
size_t *size);
24+
25+
__C __export infiniStatus_t infiniopFlashAttentionBackward(
26+
infiniopFlashAttentionBackwardDescriptor_t desc,
27+
void *workspace,
28+
size_t workspace_size,
29+
void *grad_q,
30+
void *grad_k,
31+
void *grad_v,
32+
const void *q,
33+
const void *k,
34+
const void *v,
35+
const void *grad_out,
36+
const void *mask,
37+
void *stream);
38+
39+
__C __export infiniStatus_t infiniopDestroyFlashAttentionBackwardDescriptor(
40+
infiniopFlashAttentionBackwardDescriptor_t desc);
41+
42+
#endif

src/infiniop-test/include/ops.hpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ DECLARE_INFINIOP_TEST(add)
1616
DECLARE_INFINIOP_TEST(causal_softmax)
1717
DECLARE_INFINIOP_TEST(rearrange)
1818
DECLARE_INFINIOP_TEST(sub)
19+
DECLARE_INFINIOP_TEST(flash_attention)
20+
DECLARE_INFINIOP_TEST(flash_attention_backward)
1921

2022
#define REGISTER_INFINIOP_TEST(name) \
2123
{ \
@@ -30,19 +32,21 @@ DECLARE_INFINIOP_TEST(sub)
3032
/*
3133
* Register all the tests here
3234
*/
33-
#define TEST_BUILDER_MAPPINGS \
34-
{ \
35-
REGISTER_INFINIOP_TEST(gemm) \
36-
REGISTER_INFINIOP_TEST(random_sample) \
37-
REGISTER_INFINIOP_TEST(add) \
38-
REGISTER_INFINIOP_TEST(mul) \
39-
REGISTER_INFINIOP_TEST(clip) \
40-
REGISTER_INFINIOP_TEST(swiglu) \
41-
REGISTER_INFINIOP_TEST(rope) \
42-
REGISTER_INFINIOP_TEST(rms_norm) \
43-
REGISTER_INFINIOP_TEST(causal_softmax) \
44-
REGISTER_INFINIOP_TEST(rearrange) \
45-
REGISTER_INFINIOP_TEST(sub) \
35+
#define TEST_BUILDER_MAPPINGS \
36+
{ \
37+
REGISTER_INFINIOP_TEST(gemm) \
38+
REGISTER_INFINIOP_TEST(random_sample) \
39+
REGISTER_INFINIOP_TEST(add) \
40+
REGISTER_INFINIOP_TEST(mul) \
41+
REGISTER_INFINIOP_TEST(clip) \
42+
REGISTER_INFINIOP_TEST(swiglu) \
43+
REGISTER_INFINIOP_TEST(rope) \
44+
REGISTER_INFINIOP_TEST(rms_norm) \
45+
REGISTER_INFINIOP_TEST(causal_softmax) \
46+
REGISTER_INFINIOP_TEST(rearrange) \
47+
REGISTER_INFINIOP_TEST(sub) \
48+
REGISTER_INFINIOP_TEST(flash_attention) \
49+
REGISTER_INFINIOP_TEST(flash_attention_backward) \
4650
}
4751

4852
namespace infiniop_test {
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
#include "ops.hpp"
2+
#include "utils.hpp"
3+
#include <infinirt.h>
4+
#include <iomanip>
5+
#include <iostream>
6+
7+
namespace infiniop_test::flash_attention {
8+
struct Test::Attributes {
9+
int mask_type;
10+
std::shared_ptr<Tensor> q;
11+
std::shared_ptr<Tensor> k;
12+
std::shared_ptr<Tensor> v;
13+
std::shared_ptr<Tensor> mask;
14+
std::shared_ptr<Tensor> out;
15+
std::shared_ptr<Tensor> l;
16+
std::shared_ptr<Tensor> ans;
17+
};
18+
19+
std::shared_ptr<Test> Test::build(
20+
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
21+
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
22+
double rtol, double atol) {
23+
24+
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
25+
test->_attributes = new Attributes();
26+
27+
if (attributes.find("mask_type") == attributes.end()
28+
|| tensors.find("q") == tensors.end()
29+
|| tensors.find("k") == tensors.end()
30+
|| tensors.find("v") == tensors.end()
31+
|| tensors.find("out") == tensors.end()
32+
|| tensors.find("l") == tensors.end()
33+
|| tensors.find("ans") == tensors.end()) {
34+
throw std::runtime_error("Invalid Test: Missing attributes or tensors");
35+
}
36+
37+
if (tensors.find("mask") == tensors.end()) {
38+
test->_attributes->mask = nullptr;
39+
} else {
40+
test->_attributes->mask = tensors["mask"];
41+
}
42+
43+
test->_attributes->mask_type = *reinterpret_cast<int *>(attributes["mask_type"].data());
44+
45+
test->_attributes->q = tensors["q"];
46+
test->_attributes->k = tensors["k"];
47+
test->_attributes->v = tensors["v"];
48+
test->_attributes->out = tensors["out"];
49+
test->_attributes->l = tensors["l"];
50+
test->_attributes->ans = tensors["ans"];
51+
52+
return test;
53+
}
54+
55+
std::shared_ptr<infiniop_test::Result> Test::run(
56+
infiniopHandle_t handle, infiniDevice_t device, int device_id,
57+
size_t warm_ups, size_t iterations) {
58+
59+
infiniopFlashAttentionDescriptor_t op_desc;
60+
infiniopAttentionMaskType_t mask_type = static_cast<infiniopAttentionMaskType_t>(_attributes->mask_type);
61+
CHECK_OR(infiniopCreateFlashAttentionDescriptor(
62+
handle, &op_desc,
63+
_attributes->out->desc(),
64+
_attributes->l->desc(),
65+
_attributes->q->desc(),
66+
_attributes->k->desc(),
67+
_attributes->v->desc(),
68+
_attributes->mask->desc(),
69+
mask_type),
70+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create FlashAttention descriptor"));
71+
72+
auto out = _attributes->out->to(device, device_id);
73+
auto l = _attributes->l->to(device, device_id);
74+
auto q = _attributes->q->to(device, device_id);
75+
auto k = _attributes->k->to(device, device_id);
76+
auto v = _attributes->v->to(device, device_id);
77+
auto mask = _attributes->mask ? _attributes->mask->to(device, device_id) : nullptr;
78+
79+
size_t workspace_size;
80+
CHECK_OR(infiniopGetFlashAttentionWorkspaceSize(op_desc, &workspace_size),
81+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size"));
82+
void *workspace = nullptr;
83+
if (workspace_size > 0) {
84+
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
85+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace"));
86+
}
87+
88+
CHECK_OR(infiniopFlashAttention(op_desc,
89+
workspace, workspace_size,
90+
out->data(),
91+
l->data(),
92+
q->data(),
93+
k->data(),
94+
v->data(),
95+
mask ? mask->data() : nullptr,
96+
nullptr),
97+
return TEST_FAILED(OP_EXECUTION_FAILED, "FlashAttention execution failed"));
98+
99+
try {
100+
allClose(out, _attributes->ans, _rtol, _atol);
101+
} catch (const std::exception &e) {
102+
return TEST_FAILED(RESULT_INCORRECT, e.what());
103+
}
104+
105+
double elapsed_time = 0;
106+
107+
elapsed_time = benchmark(
108+
[=]() {
109+
infiniopFlashAttention(op_desc,
110+
workspace, workspace_size,
111+
out->data(),
112+
l->data(),
113+
q->data(),
114+
k->data(),
115+
v->data(),
116+
mask ? mask->data() : nullptr,
117+
nullptr);
118+
},
119+
warm_ups, iterations);
120+
121+
if (workspace != nullptr) {
122+
infinirtFree(workspace);
123+
}
124+
125+
return TEST_PASSED(elapsed_time);
126+
}
127+
128+
std::vector<std::string> Test::attribute_names() {
129+
return {"mask_type"};
130+
}
131+
132+
std::vector<std::string> Test::tensor_names() {
133+
return {"q", "k", "v", "mask", "out", "l", "ans"};
134+
}
135+
136+
std::vector<std::string> Test::output_names() {
137+
return {"out", "l"};
138+
}
139+
140+
std::string Test::toString() const {
141+
std::ostringstream oss;
142+
oss << op_name() << std::endl;
143+
oss << "- masktype=" << static_cast<infiniopAttentionMaskType_t>(_attributes->mask_type) << std::endl;
144+
oss << "- q: " << _attributes->q->info() << std::endl;
145+
oss << "- k: " << _attributes->k->info() << std::endl;
146+
oss << "- v: " << _attributes->v->info() << std::endl;
147+
oss << "- mask: " << (_attributes->mask ? _attributes->mask->info() : "none") << std::endl;
148+
oss << "- out: " << _attributes->out->info() << std::endl;
149+
oss << std::scientific << std::setprecision(2);
150+
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
151+
return oss.str();
152+
}
153+
154+
Test::~Test() {
155+
delete _attributes;
156+
}
157+
158+
} // namespace infiniop_test::flash_attention

0 commit comments

Comments
 (0)