diff --git a/.gitignore b/.gitignore index 3024ac452..84d01353a 100644 --- a/.gitignore +++ b/.gitignore @@ -20,5 +20,5 @@ cache/ # JSON *.json -#GGUF +# GGUF *.gguf diff --git a/include/infiniop.h b/include/infiniop.h index 0acad83f9..e2013d4dd 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -8,6 +8,8 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize.h" +#include "infiniop/ops/flash_attention.h" +#include "infiniop/ops/flash_attention_backward.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/random_sample.h" diff --git a/include/infiniop/ops/flash_attention.h b/include/infiniop/ops/flash_attention.h new file mode 100644 index 000000000..2c7ab0ec5 --- /dev/null +++ b/include/infiniop/ops/flash_attention.h @@ -0,0 +1,43 @@ +#ifndef __INFINIOP_FLASH_ATTENTION_API_H__ +#define __INFINIOP_FLASH_ATTENTION_API_H__ + +#include "../operator_descriptor.h" + +typedef enum { + INFINIOP_ATTENTION_MASK_TYPE_NONE = 0, + INFINIOP_ATTENTION_MASK_TYPE_FULL = 1, + INFINIOP_ATTENTION_MASK_TYPE_CAUSAL = 2, +} infiniopAttentionMaskType_t; + +typedef struct InfiniopDescriptor *infiniopFlashAttentionDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t l_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type); + +__C __export infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void *l, + const void *q, + const void *k, + const void *v, + const void *mask, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFlashAttentionDescriptor(infiniopFlashAttentionDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/flash_attention_backward.h b/include/infiniop/ops/flash_attention_backward.h new file mode 100644 index 000000000..a11c1377d --- /dev/null +++ b/include/infiniop/ops/flash_attention_backward.h @@ -0,0 +1,43 @@ +#ifndef __INFINIOP_FLASH_ATTENTION_BACKWARD_H__ +#define __INFINIOP_FLASH_ATTENTION_BACKWARD_H__ + +#include "../operator_descriptor.h" +#include "flash_attention.h" + +typedef struct InfiniopDescriptor *infiniopFlashAttentionBackwardDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFlashAttentionBackwardDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionBackwardDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t grad_q_desc, + infiniopTensorDescriptor_t grad_k_desc, + infiniopTensorDescriptor_t grad_v_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t grad_out_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type); + +__C __export infiniStatus_t infiniopGetFlashAttentionBackwardWorkspaceSize( + infiniopFlashAttentionBackwardDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopFlashAttentionBackward( + infiniopFlashAttentionBackwardDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *grad_q, + void *grad_k, + void *grad_v, + const void *q, + const void *k, + const void *v, + const void *grad_out, + const void *mask, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFlashAttentionBackwardDescriptor( + infiniopFlashAttentionBackwardDescriptor_t desc); + +#endif diff --git a/scripts/python_test.py b/scripts/python_test.py index 5348c8c69..98015370e 100644 --- a/scripts/python_test.py +++ b/scripts/python_test.py @@ -25,6 +25,8 @@ def run_tests(args): "sub.py", "swiglu.py", "softplus.py", + "flash_attention.py", + "flash_attention_backward.py", ]: result = subprocess.run( f"python {test} {args} --debug", text=True, encoding="utf-8", shell=True diff --git a/src/infiniop-test/include/ops.hpp b/src/infiniop-test/include/ops.hpp index 3820f7cfd..086e4e5f3 100644 --- a/src/infiniop-test/include/ops.hpp +++ b/src/infiniop-test/include/ops.hpp @@ -16,6 +16,8 @@ DECLARE_INFINIOP_TEST(add) DECLARE_INFINIOP_TEST(causal_softmax) DECLARE_INFINIOP_TEST(rearrange) DECLARE_INFINIOP_TEST(sub) +DECLARE_INFINIOP_TEST(flash_attention) +DECLARE_INFINIOP_TEST(flash_attention_backward) #define REGISTER_INFINIOP_TEST(name) \ { \ @@ -30,19 +32,21 @@ DECLARE_INFINIOP_TEST(sub) /* * Register all the tests here */ -#define TEST_BUILDER_MAPPINGS \ - { \ - REGISTER_INFINIOP_TEST(gemm) \ - REGISTER_INFINIOP_TEST(random_sample) \ - REGISTER_INFINIOP_TEST(add) \ - REGISTER_INFINIOP_TEST(mul) \ - REGISTER_INFINIOP_TEST(clip) \ - REGISTER_INFINIOP_TEST(swiglu) \ - REGISTER_INFINIOP_TEST(rope) \ - REGISTER_INFINIOP_TEST(rms_norm) \ - REGISTER_INFINIOP_TEST(causal_softmax) \ - REGISTER_INFINIOP_TEST(rearrange) \ - REGISTER_INFINIOP_TEST(sub) \ +#define TEST_BUILDER_MAPPINGS \ + { \ + REGISTER_INFINIOP_TEST(gemm) \ + REGISTER_INFINIOP_TEST(random_sample) \ + REGISTER_INFINIOP_TEST(add) \ + REGISTER_INFINIOP_TEST(mul) \ + REGISTER_INFINIOP_TEST(clip) \ + REGISTER_INFINIOP_TEST(swiglu) \ + REGISTER_INFINIOP_TEST(rope) \ + REGISTER_INFINIOP_TEST(rms_norm) \ + REGISTER_INFINIOP_TEST(causal_softmax) \ + REGISTER_INFINIOP_TEST(rearrange) \ + REGISTER_INFINIOP_TEST(sub) \ + REGISTER_INFINIOP_TEST(flash_attention) \ + REGISTER_INFINIOP_TEST(flash_attention_backward) \ } namespace infiniop_test { diff --git a/src/infiniop-test/src/ops/flash_attention.cpp b/src/infiniop-test/src/ops/flash_attention.cpp new file mode 100644 index 000000000..f03cd8455 --- /dev/null +++ b/src/infiniop-test/src/ops/flash_attention.cpp @@ -0,0 +1,158 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::flash_attention { +struct Test::Attributes { + int mask_type; + std::shared_ptr q; + std::shared_ptr k; + std::shared_ptr v; + std::shared_ptr mask; + std::shared_ptr out; + std::shared_ptr l; + std::shared_ptr ans; +}; + +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + + if (attributes.find("mask_type") == attributes.end() + || tensors.find("q") == tensors.end() + || tensors.find("k") == tensors.end() + || tensors.find("v") == tensors.end() + || tensors.find("out") == tensors.end() + || tensors.find("l") == tensors.end() + || tensors.find("ans") == tensors.end()) { + throw std::runtime_error("Invalid Test: Missing attributes or tensors"); + } + + if (tensors.find("mask") == tensors.end()) { + test->_attributes->mask = nullptr; + } else { + test->_attributes->mask = tensors["mask"]; + } + + test->_attributes->mask_type = *reinterpret_cast(attributes["mask_type"].data()); + + test->_attributes->q = tensors["q"]; + test->_attributes->k = tensors["k"]; + test->_attributes->v = tensors["v"]; + test->_attributes->out = tensors["out"]; + test->_attributes->l = tensors["l"]; + test->_attributes->ans = tensors["ans"]; + + return test; +} + +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, + size_t warm_ups, size_t iterations) { + + infiniopFlashAttentionDescriptor_t op_desc; + infiniopAttentionMaskType_t mask_type = static_cast(_attributes->mask_type); + CHECK_OR(infiniopCreateFlashAttentionDescriptor( + handle, &op_desc, + _attributes->out->desc(), + _attributes->l->desc(), + _attributes->q->desc(), + _attributes->k->desc(), + _attributes->v->desc(), + _attributes->mask->desc(), + mask_type), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create FlashAttention descriptor")); + + auto out = _attributes->out->to(device, device_id); + auto l = _attributes->l->to(device, device_id); + auto q = _attributes->q->to(device, device_id); + auto k = _attributes->k->to(device, device_id); + auto v = _attributes->v->to(device, device_id); + auto mask = _attributes->mask ? _attributes->mask->to(device, device_id) : nullptr; + + size_t workspace_size; + CHECK_OR(infiniopGetFlashAttentionWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size")); + void *workspace = nullptr; + if (workspace_size > 0) { + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace")); + } + + CHECK_OR(infiniopFlashAttention(op_desc, + workspace, workspace_size, + out->data(), + l->data(), + q->data(), + k->data(), + v->data(), + mask ? mask->data() : nullptr, + nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "FlashAttention execution failed")); + + try { + allClose(out, _attributes->ans, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + double elapsed_time = 0; + + elapsed_time = benchmark( + [=]() { + infiniopFlashAttention(op_desc, + workspace, workspace_size, + out->data(), + l->data(), + q->data(), + k->data(), + v->data(), + mask ? mask->data() : nullptr, + nullptr); + }, + warm_ups, iterations); + + if (workspace != nullptr) { + infinirtFree(workspace); + } + + return TEST_PASSED(elapsed_time); +} + +std::vector Test::attribute_names() { + return {"mask_type"}; +} + +std::vector Test::tensor_names() { + return {"q", "k", "v", "mask", "out", "l", "ans"}; +} + +std::vector Test::output_names() { + return {"out", "l"}; +} + +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- masktype=" << static_cast(_attributes->mask_type) << std::endl; + oss << "- q: " << _attributes->q->info() << std::endl; + oss << "- k: " << _attributes->k->info() << std::endl; + oss << "- v: " << _attributes->v->info() << std::endl; + oss << "- mask: " << (_attributes->mask ? _attributes->mask->info() : "none") << std::endl; + oss << "- out: " << _attributes->out->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +Test::~Test() { + delete _attributes; +} + +} // namespace infiniop_test::flash_attention diff --git a/src/infiniop-test/src/ops/flash_attention_backward.cpp b/src/infiniop-test/src/ops/flash_attention_backward.cpp new file mode 100644 index 000000000..3c48c7f47 --- /dev/null +++ b/src/infiniop-test/src/ops/flash_attention_backward.cpp @@ -0,0 +1,184 @@ +#include "ops.hpp" +#include "utils.hpp" +#include +#include +#include + +namespace infiniop_test::flash_attention_backward { +struct Test::Attributes { + int mask_type; + std::shared_ptr q; + std::shared_ptr k; + std::shared_ptr v; + std::shared_ptr mask; + std::shared_ptr grad_out; + std::shared_ptr grad_q; + std::shared_ptr grad_k; + std::shared_ptr grad_v; + std::shared_ptr ans_grad_q; + std::shared_ptr ans_grad_k; + std::shared_ptr ans_grad_v; +}; + +std::shared_ptr Test::build( + std::unordered_map> attributes, + std::unordered_map> tensors, + double rtol, double atol) { + + auto test = std::shared_ptr(new Test(rtol, atol)); + test->_attributes = new Attributes(); + + if (attributes.find("mask_type") == attributes.end() + || tensors.find("q") == tensors.end() + || tensors.find("k") == tensors.end() + || tensors.find("v") == tensors.end() + || tensors.find("grad_out") == tensors.end() + || tensors.find("grad_q") == tensors.end() + || tensors.find("grad_k") == tensors.end() + || tensors.find("grad_v") == tensors.end() + || tensors.find("ans_grad_q") == tensors.end() + || tensors.find("ans_grad_k") == tensors.end() + || tensors.find("ans_grad_v") == tensors.end()) { + throw std::runtime_error("Invalid Test: Missing attributes or tensors"); + } + + if (tensors.find("mask") == tensors.end()) { + test->_attributes->mask = nullptr; + } else { + test->_attributes->mask = tensors["mask"]; + } + + test->_attributes->mask_type = *reinterpret_cast(attributes["mask_type"].data()); + + test->_attributes->q = tensors["q"]; + test->_attributes->k = tensors["k"]; + test->_attributes->v = tensors["v"]; + test->_attributes->grad_out = tensors["grad_out"]; + test->_attributes->grad_q = tensors["grad_q"]; + test->_attributes->grad_k = tensors["grad_k"]; + test->_attributes->grad_v = tensors["grad_v"]; + test->_attributes->ans_grad_q = tensors["ans_grad_q"]; + test->_attributes->ans_grad_k = tensors["ans_grad_k"]; + test->_attributes->ans_grad_v = tensors["ans_grad_v"]; + + return test; +} + +std::shared_ptr Test::run( + infiniopHandle_t handle, infiniDevice_t device, int device_id, + size_t warm_ups, size_t iterations) { + + infiniopFlashAttentionBackwardDescriptor_t op_desc; + infiniopAttentionMaskType_t mask_type = static_cast(_attributes->mask_type); + CHECK_OR(infiniopCreateFlashAttentionBackwardDescriptor( + handle, &op_desc, + _attributes->grad_q->desc(), + _attributes->grad_k->desc(), + _attributes->grad_v->desc(), + _attributes->q->desc(), + _attributes->k->desc(), + _attributes->v->desc(), + _attributes->grad_out->desc(), + _attributes->mask->desc(), + mask_type), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to create FlashAttentionBackward descriptor")); + + auto grad_q = _attributes->grad_q->to(device, device_id); + auto grad_k = _attributes->grad_k->to(device, device_id); + auto grad_v = _attributes->grad_v->to(device, device_id); + auto q = _attributes->q->to(device, device_id); + auto k = _attributes->k->to(device, device_id); + auto v = _attributes->v->to(device, device_id); + auto grad_out = _attributes->grad_out->to(device, device_id); + auto mask = _attributes->mask ? _attributes->mask->to(device, device_id) : nullptr; + + size_t workspace_size; + CHECK_OR(infiniopGetFlashAttentionBackwardWorkspaceSize(op_desc, &workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size")); + void *workspace = nullptr; + if (workspace_size > 0) { + CHECK_OR(infinirtMalloc(&workspace, workspace_size), + return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace")); + } + + CHECK_OR(infiniopFlashAttentionBackward(op_desc, + workspace, workspace_size, + grad_q->data(), + grad_k->data(), + grad_v->data(), + q->data(), + k->data(), + v->data(), + grad_out->data(), + mask ? mask->data() : nullptr, + nullptr), + return TEST_FAILED(OP_EXECUTION_FAILED, "Failed to execute FlashAttentionBackward")); + + try { + allClose(grad_q, _attributes->ans_grad_q, _rtol, _atol); + allClose(grad_k, _attributes->ans_grad_k, _rtol, _atol); + allClose(grad_v, _attributes->ans_grad_v, _rtol, _atol); + } catch (const std::exception &e) { + return TEST_FAILED(RESULT_INCORRECT, e.what()); + } + + double elapsed_time = 0; + + elapsed_time = benchmark( + [=]() { + infiniopFlashAttentionBackward(op_desc, + workspace, workspace_size, + grad_q->data(), + grad_k->data(), + grad_v->data(), + q->data(), + k->data(), + v->data(), + grad_out->data(), + mask ? mask->data() : nullptr, + nullptr); + }, + warm_ups, iterations); + + if (workspace != nullptr) { + infinirtFree(workspace); + } + + return TEST_PASSED(elapsed_time); +} + +std::vector Test::attribute_names() { + return {"mask_type"}; +} + +std::vector Test::tensor_names() { + return {"grad_q", "grad_k", "grad_v", "q", "k", "v", "grad_out", "mask", + "ans_grad_q", "ans_grad_k", "ans_grad_v"}; +} + +std::vector Test::output_names() { + return {"grad_q", "grad_k", "grad_v"}; +} + +std::string Test::toString() const { + std::ostringstream oss; + oss << op_name() << std::endl; + oss << "- masktype=" << static_cast(_attributes->mask_type) << std::endl; + oss << "- q: " << _attributes->q->info() << std::endl; + oss << "- k: " << _attributes->k->info() << std::endl; + oss << "- v: " << _attributes->v->info() << std::endl; + oss << "- grad_out: " << _attributes->grad_out->info() << std::endl; + oss << "- mask: " << (_attributes->mask ? _attributes->mask->info() : "none") << std::endl; + oss << "- grad_q: " << _attributes->grad_q->info() << std::endl; + oss << "- grad_k: " << _attributes->grad_k->info() << std::endl; + oss << "- grad_v: " << _attributes->grad_v->info() << std::endl; + oss << std::scientific << std::setprecision(2); + oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl; + return oss.str(); +} + +Test::~Test() { + delete _attributes; +} + +} // namespace infiniop_test::flash_attention_backward diff --git a/src/infiniop/ops/flash_attention/cpu/flash_attention_cpu.cc b/src/infiniop/ops/flash_attention/cpu/flash_attention_cpu.cc new file mode 100644 index 000000000..506278f87 --- /dev/null +++ b/src/infiniop/ops/flash_attention/cpu/flash_attention_cpu.cc @@ -0,0 +1,252 @@ +#include "flash_attention_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "math.h" + +namespace op::flash_attention::cpu { + +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t l_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + + auto info = FlashAttentionInfo::create(out_desc, l_desc, q_desc, k_desc, v_desc, mask_desc, mask_type); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t flashAttention( + T *out, T *l, const T *q, const T *k, const T *v, const float *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + size_t qo_stride_b, size_t qo_stride_s, size_t qo_stride_n, + size_t kv_stride_b, size_t kv_stride_s, size_t kv_stride_n, + size_t l_stride_b, size_t l_stride_s, size_t l_stride_n) { + + std::memset(out, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(T)); + std::memset(l, 0, batch_size * nums_head_q * seq_len_q * sizeof(T)); + + float softmax_scale = 1.f / sqrt(float(head_dim)); + +#pragma omp parallel for + for (ptrdiff_t bx = 0; bx < ptrdiff_t(batch_size); ++bx) { + for (size_t by = 0; by < nums_head_q; ++by) { + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + std::vector q_i(B_r * head_dim); + std::vector k_j(B_c * head_dim); + std::vector v_j(B_c * head_dim); + std::vector s_i(B_r * B_c); + + for (size_t i = 0; i < T_r; ++i) { + for (size_t tx = 0; tx < B_r; ++tx) { + // skip when over q's seq_len + if (i * B_r + tx >= seq_len_q) { + break; + } + + // load q_i + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = utils::cast(q[qo_offset + (i * B_r + tx) * qo_stride_s + x]); + } + + // initial m, l + float row_m_prev = -INFINITY; + float row_l_prev = 0; + + for (size_t j = 0; j < T_c; ++j) { + // load k_j, v_j + for (size_t y = 0; y < B_c; ++y) { + for (size_t x = 0; x < head_dim; ++x) { + k_j[y * head_dim + x] = utils::cast(k[kv_offset + (y + j * B_c) * kv_stride_s + x]); + v_j[y * head_dim + x] = utils::cast(v[kv_offset + (y + j * B_c) * kv_stride_s + x]); + } + } + + float row_m = -INFINITY; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + } + + // S_i^(j) = Q_i @ K_j^T / softmax_scale + float sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + + s_i[tx * B_c + y] = sum; + + row_m = std::max(row_m, sum); + } + + // m_i^(j) = max(m_i^(j - 1), rowmax(S_i^(j))) + float new_row_m = std::max(row_m_prev, row_m); + + // rowsum(P_i^(j)) + float row_l = 0; + for (size_t y = 0; y < B_r; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + // P_i^(j) = exp(S_i^(j) - m_i^(j)) + if (new_row_m == -INFINITY) { + s_i[tx * B_c + y] = 1.0; + } else { + s_i[tx * B_c + y] = exp(s_i[tx * B_c + y] - new_row_m); + } + + row_l += s_i[tx * B_c + y]; + } + + // l_i^(j) = exp(m_i^(j - 1) - m_i^(j - 1)) * l_i^(j - 1) + rowsum(P_i^(j)) + float row_m_exp; + if (row_m_prev == -INFINITY) { + row_m_exp = 1.0; + } else { + row_m_exp = exp(row_m_prev - new_row_m); + } + float new_row_l = (row_m_exp * row_l_prev) + row_l; + + // out_i^(j) = diag(exp(m_i^(j - 1) - m_i^(y))) * O_i^(j - 1) + P_i^(j) * V_j + for (size_t x = 0; x < head_dim; ++x) { + float pv = 0; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + pv += s_i[tx * B_c + y] * v_j[y * head_dim + x]; + } + + out[qo_offset + (i * B_r + tx) * qo_stride_s + x] = utils::cast(row_m_exp * utils::cast(out[qo_offset + (i * B_r + tx) * qo_stride_s + x]) + pv); + } + + row_m_prev = new_row_m; + row_l_prev = new_row_l; + } + + // O_i = O_i^(Tc) / l_i^(Tc) + for (size_t x = 0; x < head_dim; ++x) { + out[qo_offset + (i * B_r + tx) * qo_stride_s + x] = utils::cast(utils::cast(out[qo_offset + (i * B_r + tx) * qo_stride_s + x]) / row_l_prev); + } + + l[l_offset + i * B_r + tx] = utils::cast(row_m_prev + log(row_l_prev)); + } + } + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + void *stream) const { + + size_t B_r = 16; + size_t B_c = 16; + + size_t batch_size = _info.batch_size; + size_t seq_len_q = _info.seq_len_q; + size_t seq_len_kv = _info.seq_len_kv; + size_t nums_head_q = _info.num_heads_q; + size_t nums_head_kv = _info.num_heads_kv; + size_t group = nums_head_q / nums_head_kv; + size_t head_dim = _info.head_dim; + + const void *mask_input = nullptr; + if (_info.is_masked) { + if (_info.mask != nullptr) { + mask_input = _info.mask; + } else { + mask_input = mask; + } + } + + size_t T_r = CEIL_DIV(seq_len_q, B_r); + size_t T_c = CEIL_DIV(seq_len_kv, B_c); + + if (_info.dtype == INFINI_DTYPE_F32) { + CHECK_STATUS(flashAttention( + (float *)out, (float *)l, (float *)q, (float *)k, (float *)v, (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + } else if (_info.dtype == INFINI_DTYPE_F16) { + CHECK_STATUS(flashAttention( + (fp16_t *)out, (fp16_t *)l, (fp16_t *)q, (fp16_t *)k, (fp16_t *)v, (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + CHECK_STATUS(flashAttention( + (bf16_t *)out, (bf16_t *)l, (bf16_t *)q, (bf16_t *)k, (bf16_t *)v, (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::flash_attention::cpu diff --git a/src/infiniop/ops/flash_attention/cpu/flash_attention_cpu.h b/src/infiniop/ops/flash_attention/cpu/flash_attention_cpu.h new file mode 100644 index 000000000..4ac76ce1d --- /dev/null +++ b/src/infiniop/ops/flash_attention/cpu/flash_attention_cpu.h @@ -0,0 +1,7 @@ +#ifndef __FLASH_ATTENTION_CPU_H__ +#define __FLASH_ATTENTION_CPU_H__ +#include "../flash_attention.h" + +DESCRIPTOR(cpu) + +#endif diff --git a/src/infiniop/ops/flash_attention/cuda/kernel.cuh b/src/infiniop/ops/flash_attention/cuda/kernel.cuh new file mode 100644 index 000000000..641a3bfd6 --- /dev/null +++ b/src/infiniop/ops/flash_attention/cuda/kernel.cuh @@ -0,0 +1,180 @@ +#ifndef __FLASH_ATTENTION_KERNEL_CUH__ +#define __FLASH_ATTENTION_KERNEL_CUH__ + +template +__device__ void flashAttentionBlock( + Tdata *out_, Tdata *l_, + const Tdata *q_, const Tdata *k_, const Tdata *v_, const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + const Tdata softmax_scale, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + size_t bx = blockIdx.x; // batch -> batch_size + size_t by = blockIdx.y; // q's head index -> num_heads_q + size_t tx = threadIdx.x; // q's row index within one block -> B_r/B_c + + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + extern __shared__ __align__(sizeof(Tdata)) char shared_mem[]; + Tdata *q_i = reinterpret_cast(shared_mem); + Tdata *k_j = reinterpret_cast(q_i + B_r * head_dim); + Tdata *v_j = reinterpret_cast(k_j + B_c * head_dim); + Tdata *s_i = reinterpret_cast(v_j + B_c * head_dim); + + for (size_t i = 0; i < T_r; ++i) { + // skip when over q's seq_len + if (i * B_r + tx >= seq_len_q) { + break; + } + + // load q_i from HBM to on-chip SRAM + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = q_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + } + // initial m, l + Tdata row_m_prev = -INFINITY; + Tdata row_l_prev = 0; + + for (size_t j = 0; j < T_c; ++j) { + __syncthreads(); + // load k_j, v_j from HBM to on-chip SRAM + for (size_t y = 0; y < B_c; ++y) { + for (size_t x = 0; x < head_dim; ++x) { + k_j[y * head_dim + x] = k_[kv_offset + (y + j * B_c) * kv_stride_s + x]; + v_j[y * head_dim + x] = v_[kv_offset + (y + j * B_c) * kv_stride_s + x]; + } + } + __syncthreads(); + + Tdata row_m = -INFINITY; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + }; + + // S_i^(j) = Q_i @ K_j^T / softmax_scale + Tdata sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + + s_i[tx * B_c + y] = sum; + + if constexpr (std::is_same_v) { + row_m = __float2half(max(__half2float(row_m), __half2float(sum))); + } else if constexpr (std::is_same_v) { + row_m = __hmax(row_m, sum); + } else { + row_m = max(row_m, sum); + } + } + + // m_i^(j) = max(m_i^(j - 1), rowmax(S_i^(j))) + Tdata new_row_m; + + if constexpr (std::is_same_v) { + new_row_m = __float2half(max(__half2float(row_m_prev), __half2float(row_m))); + } else if constexpr (std::is_same_v) { + new_row_m = __hmax(row_m_prev, row_m); + } else { + new_row_m = max(row_m_prev, row_m); + } + + // rowsum(P_i^(j)) + Tdata row_l = 0; + for (size_t y = 0; y < B_r; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + // P_i^(j) = exp(S_i^(j) - m_i^(j)) + if constexpr (std::is_same_v || std::is_same_v) { + if (__hisinf(new_row_m)) { + s_i[tx * B_c + y] = 1.0; + } else { + s_i[tx * B_c + y] = hexp(s_i[tx * B_c + y] - new_row_m); + } + } else { + if (isinf(new_row_m)) { + s_i[tx * B_c + y] = 1.0; + } else { + s_i[tx * B_c + y] = expf(s_i[tx * B_c + y] - new_row_m); + } + } + + row_l += s_i[tx * B_c + y]; + } + + // l_i^(j) = exp(m_i^(j - 1) - m_i^(j - 1)) * l_i^(j - 1) + rowsum(P_i^(j)) + Tdata row_m_exp; + if constexpr (std::is_same_v || std::is_same_v) { + if (__hisinf(row_m_prev)) { + row_m_exp = 1.0; + } else { + row_m_exp = hexp(row_m_prev - new_row_m); + } + } else { + if (isinf(new_row_m)) { + row_m_exp = 1.0; + } else { + row_m_exp = expf(row_m_prev - new_row_m); + } + } + Tdata new_row_l = (row_m_exp * row_l_prev) + row_l; + + // out_i^(j) = diag(exp(m_i^(j - 1) - m_i^(y))) * O_i^(j - 1) + P_i^(j) * V_j + for (size_t x = 0; x < head_dim; ++x) { + Tdata pv = 0; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + pv += s_i[tx * B_c + y] * v_j[y * head_dim + x]; + } + + out_[qo_offset + (i * B_r + tx) * qo_stride_s + x] = row_m_exp * out_[qo_offset + (i * B_r + tx) * qo_stride_s + x] + pv; + } + + row_m_prev = new_row_m; + row_l_prev = new_row_l; + } + + // O_i = O_i^(Tc) / l_i^(Tc) + for (size_t x = 0; x < head_dim; ++x) { + out_[qo_offset + (i * B_r + tx) * qo_stride_s + x] /= row_l_prev; + } + + // L_i = m_i^(Tc) + log(l_i^(Tc)) + if constexpr (std::is_same_v || std::is_same_v) { + l_[l_offset + i * B_r + tx] = row_m_prev + hlog(row_l_prev); + } else { + l_[l_offset + i * B_r + tx] = row_m_prev + logf(row_l_prev); + } + } +} + +#endif // __FLASH_ATTENTION_KERNEL_CUH__ diff --git a/src/infiniop/ops/flash_attention/flash_attention.h b/src/infiniop/ops/flash_attention/flash_attention.h new file mode 100644 index 000000000..b16f9fceb --- /dev/null +++ b/src/infiniop/ops/flash_attention/flash_attention.h @@ -0,0 +1,56 @@ +#ifndef FLASH_ATTENTION_H +#define FLASH_ATTENTION_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::flash_attention::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + FlashAttentionInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + FlashAttentionInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t l_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t mask_desc, \ + infiniopAttentionMaskType_t mask_type); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *out, \ + void *l, \ + const void *q, \ + const void *k, \ + const void *v, \ + const void *mask, \ + void *stream) const; \ + }; \ + } + +#endif // FLASH_ATTENTION_H diff --git a/src/infiniop/ops/flash_attention/info.h b/src/infiniop/ops/flash_attention/info.h new file mode 100644 index 000000000..20a4def40 --- /dev/null +++ b/src/infiniop/ops/flash_attention/info.h @@ -0,0 +1,197 @@ +#ifndef __FLASH_ATTENTION_INFO_H__ +#define __FLASH_ATTENTION_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/flash_attention.h" +#include +#include + +namespace op::flash_attention { + +class FlashAttentionInfo { +private: + FlashAttentionInfo() = default; + +public: + infiniDtype_t dtype; + size_t batch_size; + size_t seq_len_q, seq_len_kv; + size_t num_heads_q, num_heads_kv; + size_t head_dim; + + ptrdiff_t qo_stride_b; + ptrdiff_t qo_stride_s; + ptrdiff_t qo_stride_n; + ptrdiff_t qo_stride_d; + + ptrdiff_t kv_stride_b; + ptrdiff_t kv_stride_s; + ptrdiff_t kv_stride_n; + ptrdiff_t kv_stride_d; + + ptrdiff_t l_stride_b; + ptrdiff_t l_stride_s; + ptrdiff_t l_stride_n; + + ptrdiff_t mask_stride_sq; + ptrdiff_t mask_stride_sk; + + void *mask; + bool is_masked; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t l_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + // 检查 l,o,q,k,v 数据类型是否一致 + auto dtype = out_desc->dtype(); + CHECK_OR_RETURN( + dtype == q_desc->dtype() + && dtype == k_desc->dtype() + && dtype == v_desc->dtype() + && dtype == l_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + // 检查 l,o,q,k,v 张量形状 + // q 和 out 的形状必须相同 + auto q_shape = out_desc->shape(); + CHECK_SAME_SHAPE(q_shape, q_desc->shape()); + // k 和 v 的形状必须相同 + auto kv_shape = k_desc->shape(); + CHECK_SAME_SHAPE(kv_shape, v_desc->shape()); + // 检查输入的维度 + auto ndim = q_desc->ndim(); + CHECK_OR_RETURN(ndim == k_desc->ndim(), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(ndim == 3 || ndim == 4, INFINI_STATUS_BAD_TENSOR_SHAPE); + + size_t batch_size_q = 1; + size_t seq_len_q = q_shape[ndim - 3]; + size_t num_heads_q = q_shape[ndim - 2]; + size_t head_dim_q = q_shape[ndim - 1]; + + size_t batch_size_kv = 1; + size_t seq_len_kv = kv_shape[ndim - 3]; + size_t num_heads_kv = kv_shape[ndim - 2]; + size_t head_dim_kv = kv_shape[ndim - 1]; + + ptrdiff_t qo_stride_b = 0, + qo_stride_s = q_desc->stride(ndim - 3), + qo_stride_n = q_desc->stride(ndim - 2), + qo_stride_d = q_desc->stride(ndim - 1); + + ptrdiff_t kv_stride_b = 0, + kv_stride_s = k_desc->stride(ndim - 3), + kv_stride_n = k_desc->stride(ndim - 2), + kv_stride_d = k_desc->stride(ndim - 1); + + if (ndim == 4) { + qo_stride_b = q_desc->stride(0); + kv_stride_b = k_desc->stride(0); + batch_size_q = q_shape[0]; + batch_size_kv = kv_shape[0]; + } + + // batch_size 和 head_dim 是否一致 + CHECK_OR_RETURN(batch_size_q == batch_size_kv, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(head_dim_q == head_dim_kv, INFINI_STATUS_BAD_TENSOR_SHAPE); + // 多头注意力是否整除 + CHECK_OR_RETURN(num_heads_q % num_heads_kv == 0, INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(out_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(q_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + + size_t batch_size = batch_size_q; + size_t head_dim = head_dim_q; + + // 检查 l (log-sum-exp) + CHECK_OR_RETURN(l_desc->ndim() == ndim - 1, INFINI_STATUS_BAD_TENSOR_SHAPE); + + auto l_shape = l_desc->shape(); + size_t batch_size_l = 1; + size_t seq_len_l = l_shape[ndim - 3]; + size_t num_heads_l = l_shape[ndim - 2]; + + ptrdiff_t l_stride_b = 0; + ptrdiff_t l_stride_s = l_desc->stride(ndim - 3); + ptrdiff_t l_stride_n = l_desc->stride(ndim - 2); + if (ndim == 4) { + l_stride_b = l_desc->stride(0); + batch_size_l = l_shape[0]; + } + + CHECK_OR_RETURN(batch_size_l == batch_size, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(seq_len_l == seq_len_q, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(num_heads_l == num_heads_q, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(l_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + + // 处理不同的 MASK_TYPE + ptrdiff_t mask_stride_sq = seq_len_kv, + mask_stride_sk = 1; + void *mask = nullptr; + bool is_masked = true; + + if (mask_type == INFINIOP_ATTENTION_MASK_TYPE_NONE) { + mask_stride_sq = 0; + mask_stride_sk = 0; + is_masked = false; + } else if (mask_type == INFINIOP_ATTENTION_MASK_TYPE_FULL) { + auto mask_dtype = mask_desc->dtype(); + CHECK_DTYPE(mask_dtype, INFINI_DTYPE_F32); + CHECK_OR_RETURN(mask_desc->ndim() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(mask_desc->dim(0) == seq_len_q && mask_desc->dim(1) == seq_len_kv, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(mask_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + } else if (mask_type == INFINIOP_ATTENTION_MASK_TYPE_CAUSAL) { + size_t mask_size = seq_len_q * seq_len_kv; + float *causal_mask = new float[mask_size]; + + for (size_t i = 0; i < seq_len_q; ++i) { + for (size_t j = 0; j < seq_len_kv; ++j) { + if (j > i) { + causal_mask[i * seq_len_kv + j] = -INFINITY; + } else { + causal_mask[i * seq_len_kv + j] = 0.0f; + } + } + } + + mask = causal_mask; + } + + return utils::Result(FlashAttentionInfo{ + dtype, + batch_size, + seq_len_q, + seq_len_kv, + num_heads_q, + num_heads_kv, + head_dim, + qo_stride_b, + qo_stride_s, + qo_stride_n, + qo_stride_d, + kv_stride_b, + kv_stride_s, + kv_stride_n, + kv_stride_d, + l_stride_b, + l_stride_s, + l_stride_n, + mask_stride_sq, + mask_stride_sk, + mask, + is_masked, + }); + } +}; + +} // namespace op::flash_attention + +#endif // __FLASH_ATTENTION_INFO_H__ diff --git a/src/infiniop/ops/flash_attention/metax/flash_attention_kernel.cuh b/src/infiniop/ops/flash_attention/metax/flash_attention_kernel.cuh new file mode 100644 index 000000000..c0817eefd --- /dev/null +++ b/src/infiniop/ops/flash_attention/metax/flash_attention_kernel.cuh @@ -0,0 +1,175 @@ +#ifndef __FLASH_ATTENTION_KERNEL_CUH__ +#define __FLASH_ATTENTION_KERNEL_CUH__ + +template +__device__ void flashAttentionBlock( + Tdata *out_, Tdata *l_, + const Tdata *q_, const Tdata *k_, const Tdata *v_, const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + const Tdata softmax_scale, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + size_t bx = blockIdx.x; // batch -> batch_size + size_t by = blockIdx.y; // q's head index -> num_heads_q + size_t tx = threadIdx.x; // q's row index within one block -> B_r/B_c + + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + extern __shared__ __align__(sizeof(Tdata)) char shared_mem[]; + Tdata *q_i = reinterpret_cast(shared_mem); + Tdata *k_j = reinterpret_cast(q_i + B_r * head_dim); + Tdata *v_j = reinterpret_cast(k_j + B_c * head_dim); + Tdata *s_i = reinterpret_cast(v_j + B_c * head_dim); + + for (size_t i = 0; i < T_r; ++i) { + // skip when over q's seq_len + if (i * B_r + tx >= seq_len_q) { + break; + } + + // load q_i from HBM to on-chip SRAM + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = q_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + } + // initial m, l + Tdata row_m_prev = -INFINITY; + Tdata row_l_prev = 0; + + for (size_t j = 0; j < T_c; ++j) { + __syncthreads(); + // load k_j, v_j from HBM to on-chip SRAM + for (size_t y = 0; y < B_c; ++y) { + for (size_t x = 0; x < head_dim; ++x) { + k_j[y * head_dim + x] = k_[kv_offset + (y + j * B_c) * kv_stride_s + x]; + v_j[y * head_dim + x] = v_[kv_offset + (y + j * B_c) * kv_stride_s + x]; + } + } + __syncthreads(); + + Tdata row_m = -INFINITY; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + }; + + // S_i^(j) = Q_i @ K_j^T / softmax_scale + Tdata sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + + s_i[tx * B_c + y] = sum; + + if constexpr (std::is_same_v || std::is_same_v) { + row_m = __hmax(row_m, sum); + } else { + row_m = fmaxf(row_m, sum); + } + } + + // m_i^(j) = max(m_i^(j - 1), rowmax(S_i^(j))) + Tdata new_row_m; + if constexpr (std::is_same_v || std::is_same_v) { + new_row_m = __hmax(row_m_prev, row_m); + } else { + new_row_m = fmaxf(row_m_prev, row_m); + } + + // rowsum(P_i^(j)) + Tdata row_l = 0; + for (size_t y = 0; y < B_r; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + // P_i^(j) = exp(S_i^(j) - m_i^(j)) + if constexpr (std::is_same_v || std::is_same_v) { + if (__hisinf(new_row_m)) { + s_i[tx * B_c + y] = 1.0; + } else { + s_i[tx * B_c + y] = hexp(s_i[tx * B_c + y] - new_row_m); + } + } else { + if (isinf(new_row_m)) { + s_i[tx * B_c + y] = 1.0; + } else { + s_i[tx * B_c + y] = expf(s_i[tx * B_c + y] - new_row_m); + } + } + + row_l += s_i[tx * B_c + y]; + } + + // l_i^(j) = exp(m_i^(j - 1) - m_i^(j - 1)) * l_i^(j - 1) + rowsum(P_i^(j)) + Tdata row_m_exp; + if constexpr (std::is_same_v || std::is_same_v) { + if (__hisinf(row_m_prev)) { + row_m_exp = 1.0; + } else { + row_m_exp = hexp(row_m_prev - new_row_m); + } + } else { + if (isinf(new_row_m)) { + row_m_exp = 1.0; + } else { + row_m_exp = expf(row_m_prev - new_row_m); + } + } + Tdata new_row_l = (row_m_exp * row_l_prev) + row_l; + + // out_i^(j) = diag(exp(m_i^(j - 1) - m_i^(y))) * O_i^(j - 1) + P_i^(j) * V_j + for (size_t x = 0; x < head_dim; ++x) { + Tdata pv = 0; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + pv += s_i[tx * B_c + y] * v_j[y * head_dim + x]; + } + + out_[qo_offset + (i * B_r + tx) * qo_stride_s + x] = row_m_exp * out_[qo_offset + (i * B_r + tx) * qo_stride_s + x] + pv; + } + + row_m_prev = new_row_m; + row_l_prev = new_row_l; + } + + // O_i = O_i^(Tc) / l_i^(Tc) + for (size_t x = 0; x < head_dim; ++x) { + out_[qo_offset + (i * B_r + tx) * qo_stride_s + x] /= row_l_prev; + } + + // L_i = m_i^(Tc) + log(l_i^(Tc)) + if constexpr (std::is_same_v || std::is_same_v) { + l_[l_offset + i * B_r + tx] = row_m_prev + hlog(row_l_prev); + } else { + l_[l_offset + i * B_r + tx] = row_m_prev + logf(row_l_prev); + } + } +} + +#endif // __FLASH_ATTENTION_KERNEL_CUH__ diff --git a/src/infiniop/ops/flash_attention/metax/flash_attention_metax.cuh b/src/infiniop/ops/flash_attention/metax/flash_attention_metax.cuh new file mode 100644 index 000000000..17795b349 --- /dev/null +++ b/src/infiniop/ops/flash_attention/metax/flash_attention_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __FLASH_ATTENTION_METAX_CUH__ +#define __FLASH_ATTENTION_METAX_CUH__ + +#include "../flash_attention.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/flash_attention/metax/flash_attention_metax.maca b/src/infiniop/ops/flash_attention/metax/flash_attention_metax.maca new file mode 100644 index 000000000..2ed4ae899 --- /dev/null +++ b/src/infiniop/ops/flash_attention/metax/flash_attention_metax.maca @@ -0,0 +1,182 @@ +#include "../../../devices/metax/metax_common.h" +#include "flash_attention_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" + +#include "flash_attention_kernel.cuh" + +template +INFINIOP_METAX_KERNEL flashAttentionKernel( + Tdata *__restrict__ out_, + Tdata *__restrict__ l_, + const Tdata *__restrict__ q_, + const Tdata *__restrict__ k_, + const Tdata *__restrict__ v_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n +) { + + Tdata softmax_scale = 1.0 / sqrt(head_dim); + flashAttentionBlock( + out_, l_, + q_, k_, v_, mask_, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + softmax_scale, + qo_stride_b, qo_stride_s, qo_stride_n, + kv_stride_b, kv_stride_s, kv_stride_n, + l_stride_b, l_stride_s, l_stride_n + ); +} + + +namespace op::flash_attention::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t l_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + + auto handle = reinterpret_cast(handle_); + + auto info = FlashAttentionInfo::create(out_desc, l_desc, q_desc, k_desc, v_desc, mask_desc, mask_type); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchKernel( + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n, + infiniDtype_t dtype, + hcStream_t stream) { + + hcMemset(out, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(dtype)); + hcMemset(l, 0, batch_size * nums_head_q * seq_len_q * sizeof(dtype)); + + + // calculate SRAM size needed per block + const int sram_size = (2 * B_c * head_dim * sizeof(dtype)) // SRAM size for K_j, V_j + + (B_r * head_dim * sizeof(dtype)) // SRAM size for Q_i + + (B_c * B_r * sizeof(dtype)); // SRAM size for S_i + + dim3 grid_dim(batch_size, nums_head_q); + dim3 block_dim(B_r); + +#define LAUNCHI_KERNEL(Tdata) \ + flashAttentionKernel<<>>( \ + reinterpret_cast(out), \ + reinterpret_cast(l), \ + reinterpret_cast(q), \ + reinterpret_cast(k), \ + reinterpret_cast(v), \ + reinterpret_cast(mask), \ + seq_len_q, seq_len_kv, \ + head_dim, group, \ + B_r, B_c, T_r, T_c, \ + qo_stride_b, qo_stride_s, qo_stride_n, \ + kv_stride_b, kv_stride_s, kv_stride_n, \ + l_stride_b, l_stride_s, l_stride_n) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCHI_KERNEL(half); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCHI_KERNEL(float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCHI_KERNEL(__hpcc_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + void *stream) const { + + size_t B_r = 4; + size_t B_c = 4; + + size_t batch_size = _info.batch_size; + size_t seq_len_q = _info.seq_len_q; + size_t seq_len_kv = _info.seq_len_kv; + size_t nums_head_q = _info.num_heads_q; + size_t nums_head_kv = _info.num_heads_kv; + size_t group = nums_head_q / nums_head_kv; + size_t head_dim = _info.head_dim; + + const void *mask_input = nullptr; + if (_info.is_masked) { + if (_info.mask != nullptr) { + void *mask_temp; + hcMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); + hcMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), hcMemcpyHostToDevice); + mask_input = mask_temp; + hcFree(mask_temp); + } else { + mask_input = mask; + } + } + + size_t T_r = ceil(float(seq_len_q) / B_r); + size_t T_c = ceil(float(seq_len_kv) / B_c); + + auto hc_stream = reinterpret_cast(stream); + + CHECK_STATUS(launchKernel( + out, l, q, k, v, mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, + _info.dtype, + hc_stream)); + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::flash_attention::metax diff --git a/src/infiniop/ops/flash_attention/nvidia/flash_attention_nvidia.cu b/src/infiniop/ops/flash_attention/nvidia/flash_attention_nvidia.cu new file mode 100644 index 000000000..4ba2da314 --- /dev/null +++ b/src/infiniop/ops/flash_attention/nvidia/flash_attention_nvidia.cu @@ -0,0 +1,178 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "flash_attention_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_CUDA_KERNEL flashAttentionKernel( + Tdata *__restrict__ out_, + Tdata *__restrict__ l_, + const Tdata *__restrict__ q_, + const Tdata *__restrict__ k_, + const Tdata *__restrict__ v_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + Tdata softmax_scale = 1.0 / sqrt(head_dim); + flashAttentionBlock( + out_, l_, + q_, k_, v_, mask_, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + softmax_scale, + qo_stride_b, qo_stride_s, qo_stride_n, + kv_stride_b, kv_stride_s, kv_stride_n, + l_stride_b, l_stride_s, l_stride_n); +} + +namespace op::flash_attention::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t l_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + + auto handle = reinterpret_cast(handle_); + + auto info = FlashAttentionInfo::create(out_desc, l_desc, q_desc, k_desc, v_desc, mask_desc, mask_type); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchKernel( + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n, + infiniDtype_t dtype, + cudaStream_t stream) { + + cudaMemset(out, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(dtype)); + cudaMemset(l, 0, batch_size * nums_head_q * seq_len_q * sizeof(dtype)); + + // calculate SRAM size needed per block + const int sram_size = (2 * B_c * head_dim * sizeof(dtype)) // SRAM size for K_j, V_j + + (B_r * head_dim * sizeof(dtype)) // SRAM size for Q_i + + (B_c * B_r * sizeof(dtype)); // SRAM size for S_i + + dim3 grid_dim(batch_size, nums_head_q); + dim3 block_dim(B_r); + +#define LAUNCHI_KERNEL(Tdata) \ + flashAttentionKernel<<>>( \ + reinterpret_cast(out), \ + reinterpret_cast(l), \ + reinterpret_cast(q), \ + reinterpret_cast(k), \ + reinterpret_cast(v), \ + reinterpret_cast(mask), \ + seq_len_q, seq_len_kv, \ + head_dim, group, \ + B_r, B_c, T_r, T_c, \ + qo_stride_b, qo_stride_s, qo_stride_n, \ + kv_stride_b, kv_stride_s, kv_stride_n, \ + l_stride_b, l_stride_s, l_stride_n) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCHI_KERNEL(half); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCHI_KERNEL(float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCHI_KERNEL(__nv_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + void *stream) const { + + size_t B_r = 32; + size_t B_c = 32; + + size_t batch_size = _info.batch_size; + size_t seq_len_q = _info.seq_len_q; + size_t seq_len_kv = _info.seq_len_kv; + size_t nums_head_q = _info.num_heads_q; + size_t nums_head_kv = _info.num_heads_kv; + size_t group = nums_head_q / nums_head_kv; + size_t head_dim = _info.head_dim; + + const void *mask_input = nullptr; + if (_info.is_masked) { + if (_info.mask != nullptr) { + void *mask_temp; + cudaMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); + cudaMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), cudaMemcpyHostToDevice); + mask_input = mask_temp; + cudaFree(mask_temp); + } else { + mask_input = mask; + } + } + + size_t T_r = ceil(float(seq_len_q) / B_r); + size_t T_c = ceil(float(seq_len_kv) / B_c); + + auto cuda_stream = reinterpret_cast(stream); + + CHECK_STATUS(launchKernel( + out, l, q, k, v, mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, + _info.dtype, + cuda_stream)); + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::flash_attention::nvidia diff --git a/src/infiniop/ops/flash_attention/nvidia/flash_attention_nvidia.cuh b/src/infiniop/ops/flash_attention/nvidia/flash_attention_nvidia.cuh new file mode 100644 index 000000000..6f60d7aa8 --- /dev/null +++ b/src/infiniop/ops/flash_attention/nvidia/flash_attention_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __FLASH_ATTENTION_CUDA_H__ +#define __FLASH_ATTENTION_CUDA_H__ + +#include "../flash_attention.h" + +DESCRIPTOR(nvidia) + +#endif // __FLASH_ATTENTION_CUDA_H__ diff --git a/src/infiniop/ops/flash_attention/operator.cc b/src/infiniop/ops/flash_attention/operator.cc new file mode 100644 index 000000000..64b4ef189 --- /dev/null +++ b/src/infiniop/ops/flash_attention/operator.cc @@ -0,0 +1,156 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/flash_attention.h" + +#ifdef ENABLE_CPU_API +#include "cpu/flash_attention_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/flash_attention_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/flash_attention_metax.cuh" +#endif + +__C infiniStatus_t infiniopCreateFlashAttentionDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t l_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::flash_attention::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + l_desc, \ + q_desc, \ + k_desc, \ + v_desc, \ + mask_desc, \ + mask_type); + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetFlashAttentionWorkspaceSize( + infiniopFlashAttentionDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__C infiniStatus_t infiniopFlashAttention( + infiniopFlashAttentionDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, + void *l, + const void *q, + const void *k, + const void *v, + const void *mask, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, l, q, k, v, mask, stream); + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFlashAttentionDescriptor(infiniopFlashAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DESTROY +} diff --git a/src/infiniop/ops/flash_attention_backward/cpu/flash_attention_backward_cpu.cc b/src/infiniop/ops/flash_attention_backward/cpu/flash_attention_backward_cpu.cc new file mode 100644 index 000000000..9fa9ed5d1 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/cpu/flash_attention_backward_cpu.cc @@ -0,0 +1,455 @@ +#include "flash_attention_backward_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "math.h" + +namespace op::flash_attention_backward::cpu { + +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t grad_q_desc, + infiniopTensorDescriptor_t grad_k_desc, + infiniopTensorDescriptor_t grad_v_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t grad_out_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + + auto info = FlashAttentionBackwardInfo::create( + grad_q_desc, grad_k_desc, grad_v_desc, + q_desc, k_desc, v_desc, + grad_out_desc, mask_desc, mask_type); + + *desc_ptr = new Descriptor( + info.take(), + 0, + nullptr, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t flashAttention( + T *out, T *l, const T *q, const T *k, const T *v, const float *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + size_t qo_stride_b, size_t qo_stride_s, size_t qo_stride_n, + size_t kv_stride_b, size_t kv_stride_s, size_t kv_stride_n, + size_t l_stride_b, size_t l_stride_s, size_t l_stride_n) { + + std::memset(out, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(T)); + std::memset(l, 0, batch_size * nums_head_q * seq_len_q * sizeof(T)); + + float softmax_scale = 1.f / sqrt(float(head_dim)); + +#pragma omp parallel for + for (ptrdiff_t bx = 0; bx < ptrdiff_t(batch_size); ++bx) { + for (size_t by = 0; by < nums_head_q; ++by) { + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + std::vector q_i(B_r * head_dim); + std::vector k_j(B_c * head_dim); + std::vector v_j(B_c * head_dim); + std::vector s_i(B_r * B_c); + + for (size_t i = 0; i < T_r; ++i) { + for (size_t tx = 0; tx < B_r; ++tx) { + // skip when over q's seq_len + if (i * B_r + tx >= seq_len_q) { + break; + } + + // load q_i + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = utils::cast(q[qo_offset + (i * B_r + tx) * qo_stride_s + x]); + } + + // initial m, l + float row_m_prev = -INFINITY; + float row_l_prev = 0; + + for (size_t j = 0; j < T_c; ++j) { + // load k_j, v_j + for (size_t y = 0; y < B_c; ++y) { + for (size_t x = 0; x < head_dim; ++x) { + k_j[y * head_dim + x] = utils::cast(k[kv_offset + (y + j * B_c) * kv_stride_s + x]); + v_j[y * head_dim + x] = utils::cast(v[kv_offset + (y + j * B_c) * kv_stride_s + x]); + } + } + + float row_m = -INFINITY; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + } + + // S_i^(j) = Q_i @ K_j^T / softmax_scale + float sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + + s_i[tx * B_c + y] = sum; + + row_m = std::max(row_m, sum); + } + + // m_i^(j) = max(m_i^(j - 1), rowmax(S_i^(j))) + float new_row_m = std::max(row_m_prev, row_m); + + // rowsum(P_i^(j)) + float row_l = 0; + for (size_t y = 0; y < B_r; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + // P_i^(j) = exp(S_i^(j) - m_i^(j)) + if (new_row_m == -INFINITY) { + s_i[tx * B_c + y] = 1.0; + } else { + s_i[tx * B_c + y] = exp(s_i[tx * B_c + y] - new_row_m); + } + + row_l += s_i[tx * B_c + y]; + } + + // l_i^(j) = exp(m_i^(j - 1) - m_i^(j - 1)) * l_i^(j - 1) + rowsum(P_i^(j)) + float row_m_exp; + if (row_m_prev == -INFINITY) { + row_m_exp = 1.0; + } else { + row_m_exp = exp(row_m_prev - new_row_m); + } + float new_row_l = (row_m_exp * row_l_prev) + row_l; + + // out_i^(j) = diag(exp(m_i^(j - 1) - m_i^(y))) * O_i^(j - 1) + P_i^(j) * V_j + for (size_t x = 0; x < head_dim; ++x) { + float pv = 0; + for (size_t y = 0; y < B_c; ++y) { + if (j * B_c + y >= seq_len_kv) { + break; + } + + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + continue; + } + + pv += s_i[tx * B_c + y] * v_j[y * head_dim + x]; + } + + out[qo_offset + (i * B_r + tx) * qo_stride_s + x] = utils::cast(row_m_exp * utils::cast(out[qo_offset + (i * B_r + tx) * qo_stride_s + x]) + pv); + } + + row_m_prev = new_row_m; + row_l_prev = new_row_l; + } + + // O_i = O_i^(Tc) / l_i^(Tc) + for (size_t x = 0; x < head_dim; ++x) { + out[qo_offset + (i * B_r + tx) * qo_stride_s + x] = utils::cast(utils::cast(out[qo_offset + (i * B_r + tx) * qo_stride_s + x]) / row_l_prev); + } + + l[l_offset + i * B_r + tx] = utils::cast(row_m_prev + log(row_l_prev)); + } + } + } + } + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t flashAttentionBackward( + T *grad_q, T *grad_k, T *grad_v, + const T *q, const T *k, const T *v, const T *out, const T *grad_out, const T *l, + const float *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + std::memset(grad_q, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(T)); + std::memset(grad_k, 0, batch_size * nums_head_kv * seq_len_kv * head_dim * sizeof(T)); + std::memset(grad_v, 0, batch_size * nums_head_kv * seq_len_kv * head_dim * sizeof(T)); + + float softmax_scale = 1.f / sqrt(float(head_dim)); + +#pragma omp parallel for + for (ptrdiff_t bx = 0; bx < ptrdiff_t(batch_size); ++bx) { + for (size_t by = 0; by < nums_head_q; ++by) { + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + std::vector q_i(B_r * head_dim); + std::vector out_i(B_r * head_dim); + std::vector grad_out_i(B_r * head_dim); + + std::vector k_j(B_c * head_dim); + std::vector v_j(B_c * head_dim); + std::vector grad_k_j(B_c * head_dim); + std::vector grad_v_j(B_c * head_dim); + + std::vector s_i(B_r * B_c); + std::vector grad_s_i(B_r * B_c); + + std::vector D_i(B_r); + + for (size_t j = 0; j < T_c; ++j) { + for (size_t tx = 0; tx < B_r; ++tx) { + // load k_j, v_j and initialize grad_k_j, grad_q_j to 0 + for (size_t x = 0; x < head_dim; ++x) { + k_j[tx * head_dim + x] = utils::cast(k[kv_offset + (j * B_c + tx) * kv_stride_s + x]); + v_j[tx * head_dim + x] = utils::cast(v[kv_offset + (j * B_c + tx) * kv_stride_s + x]); + grad_k_j[tx * head_dim + x] = 0; + grad_v_j[tx * head_dim + x] = 0; + } + } + + for (size_t i = 0; i < T_r; ++i) { + for (size_t tx = 0; tx < B_r; ++tx) { + // load q_i, out_i, grad_out_i + D_i[tx] = 0; + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = utils::cast(q[qo_offset + (i * B_r + tx) * qo_stride_s + x]); + out_i[tx * head_dim + x] = utils::cast(out[qo_offset + (i * B_r + tx) * qo_stride_s + x]); + grad_out_i[tx * head_dim + x] = utils::cast(grad_out[qo_offset + (i * B_r + tx) * qo_stride_s + x]); + D_i[tx] += grad_out_i[tx * head_dim + x] * out_i[tx * head_dim + x]; + } + float l_curr = utils::cast(l[l_offset + i * B_r + tx]); + + // S_i^(j) = Q_i @ K_j^T * softmax_scale + for (size_t y = 0; y < B_c; ++y) { + // mask + if (mask != nullptr && mask[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + }; + + float sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + s_i[tx * B_c + y] = sum; + } + + // P_i^(j) = exp(S_ij - L_i) + for (size_t y = 0; y < B_c; ++y) { + s_i[tx * B_c + y] = exp(s_i[tx * B_c + y] - l_curr); + } + } + + for (size_t tx = 0; tx < B_r; ++tx) { + // dV_j = dV_j + P_i^(j)^T @ dO_i + for (size_t x = 0; x < head_dim; ++x) { + float sum = 0; + for (size_t y = 0; y < B_r; ++y) { + sum += s_i[y * B_c + tx] * grad_out_i[y * head_dim + x]; + } + grad_v_j[tx * head_dim + x] += sum; + } + + // dP_i^(j) = dO_i @ V_j^T + for (size_t y = 0; y < B_c; ++y) { + float sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += grad_out_i[tx * head_dim + x] * v_j[y * head_dim + x]; + } + grad_s_i[tx * B_c + y] = sum; + } + + // dS_i^(j) = P_i^(j) * (dP_i^(j) - D_i) + for (size_t y = 0; y < B_c; ++y) { + grad_s_i[tx * B_c + y] = s_i[tx * B_c + y] * (grad_s_i[tx * B_c + y] - D_i[tx]); + } + } + + for (size_t tx = 0; tx < B_r; ++tx) { + // dQ_i = dQ_i + dS_i^(j) @ K_j + for (size_t x = 0; x < head_dim; ++x) { + float sum = 0; + for (size_t y = 0; y < B_c; ++y) { + sum += grad_s_i[tx * B_c + y] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + grad_q[qo_offset + (i * B_r + tx) * qo_stride_s + x] = utils::cast(utils::cast(grad_q[qo_offset + (i * B_r + tx) * qo_stride_s + x]) + sum); + } + + // dK_j = dK_j + dS_i^(j)^T @ Q_i + for (size_t x = 0; x < head_dim; ++x) { + float sum = 0; + for (size_t y = 0; y < B_r; ++y) { + sum += grad_s_i[y * B_c + tx] * q_i[y * head_dim + x]; + } + sum *= softmax_scale; + grad_k_j[tx * head_dim + x] += sum; + } + } + } + + for (size_t tx = 0; tx < B_r; ++tx) { + // write dK_j, dV_j + for (size_t x = 0; x < head_dim; ++x) { + grad_k[kv_offset + (j * B_c + tx) * kv_stride_s + x] = utils::cast(utils::cast(grad_k[kv_offset + (j * B_c + tx) * kv_stride_s + x]) + grad_k_j[tx * head_dim + x]); + grad_v[kv_offset + (j * B_c + tx) * kv_stride_s + x] = utils::cast(utils::cast(grad_v[kv_offset + (j * B_c + tx) * kv_stride_s + x]) + grad_v_j[tx * head_dim + x]); + } + } + } + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *grad_q, void *grad_k, void *grad_v, + const void *q, const void *k, const void *v, const void *grad_out, + const void *mask, + void *stream) const { + + size_t B_r = 2; + size_t B_c = 2; + + size_t batch_size = _info.batch_size; + size_t seq_len_q = _info.seq_len_q; + size_t seq_len_kv = _info.seq_len_kv; + size_t nums_head_q = _info.num_heads_q; + size_t nums_head_kv = _info.num_heads_kv; + size_t group = nums_head_q / nums_head_kv; + size_t head_dim = _info.head_dim; + + const void *mask_input = nullptr; + if (_info.is_masked) { + if (_info.mask != nullptr) { + mask_input = _info.mask; + } else { + mask_input = mask; + } + } + + size_t T_r = CEIL_DIV(seq_len_q, B_r); + size_t T_c = CEIL_DIV(seq_len_kv, B_c); + + if (_info.dtype == INFINI_DTYPE_F32) { + float *out = new float[batch_size * nums_head_q * seq_len_q * head_dim]; + float *l = new float[batch_size * nums_head_q * seq_len_q]; + + CHECK_STATUS(flashAttention( + (float *)out, (float *)l, (const float *)q, (const float *)k, (const float *)v, + (const float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + + CHECK_STATUS(flashAttentionBackward( + (float *)grad_q, (float *)grad_k, (float *)grad_v, + (const float *)q, (const float *)k, (const float *)v, (const float *)out, (const float *)grad_out, (const float *)l, + (const float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + + } else if (_info.dtype == INFINI_DTYPE_F16) { + fp16_t *out = new fp16_t[batch_size * nums_head_q * seq_len_q * head_dim]; + fp16_t *l = new fp16_t[batch_size * nums_head_q * seq_len_q]; + + CHECK_STATUS(flashAttention( + (fp16_t *)out, (fp16_t *)l, (fp16_t *)q, (fp16_t *)k, (fp16_t *)v, + (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + + CHECK_STATUS(flashAttentionBackward( + (fp16_t *)grad_q, (fp16_t *)grad_k, (fp16_t *)grad_v, + (fp16_t *)q, (fp16_t *)k, (fp16_t *)v, (fp16_t *)out, (fp16_t *)grad_out, (fp16_t *)l, + (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + + } else if (_info.dtype == INFINI_DTYPE_BF16) { + bf16_t *out = new bf16_t[batch_size * nums_head_q * seq_len_q * head_dim]; + bf16_t *l = new bf16_t[batch_size * nums_head_q * seq_len_q]; + + CHECK_STATUS(flashAttention( + (bf16_t *)out, (bf16_t *)l, (bf16_t *)q, (bf16_t *)k, (bf16_t *)v, + (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + + CHECK_STATUS(flashAttentionBackward( + (bf16_t *)grad_q, (bf16_t *)grad_k, (bf16_t *)grad_v, + (bf16_t *)q, (bf16_t *)k, (bf16_t *)v, (bf16_t *)out, (bf16_t *)grad_out, (bf16_t *)l, + (float *)mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n)); + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::flash_attention_backward::cpu diff --git a/src/infiniop/ops/flash_attention_backward/cpu/flash_attention_backward_cpu.h b/src/infiniop/ops/flash_attention_backward/cpu/flash_attention_backward_cpu.h new file mode 100644 index 000000000..f341383a8 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/cpu/flash_attention_backward_cpu.h @@ -0,0 +1,7 @@ +#ifndef __FLASH_ATTENTION_BACKWARD_CPU_H__ +#define __FLASH_ATTENTION_BACKWARD_CPU_H__ +#include "../flash_attention_backward.h" + +DESCRIPTOR(cpu) + +#endif diff --git a/src/infiniop/ops/flash_attention_backward/cuda/kernel.cuh b/src/infiniop/ops/flash_attention_backward/cuda/kernel.cuh new file mode 100644 index 000000000..09bf0c230 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/cuda/kernel.cuh @@ -0,0 +1,140 @@ +#ifndef __FLASH_ATTENTION_BACKWARD_KERNEL_CUH__ +#define __FLASH_ATTENTION_BACKWARD_KERNEL_CUH__ + +template +__device__ void flashAttentionBackwardBlock( + Tdata *grad_q_, Tdata *grad_k_, Tdata *grad_v_, + const Tdata *q_, const Tdata *k_, const Tdata *v_, + const Tdata *out_, const Tdata *grad_out_, const Tdata *l_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + const Tdata softmax_scale, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + size_t bx = blockIdx.x; // batch -> batch_size + size_t by = blockIdx.y; // q's head index -> num_heads + size_t tx = threadIdx.x; // k's row index within one block -> B_r/B_c + + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + extern __shared__ __align__(sizeof(Tdata)) char shared_mem[]; + Tdata *q_i = reinterpret_cast(shared_mem); + Tdata *out_i = reinterpret_cast(q_i + B_r * head_dim); + Tdata *grad_out_i = reinterpret_cast(out_i + B_r * head_dim); + + Tdata *k_j = reinterpret_cast(grad_out_i + B_r * head_dim); + Tdata *v_j = reinterpret_cast(k_j + B_c * head_dim); + Tdata *grad_k_j = reinterpret_cast(v_j + B_c * head_dim); + Tdata *grad_v_j = reinterpret_cast(grad_k_j + B_c * head_dim); + + Tdata *s_i = reinterpret_cast(grad_v_j + B_c * head_dim); + Tdata *grad_s_i = reinterpret_cast(s_i + B_r * B_c); + + for (size_t j = 0; j < T_c; ++j) { + // load k_j, v_j and initialize grad_k_j, grad_q_j to 0 + for (size_t x = 0; x < head_dim; ++x) { + k_j[tx * head_dim + x] = k_[kv_offset + (j * B_c + tx) * kv_stride_s + x]; + v_j[tx * head_dim + x] = v_[kv_offset + (j * B_c + tx) * kv_stride_s + x]; + grad_k_j[tx * head_dim + x] = 0; + grad_v_j[tx * head_dim + x] = 0; + } + + for (size_t i = 0; i < T_r; ++i) { + __syncthreads(); + // load q_i, out_i, grad_out_i + Tdata D_i = 0; + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = q_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + out_i[tx * head_dim + x] = out_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + grad_out_i[tx * head_dim + x] = grad_out_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + D_i += grad_out_i[tx * head_dim + x] * out_i[tx * head_dim + x]; + } + Tdata l_curr = l_[l_offset + i * B_r + tx]; + + // S_i^(j) = Q_i @ K_j^T * softmax_scale + for (size_t y = 0; y < B_c; ++y) { + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + }; + + Tdata sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + s_i[tx * B_c + y] = sum; + } + + // P_i^(j) = exp(S_ij - L_i) + for (size_t y = 0; y < B_c; ++y) { + if constexpr (std::is_same_v || std::is_same_v) { + s_i[tx * B_c + y] = hexp(s_i[tx * B_c + y] - l_curr); + } else { + s_i[tx * B_c + y] = expf(s_i[tx * B_c + y] - l_curr); + } + } + __syncthreads(); + + // dV_j = dV_j + P_i^(j)^T @ dO_i + for (size_t x = 0; x < head_dim; ++x) { + Tdata sum = 0; + for (size_t y = 0; y < B_r; ++y) { + sum += s_i[y * B_c + tx] * grad_out_i[y * head_dim + x]; + } + grad_v_j[tx * head_dim + x] += sum; + } + + // dP_i^(j) = dO_i @ V_j^T + for (size_t y = 0; y < B_c; ++y) { + Tdata sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += grad_out_i[tx * head_dim + x] * v_j[y * head_dim + x]; + } + grad_s_i[tx * B_c + y] = sum; + } + + // dS_i^(j) = P_i^(j) * (dP_i^(j) - D_i) + for (size_t y = 0; y < B_c; ++y) { + grad_s_i[tx * B_c + y] = s_i[tx * B_c + y] * (grad_s_i[tx * B_c + y] - D_i); + } + + // dQ_i = dQ_i + dS_i^(j) @ K_j + for (size_t x = 0; x < head_dim; ++x) { + Tdata sum = 0; + for (size_t y = 0; y < B_c; ++y) { + sum += grad_s_i[tx * B_c + y] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + grad_q_[qo_offset + (i * B_r + tx) * qo_stride_s + x] += sum; + } + __syncthreads(); + + // dK_j = dK_j + dS_i^(j)^T @ Q_i + for (size_t x = 0; x < head_dim; ++x) { + Tdata sum = 0; + for (size_t y = 0; y < B_r; ++y) { + sum += grad_s_i[y * B_c + tx] * q_i[y * head_dim + x]; + } + sum *= softmax_scale; + grad_k_j[tx * head_dim + x] += sum; + } + } + + // write dK_j, dV_j to HBM + for (size_t x = 0; x < head_dim; ++x) { + size_t offset = bx * kv_stride_b * group + by * kv_stride_n; + grad_k_[offset + (j * B_c + tx) * kv_stride_s * group + x] = grad_k_j[tx * head_dim + x]; + grad_v_[offset + (j * B_c + tx) * kv_stride_s * group + x] = grad_v_j[tx * head_dim + x]; + } + } +} + +#endif // __FLASH_ATTENTION_BACKWARD_KERNEL_CUH__ diff --git a/src/infiniop/ops/flash_attention_backward/flash_attention_backward.h b/src/infiniop/ops/flash_attention_backward/flash_attention_backward.h new file mode 100644 index 000000000..8830db69c --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/flash_attention_backward.h @@ -0,0 +1,60 @@ +#ifndef FLASH_ATTENTION_BACKWARD_H +#define FLASH_ATTENTION_BACKWARD_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::flash_attention_backward::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + FlashAttentionBackwardInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + FlashAttentionBackwardInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t grad_q_desc, \ + infiniopTensorDescriptor_t grad_k_desc, \ + infiniopTensorDescriptor_t grad_v_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_desc, \ + infiniopTensorDescriptor_t v_desc, \ + infiniopTensorDescriptor_t grad_out_desc, \ + infiniopTensorDescriptor_t mask_desc, \ + infiniopAttentionMaskType_t mask_type); \ + \ + infiniStatus_t calculate( \ + void *workspace, \ + size_t workspace_size, \ + void *grad_q, \ + void *grad_k, \ + void *grad_v, \ + const void *q, \ + const void *k, \ + const void *v, \ + const void *grad_out, \ + const void *mask, \ + void *stream) const; \ + }; \ + } + +#endif // FLASH_ATTENTION_BACKWARD_H diff --git a/src/infiniop/ops/flash_attention_backward/info.h b/src/infiniop/ops/flash_attention_backward/info.h new file mode 100644 index 000000000..fc7b0a510 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/info.h @@ -0,0 +1,193 @@ +#ifndef __FLASH_ATTENTION_BACKWARD_INFO_H__ +#define __FLASH_ATTENTION_BACKWARD_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include "infiniop/ops/flash_attention.h" +#include +#include + +namespace op::flash_attention_backward { + +class FlashAttentionBackwardInfo { +private: + FlashAttentionBackwardInfo() = default; + +public: + infiniDtype_t dtype; + size_t batch_size; + size_t seq_len_q, seq_len_kv; + size_t num_heads_q, num_heads_kv; + size_t head_dim; + + ptrdiff_t qo_stride_b; + ptrdiff_t qo_stride_s; + ptrdiff_t qo_stride_n; + ptrdiff_t qo_stride_d; + + ptrdiff_t kv_stride_b; + ptrdiff_t kv_stride_s; + ptrdiff_t kv_stride_n; + ptrdiff_t kv_stride_d; + + ptrdiff_t l_stride_b; + ptrdiff_t l_stride_s; + ptrdiff_t l_stride_n; + + ptrdiff_t mask_stride_sq; + ptrdiff_t mask_stride_sk; + + void *mask; + bool is_masked; + + static utils::Result create( + infiniopTensorDescriptor_t grad_q_desc, + infiniopTensorDescriptor_t grad_k_desc, + infiniopTensorDescriptor_t grad_v_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t grad_out_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + // 检查数据类型是否一致 + auto dtype = grad_out_desc->dtype(); + CHECK_OR_RETURN( + dtype == grad_q_desc->dtype() + && dtype == grad_k_desc->dtype() + && dtype == grad_v_desc->dtype() + && dtype == q_desc->dtype() + && dtype == k_desc->dtype() + && dtype == v_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + + // 检查张量形状 + // grad_q, q, grad_out 形状相同 + auto q_shape = q_desc->shape(); + CHECK_SAME_SHAPE(q_shape, grad_q_desc->shape()); + CHECK_SAME_SHAPE(q_shape, grad_out_desc->shape()); + // grad_k, grad_v, k, v 形状相同 + auto kv_shape = k_desc->shape(); + CHECK_SAME_SHAPE(kv_shape, grad_k_desc->shape()); + CHECK_SAME_SHAPE(kv_shape, grad_v_desc->shape()); + CHECK_SAME_SHAPE(kv_shape, v_desc->shape()); + // 检查输入的纬度 + auto ndim = q_desc->ndim(); + CHECK_OR_RETURN(ndim == k_desc->ndim(), INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(ndim == 3 || ndim == 4, INFINI_STATUS_BAD_TENSOR_SHAPE); + + size_t batch_size_q = 1; + size_t seq_len_q = q_shape[ndim - 3]; + size_t num_heads_q = q_shape[ndim - 2]; + size_t head_dim_q = q_shape[ndim - 1]; + + size_t batch_size_kv = 1; + size_t seq_len_kv = kv_shape[ndim - 3]; + size_t num_heads_kv = kv_shape[ndim - 2]; + size_t head_dim_kv = kv_shape[ndim - 1]; + + ptrdiff_t qo_stride_b = 0, + qo_stride_s = q_desc->stride(ndim - 3), + qo_stride_n = q_desc->stride(ndim - 2), + qo_stride_d = q_desc->stride(ndim - 1); + + ptrdiff_t kv_stride_b = 0, + kv_stride_s = k_desc->stride(ndim - 3), + kv_stride_n = k_desc->stride(ndim - 2), + kv_stride_d = k_desc->stride(ndim - 1); + + ptrdiff_t l_stride_b = 0, + l_stride_s = head_dim_q, + l_stride_n = 1; + + if (ndim == 4) { + qo_stride_b = q_desc->stride(0); + kv_stride_b = k_desc->stride(0); + batch_size_q = q_shape[0]; + batch_size_kv = kv_shape[0]; + + l_stride_b = seq_len_q * head_dim_q; + } + + // batch_size 和 head_dim 是否一致 + CHECK_OR_RETURN(batch_size_q == batch_size_kv, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(head_dim_q == head_dim_kv, INFINI_STATUS_BAD_TENSOR_SHAPE); + // 多头注意力是否整除 + CHECK_OR_RETURN(num_heads_q % num_heads_kv == 0, INFINI_STATUS_BAD_TENSOR_SHAPE); + + CHECK_OR_RETURN(grad_q_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(grad_k_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(grad_v_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(q_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(k_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(v_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(grad_out_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + + size_t batch_size = batch_size_q; + size_t head_dim = head_dim_q; + + // 处理不同的 MASK_TYPE + ptrdiff_t mask_stride_sq = seq_len_kv, + mask_stride_sk = 1; + void *mask = nullptr; + bool is_masked = true; + + if (mask_type == INFINIOP_ATTENTION_MASK_TYPE_NONE) { + mask_stride_sq = 0; + mask_stride_sk = 0; + is_masked = false; + } else if (mask_type == INFINIOP_ATTENTION_MASK_TYPE_FULL) { + auto mask_dtype = mask_desc->dtype(); + CHECK_DTYPE(mask_dtype, INFINI_DTYPE_F32); + CHECK_OR_RETURN(mask_desc->ndim() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(mask_desc->dim(0) == seq_len_q && mask_desc->dim(1) == seq_len_kv, INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(mask_desc->isContiguous(), INFINI_STATUS_BAD_TENSOR_STRIDES); + } else if (mask_type == INFINIOP_ATTENTION_MASK_TYPE_CAUSAL) { + size_t mask_size = seq_len_q * seq_len_kv; + float *causal_mask = new float[mask_size]; + + for (size_t i = 0; i < seq_len_q; ++i) { + for (size_t j = 0; j < seq_len_kv; ++j) { + if (j > i) { + causal_mask[i * seq_len_kv + j] = -INFINITY; + } else { + causal_mask[i * seq_len_kv + j] = 0.0f; + } + } + } + + mask = causal_mask; + } + + return utils::Result(FlashAttentionBackwardInfo{ + dtype, + batch_size, + seq_len_q, + seq_len_kv, + num_heads_q, + num_heads_kv, + head_dim, + qo_stride_b, + qo_stride_s, + qo_stride_n, + qo_stride_d, + kv_stride_b, + kv_stride_s, + kv_stride_n, + kv_stride_d, + l_stride_b, + l_stride_s, + l_stride_n, + mask_stride_sq, + mask_stride_sk, + mask, + is_masked, + }); + } +}; + +} // namespace op::flash_attention_backward + +#endif // __FLASH_ATTENTION_BACKWARD_INFO_H__ diff --git a/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_kernel.cuh b/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_kernel.cuh new file mode 100644 index 000000000..db12ccd46 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_kernel.cuh @@ -0,0 +1,140 @@ +#ifndef __FLASH_ATTENTION_BACKWARD_KERNEL_CUH__ +#define __FLASH_ATTENTION_BACKWARD_KERNEL_CUH__ + +template +__device__ void flashAttentionBackwardBlock( + Tdata *grad_q_, Tdata *grad_k_, Tdata *grad_v_, + const Tdata *q_, const Tdata *k_, const Tdata *v_, + const Tdata *out_, const Tdata *grad_out_, const Tdata *l_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + const Tdata softmax_scale, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + size_t bx = blockIdx.x; // batch -> batch_size + size_t by = blockIdx.y; // q's head index -> num_heads + size_t tx = threadIdx.x; // k's row index within one block -> B_r/B_c + + size_t qo_offset = bx * qo_stride_b + by * qo_stride_n; + size_t kv_offset = bx * kv_stride_b + by / group * kv_stride_n; + size_t l_offset = bx * l_stride_b + by * seq_len_q; + + extern __shared__ __align__(sizeof(Tdata)) char shared_mem[]; + Tdata *q_i = reinterpret_cast(shared_mem); + Tdata *out_i = reinterpret_cast(q_i + B_r * head_dim); + Tdata *grad_out_i = reinterpret_cast(out_i + B_r * head_dim); + + Tdata *k_j = reinterpret_cast(grad_out_i + B_r * head_dim); + Tdata *v_j = reinterpret_cast(k_j + B_c * head_dim); + Tdata *grad_k_j = reinterpret_cast(v_j + B_c * head_dim); + Tdata *grad_v_j = reinterpret_cast(grad_k_j + B_c * head_dim); + + Tdata *s_i = reinterpret_cast(grad_v_j + B_c * head_dim); + Tdata *grad_s_i = reinterpret_cast(s_i + B_r * B_c); + + for (size_t j = 0; j < T_c; ++j) { + // load k_j, v_j and initialize grad_k_j, grad_q_j to 0 + for (size_t x = 0; x < head_dim; ++x) { + k_j[tx * head_dim + x] = k_[kv_offset + (j * B_c + tx) * kv_stride_s + x]; + v_j[tx * head_dim + x] = v_[kv_offset + (j * B_c + tx) * kv_stride_s + x]; + grad_k_j[tx * head_dim + x] = 0; + grad_v_j[tx * head_dim + x] = 0; + } + + for (size_t i = 0; i < T_r; ++i) { + __syncthreads(); + // load q_i, out_i, grad_out_i + Tdata D_i = 0; + for (size_t x = 0; x < head_dim; ++x) { + q_i[tx * head_dim + x] = q_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + out_i[tx * head_dim + x] = out_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + grad_out_i[tx * head_dim + x] = grad_out_[qo_offset + (i * B_r + tx) * qo_stride_s + x]; + D_i += grad_out_i[tx * head_dim + x] * out_i[tx * head_dim + x]; + } + Tdata l_curr = l_[l_offset + i * B_r + tx]; + + // S_i^(j) = Q_i @ K_j^T * softmax_scale + for (size_t y = 0; y < B_c; ++y) { + // mask + if (mask_ != nullptr && mask_[(i * B_r + tx) * seq_len_kv + j * B_c + y] == -INFINITY) { + s_i[tx * B_c + y] = -INFINITY; + continue; + }; + + Tdata sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += q_i[tx * head_dim + x] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + s_i[tx * B_c + y] = sum; + } + + // P_i^(j) = exp(S_ij - L_i) + for (size_t y = 0; y < B_c; ++y) { + if constexpr (std::is_same_v || std::is_same_v) { + s_i[tx * B_c + y] = hexp(s_i[tx * B_c + y] - l_curr); + } else { + s_i[tx * B_c + y] = expf(s_i[tx * B_c + y] - l_curr); + } + } + __syncthreads(); + + // dV_j = dV_j + P_i^(j)^T @ dO_i + for (size_t x = 0; x < head_dim; ++x) { + Tdata sum = 0; + for (size_t y = 0; y < B_r; ++y) { + sum += s_i[y * B_c + tx] * grad_out_i[y * head_dim + x]; + } + grad_v_j[tx * head_dim + x] += sum; + } + + // dP_i^(j) = dO_i @ V_j^T + for (size_t y = 0; y < B_c; ++y) { + Tdata sum = 0; + for (size_t x = 0; x < head_dim; ++x) { + sum += grad_out_i[tx * head_dim + x] * v_j[y * head_dim + x]; + } + grad_s_i[tx * B_c + y] = sum; + } + + // dS_i^(j) = P_i^(j) * (dP_i^(j) - D_i) + for (size_t y = 0; y < B_c; ++y) { + grad_s_i[tx * B_c + y] = s_i[tx * B_c + y] * (grad_s_i[tx * B_c + y] - D_i); + } + + // dQ_i = dQ_i + dS_i^(j) @ K_j + for (size_t x = 0; x < head_dim; ++x) { + Tdata sum = 0; + for (size_t y = 0; y < B_c; ++y) { + sum += grad_s_i[tx * B_c + y] * k_j[y * head_dim + x]; + } + sum *= softmax_scale; + grad_q_[qo_offset + (i * B_r + tx) * qo_stride_s + x] += sum; + } + __syncthreads(); + + // dK_j = dK_j + dS_i^(j)^T @ Q_i + for (size_t x = 0; x < head_dim; ++x) { + Tdata sum = 0; + for (size_t y = 0; y < B_r; ++y) { + sum += grad_s_i[y * B_c + tx] * q_i[y * head_dim + x]; + } + sum *= softmax_scale; + grad_k_j[tx * head_dim + x] += sum; + } + } + + // write dK_j, dV_j to HBM + for (size_t x = 0; x < head_dim; ++x) { + size_t offset = bx * kv_stride_b * group + by * kv_stride_n; + grad_k_[offset + (j * B_c + tx) * kv_stride_s * group + x] = grad_k_j[tx * head_dim + x]; + grad_v_[offset + (j * B_c + tx) * kv_stride_s * group + x] = grad_v_j[tx * head_dim + x]; + } + } +} + +#endif // __FLASH_ATTENTION_BACKWARD_KERNEL_CUH__ diff --git a/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_metax.cuh b/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_metax.cuh new file mode 100644 index 000000000..9d0bc579b --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __FLASH_ATTENTION_BACKWARD_METAX_CUH__ +#define __FLASH_ATTENTION_BACKWARD_METAX_CUH__ + +#include "../flash_attention_backward.h" + +DESCRIPTOR(metax) + +#endif diff --git a/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_metax.maca b/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_metax.maca new file mode 100644 index 000000000..03a71c877 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/metax/flash_attention_backward_metax.maca @@ -0,0 +1,389 @@ +#include "../../../devices/metax/metax_common.h" +#include "flash_attention_backward_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" + +#include "flash_attention_backward_kernel.cuh" +#include "../../flash_attention/metax/flash_attention_kernel.cuh" + +template +__global__ void reduce_gradients_kernel( + const Tdata* grad_k_expanded, + const Tdata* grad_v_expanded, + Tdata* grad_k, + Tdata* grad_v, + size_t total_seq_len, + size_t head_dim, + size_t group +) { + size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < total_seq_len) { + for (size_t j = 0; j < head_dim; ++j) { + Tdata sum_grad_k = 0; + Tdata sum_grad_v = 0; + for (size_t k = 0; k < group; ++k) { + sum_grad_k += grad_k_expanded[i * group * head_dim + k * head_dim + j]; + sum_grad_v += grad_v_expanded[i * group * head_dim + k * head_dim + j]; + } + grad_k[i * head_dim + j] = sum_grad_k; + grad_v[i * head_dim + j] = sum_grad_v; + } + } +} + +template +INFINIOP_METAX_KERNEL flashAttentionKernel( + Tdata *__restrict__ out_, + Tdata *__restrict__ l_, + const Tdata *__restrict__ q_, + const Tdata *__restrict__ k_, + const Tdata *__restrict__ v_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n +) { + + Tdata softmax_scale = 1.0 / sqrt(head_dim); + flashAttentionBlock( + out_, l_, + q_, k_, v_, mask_, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + softmax_scale, + qo_stride_b, qo_stride_s, qo_stride_n, + kv_stride_b, kv_stride_s, kv_stride_n, + l_stride_b, l_stride_s, l_stride_n + ); +} + +template +INFINIOP_METAX_KERNEL flashAttentionBackwardKernel( + Tdata *__restrict__ grad_q_, + Tdata *__restrict__ grad_k_, + Tdata *__restrict__ grad_v_, + const Tdata *__restrict__ q_, + const Tdata *__restrict__ k_, + const Tdata *__restrict__ v_, + const Tdata *__restrict__ out_, + const Tdata *__restrict__ grad_out_, + const Tdata *__restrict__ l_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n +) { + Tdata softmax_scale = 1.0 / sqrt(head_dim); + flashAttentionBackwardBlock( + grad_q_, grad_k_, grad_v_, + q_, k_, v_, out_, grad_out_, l_, mask_, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + softmax_scale, + qo_stride_b, qo_stride_s, qo_stride_n, + kv_stride_b, kv_stride_s, kv_stride_n, + l_stride_b, l_stride_s, l_stride_n + ); +} + + +namespace op::flash_attention_backward::metax { + +struct Descriptor::Opaque { + std::shared_ptr internel; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t grad_q_desc, + infiniopTensorDescriptor_t grad_k_desc, + infiniopTensorDescriptor_t grad_v_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t grad_out_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + + auto handle = reinterpret_cast(handle_); + + auto info = FlashAttentionBackwardInfo::create(grad_q_desc, grad_k_desc, grad_v_desc, + q_desc, k_desc, v_desc, + grad_out_desc, mask_desc, mask_type); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchForwardKernel( + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n, + infiniDtype_t dtype, + hcStream_t stream) { + hcMemset(out, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(dtype)); + hcMemset(l, 0, batch_size * nums_head_q * seq_len_q * sizeof(dtype)); + + + // Calculate SRAM size needed per block + const int sram_size = (2 * B_c * head_dim * sizeof(dtype)) // SRAM size for Kj, Vj + + (B_r * head_dim * sizeof(dtype)) // SRAM size for Qi + + (B_c * B_r * sizeof(dtype)); // SRAM size for S + + dim3 grid_dim(batch_size, nums_head_q); + dim3 block_dim(B_r); + +#define LAUNCHI_FORWARD_KERNEL(Tdata) \ + flashAttentionKernel<<>>( \ + reinterpret_cast(out), \ + reinterpret_cast(l), \ + reinterpret_cast(q), \ + reinterpret_cast(k), \ + reinterpret_cast(v), \ + reinterpret_cast(mask), \ + seq_len_q, seq_len_kv, \ + head_dim, group, \ + B_r, B_c, T_r, T_c, \ + qo_stride_b, qo_stride_s, qo_stride_n, \ + kv_stride_b, kv_stride_s, kv_stride_n, \ + l_stride_b, l_stride_s, l_stride_n) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCHI_FORWARD_KERNEL(half); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCHI_FORWARD_KERNEL(float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCHI_FORWARD_KERNEL(__hpcc_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchBackwardKernel( + void *grad_q, void *grad_k, void *grad_v, + const void *q, const void *k, const void *v, + const void *out, const void *grad_out, const void *l, + const void *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n, + infiniDtype_t dtype, + hcStream_t stream) { + + // initial grad_q, grad_k, grad_v + hcMemset(grad_q, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(dtype)); + hcMemset(grad_k, 0, batch_size * nums_head_q * seq_len_kv * head_dim * sizeof(dtype)); + hcMemset(grad_v, 0, batch_size * nums_head_q * seq_len_kv * head_dim * sizeof(dtype)); + + // calculate SRAM size needed per block + const int sram_size = (4 * B_c * head_dim * sizeof(dtype)) // SRAM size for K_j, V_j, dK_j, dV_j + + (3 * B_r * head_dim * sizeof(dtype)) // SRAM size for Q_i, O_i, dO_i + + (2 * B_c * B_r * sizeof(dtype)); // SRAM size for S_i, dS_i + + dim3 grad_dim(batch_size, nums_head_q); + dim3 block_dim(B_c); + +#define LAUNCHI_BACKWARD_KERNEL(Tdata) \ + flashAttentionBackwardKernel<<>>( \ + reinterpret_cast(grad_q), \ + reinterpret_cast(grad_k), \ + reinterpret_cast(grad_v), \ + reinterpret_cast(q), \ + reinterpret_cast(k), \ + reinterpret_cast(v), \ + reinterpret_cast(out), \ + reinterpret_cast(grad_out), \ + reinterpret_cast(l), \ + reinterpret_cast(mask), \ + seq_len_q, seq_len_kv, \ + head_dim, group, \ + B_r, B_c, T_r, T_c, \ + qo_stride_b, qo_stride_s, qo_stride_n, \ + kv_stride_b, kv_stride_s, kv_stride_n, \ + l_stride_b, l_stride_s, l_stride_n) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCHI_BACKWARD_KERNEL(half); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCHI_BACKWARD_KERNEL(float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCHI_BACKWARD_KERNEL(__hpcc_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *grad_q, void *grad_k, void *grad_v, + const void *q, const void *k, const void *v, + const void *grad_out, const void *mask, + void *stream) const { + + size_t B_r = 2; + size_t B_c = 2; + + size_t batch_size = _info.batch_size; + size_t seq_len_q = _info.seq_len_q; + size_t seq_len_kv = _info.seq_len_kv; + size_t nums_head_q = _info.num_heads_q; + size_t nums_head_kv = _info.num_heads_kv; + size_t group = nums_head_q / nums_head_kv; + size_t head_dim = _info.head_dim; + + const void *mask_input = nullptr; + if (_info.is_masked) { + if (_info.mask != nullptr) { + void *mask_temp; + hcMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); + hcMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), hcMemcpyHostToDevice); + mask_input = mask_temp; + } else { + mask_input = mask; + } + } + + size_t T_r = ceil(float(seq_len_q) / B_r); + size_t T_c = ceil(float(seq_len_kv) / B_c); + + auto hc_stream = reinterpret_cast(stream); + + void *out, *l; + if (_info.dtype == INFINI_DTYPE_F16) { + hcMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(half)); + hcMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(half)); + } else if (_info.dtype == INFINI_DTYPE_F32) { + hcMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(float)); + hcMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(float)); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + hcMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(__hpcc_bfloat16)); + hcMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(__hpcc_bfloat16)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_STATUS(launchForwardKernel( + out, l, q, k, v, mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, + _info.dtype, + hc_stream)); + + void *grad_k_expanded, *grad_v_expanded; + if (_info.dtype == INFINI_DTYPE_F16) { + hcMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(half)); + hcMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(half)); + } else if (_info.dtype == INFINI_DTYPE_F32) { + hcMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(float)); + hcMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(float)); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + hcMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(__hpcc_bfloat16)); + hcMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(__hpcc_bfloat16)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_STATUS(launchBackwardKernel( + grad_q, grad_k_expanded, grad_v_expanded, + q, k, v, out, grad_out, l, + mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, + _info.dtype, + hc_stream)); + + + size_t total_seq_len = batch_size * nums_head_kv * seq_len_kv * head_dim; + size_t threads_per_block = 256; + size_t blocks = (total_seq_len + threads_per_block - 1) / threads_per_block; + + if (_info.dtype == INFINI_DTYPE_F16) { + reduce_gradients_kernel<<>>( + reinterpret_cast(grad_k_expanded), + reinterpret_cast(grad_v_expanded), + reinterpret_cast(grad_k), + reinterpret_cast(grad_v), + total_seq_len, + head_dim, + group + ); + } else if (_info.dtype == INFINI_DTYPE_F32) { + reduce_gradients_kernel<<>>( + reinterpret_cast(grad_k_expanded), + reinterpret_cast(grad_v_expanded), + reinterpret_cast(grad_k), + reinterpret_cast(grad_v), + total_seq_len, + head_dim, + group + ); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + reduce_gradients_kernel<__hpcc_bfloat16><<>>( + reinterpret_cast(grad_k_expanded), + reinterpret_cast(grad_v_expanded), + reinterpret_cast<__hpcc_bfloat16*>(grad_k), + reinterpret_cast<__hpcc_bfloat16*>(grad_v), + total_seq_len, + head_dim, + group + ); + } + + hcFree(out); + hcFree(l); + hcFree(grad_k_expanded); + hcFree(grad_v_expanded); + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::flash_attention_backward::metax diff --git a/src/infiniop/ops/flash_attention_backward/nvidia/flash_attention_backward_nvidia.cu b/src/infiniop/ops/flash_attention_backward/nvidia/flash_attention_backward_nvidia.cu new file mode 100644 index 000000000..89e2c96e0 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/nvidia/flash_attention_backward_nvidia.cu @@ -0,0 +1,383 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "flash_attention_backward_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../../flash_attention/cuda/kernel.cuh" +#include "../cuda/kernel.cuh" + +template +__global__ void reduce_gradients_kernel( + const Tdata *grad_k_expanded, + const Tdata *grad_v_expanded, + Tdata *grad_k, + Tdata *grad_v, + size_t total_seq_len, + size_t head_dim, + size_t group) { + size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < total_seq_len) { + for (size_t j = 0; j < head_dim; ++j) { + Tdata sum_grad_k = 0; + Tdata sum_grad_v = 0; + for (size_t k = 0; k < group; ++k) { + sum_grad_k += grad_k_expanded[i * group * head_dim + k * head_dim + j]; + sum_grad_v += grad_v_expanded[i * group * head_dim + k * head_dim + j]; + } + grad_k[i * head_dim + j] = sum_grad_k; + grad_v[i * head_dim + j] = sum_grad_v; + } + } +} + +template +INFINIOP_CUDA_KERNEL flashAttentionKernel( + Tdata *__restrict__ out_, + Tdata *__restrict__ l_, + const Tdata *__restrict__ q_, + const Tdata *__restrict__ k_, + const Tdata *__restrict__ v_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + + Tdata softmax_scale = 1.0 / sqrt(head_dim); + flashAttentionBlock( + out_, l_, + q_, k_, v_, mask_, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + softmax_scale, + qo_stride_b, qo_stride_s, qo_stride_n, + kv_stride_b, kv_stride_s, kv_stride_n, + l_stride_b, l_stride_s, l_stride_n); +} + +template +INFINIOP_CUDA_KERNEL flashAttentionBackwardKernel( + Tdata *__restrict__ grad_q_, + Tdata *__restrict__ grad_k_, + Tdata *__restrict__ grad_v_, + const Tdata *__restrict__ q_, + const Tdata *__restrict__ k_, + const Tdata *__restrict__ v_, + const Tdata *__restrict__ out_, + const Tdata *__restrict__ grad_out_, + const Tdata *__restrict__ l_, + const float *mask_, + const size_t seq_len_q, const size_t seq_len_kv, + const size_t head_dim, const size_t group, + const size_t B_r, const size_t B_c, const size_t T_r, const size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n) { + Tdata softmax_scale = 1.0 / sqrt(head_dim); + flashAttentionBackwardBlock( + grad_q_, grad_k_, grad_v_, + q_, k_, v_, out_, grad_out_, l_, mask_, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + softmax_scale, + qo_stride_b, qo_stride_s, qo_stride_n, + kv_stride_b, kv_stride_s, kv_stride_n, + l_stride_b, l_stride_s, l_stride_n); +} + +namespace op::flash_attention_backward::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internel; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t grad_q_desc, + infiniopTensorDescriptor_t grad_k_desc, + infiniopTensorDescriptor_t grad_v_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t grad_out_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + + auto handle = reinterpret_cast(handle_); + + auto info = FlashAttentionBackwardInfo::create(grad_q_desc, grad_k_desc, grad_v_desc, + q_desc, k_desc, v_desc, + grad_out_desc, mask_desc, mask_type); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchForwardKernel( + void *out, void *l, + const void *q, const void *k, const void *v, + const void *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n, + infiniDtype_t dtype, + cudaStream_t stream) { + cudaMemset(out, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(dtype)); + cudaMemset(l, 0, batch_size * nums_head_q * seq_len_q * sizeof(dtype)); + + // Calculate SRAM size needed per block + const int sram_size = (2 * B_c * head_dim * sizeof(dtype)) // SRAM size for Kj, Vj + + (B_r * head_dim * sizeof(dtype)) // SRAM size for Qi + + (B_c * B_r * sizeof(dtype)); // SRAM size for S + int max_sram_size; + cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0); + if (sram_size > max_sram_size) { + printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size); + } + + dim3 grid_dim(batch_size, nums_head_q); + dim3 block_dim(B_r); + +#define LAUNCHI_FORWARD_KERNEL(Tdata) \ + flashAttentionKernel<<>>( \ + reinterpret_cast(out), \ + reinterpret_cast(l), \ + reinterpret_cast(q), \ + reinterpret_cast(k), \ + reinterpret_cast(v), \ + reinterpret_cast(mask), \ + seq_len_q, seq_len_kv, \ + head_dim, group, \ + B_r, B_c, T_r, T_c, \ + qo_stride_b, qo_stride_s, qo_stride_n, \ + kv_stride_b, kv_stride_s, kv_stride_n, \ + l_stride_b, l_stride_s, l_stride_n) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCHI_FORWARD_KERNEL(half); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCHI_FORWARD_KERNEL(float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCHI_FORWARD_KERNEL(__nv_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchBackwardKernel( + void *grad_q, void *grad_k, void *grad_v, + const void *q, const void *k, const void *v, + const void *out, const void *grad_out, const void *l, + const void *mask, + size_t batch_size, + size_t nums_head_q, size_t nums_head_kv, + size_t seq_len_q, size_t seq_len_kv, + size_t head_dim, size_t group, + size_t B_r, size_t B_c, size_t T_r, size_t T_c, + ptrdiff_t qo_stride_b, ptrdiff_t qo_stride_s, ptrdiff_t qo_stride_n, + ptrdiff_t kv_stride_b, ptrdiff_t kv_stride_s, ptrdiff_t kv_stride_n, + ptrdiff_t l_stride_b, ptrdiff_t l_stride_s, ptrdiff_t l_stride_n, + infiniDtype_t dtype, + cudaStream_t stream) { + + // initial grad_q, grad_k, grad_v + cudaMemset(grad_q, 0, batch_size * nums_head_q * seq_len_q * head_dim * sizeof(dtype)); + cudaMemset(grad_k, 0, batch_size * nums_head_q * seq_len_kv * head_dim * sizeof(dtype)); + cudaMemset(grad_v, 0, batch_size * nums_head_q * seq_len_kv * head_dim * sizeof(dtype)); + + // calculate SRAM size needed per block + const int sram_size = (4 * B_c * head_dim * sizeof(dtype)) // SRAM size for K_j, V_j, dK_j, dV_j + + (3 * B_r * head_dim * sizeof(dtype)) // SRAM size for Q_i, O_i, dO_i + + (2 * B_c * B_r * sizeof(dtype)); // SRAM size for S_i, dS_i + + dim3 grad_dim(batch_size, nums_head_q); + dim3 block_dim(B_c); + +#define LAUNCHI_BACKWARD_KERNEL(Tdata) \ + flashAttentionBackwardKernel<<>>( \ + reinterpret_cast(grad_q), \ + reinterpret_cast(grad_k), \ + reinterpret_cast(grad_v), \ + reinterpret_cast(q), \ + reinterpret_cast(k), \ + reinterpret_cast(v), \ + reinterpret_cast(out), \ + reinterpret_cast(grad_out), \ + reinterpret_cast(l), \ + reinterpret_cast(mask), \ + seq_len_q, seq_len_kv, \ + head_dim, group, \ + B_r, B_c, T_r, T_c, \ + qo_stride_b, qo_stride_s, qo_stride_n, \ + kv_stride_b, kv_stride_s, kv_stride_n, \ + l_stride_b, l_stride_s, l_stride_n) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCHI_BACKWARD_KERNEL(half); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCHI_BACKWARD_KERNEL(float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCHI_BACKWARD_KERNEL(__nv_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *grad_q, void *grad_k, void *grad_v, + const void *q, const void *k, const void *v, + const void *grad_out, const void *mask, + void *stream) const { + + size_t B_r = 2; + size_t B_c = 2; + + size_t batch_size = _info.batch_size; + size_t seq_len_q = _info.seq_len_q; + size_t seq_len_kv = _info.seq_len_kv; + size_t nums_head_q = _info.num_heads_q; + size_t nums_head_kv = _info.num_heads_kv; + size_t group = nums_head_q / nums_head_kv; + size_t head_dim = _info.head_dim; + + const void *mask_input = nullptr; + if (_info.is_masked) { + if (_info.mask != nullptr) { + void *mask_temp; + cudaMalloc(&mask_temp, seq_len_q * seq_len_kv * sizeof(float)); + cudaMemcpy(mask_temp, _info.mask, seq_len_q * seq_len_kv * sizeof(float), cudaMemcpyHostToDevice); + mask_input = mask_temp; + } else { + mask_input = mask; + } + } + + size_t T_r = ceil(float(seq_len_q) / B_r); + size_t T_c = ceil(float(seq_len_kv) / B_c); + + auto cuda_stream = reinterpret_cast(stream); + + void *out, *l; + if (_info.dtype == INFINI_DTYPE_F16) { + cudaMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(half)); + cudaMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(half)); + } else if (_info.dtype == INFINI_DTYPE_F32) { + cudaMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(float)); + cudaMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(float)); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + cudaMalloc(&out, batch_size * seq_len_kv * nums_head_q * head_dim * sizeof(__nv_bfloat16)); + cudaMalloc(&l, batch_size * seq_len_kv * nums_head_q * sizeof(__nv_bfloat16)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_STATUS(launchForwardKernel( + out, l, q, k, v, mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, + _info.dtype, + cuda_stream)); + + void *grad_k_expanded, *grad_v_expanded; + if (_info.dtype == INFINI_DTYPE_F16) { + cudaMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(half)); + cudaMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(half)); + } else if (_info.dtype == INFINI_DTYPE_F32) { + cudaMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(float)); + cudaMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(float)); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + cudaMalloc(&grad_k_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(__nv_bfloat16)); + cudaMalloc(&grad_v_expanded, batch_size * nums_head_kv * seq_len_kv * head_dim * group * sizeof(__nv_bfloat16)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_STATUS(launchBackwardKernel( + grad_q, grad_k_expanded, grad_v_expanded, + q, k, v, out, grad_out, l, + mask_input, + batch_size, + nums_head_q, nums_head_kv, + seq_len_q, seq_len_kv, + head_dim, group, + B_r, B_c, T_r, T_c, + _info.qo_stride_b, _info.qo_stride_s, _info.qo_stride_n, + _info.kv_stride_b, _info.kv_stride_s, _info.kv_stride_n, + _info.l_stride_b, _info.l_stride_s, _info.l_stride_n, + _info.dtype, + cuda_stream)); + + size_t total_seq_len = batch_size * nums_head_kv * seq_len_kv * head_dim; + size_t threads_per_block = 256; + size_t blocks = (total_seq_len + threads_per_block - 1) / threads_per_block; + + if (_info.dtype == INFINI_DTYPE_F16) { + reduce_gradients_kernel<<>>( + reinterpret_cast(grad_k_expanded), + reinterpret_cast(grad_v_expanded), + reinterpret_cast(grad_k), + reinterpret_cast(grad_v), + total_seq_len, + head_dim, + group); + } else if (_info.dtype == INFINI_DTYPE_F32) { + reduce_gradients_kernel<<>>( + reinterpret_cast(grad_k_expanded), + reinterpret_cast(grad_v_expanded), + reinterpret_cast(grad_k), + reinterpret_cast(grad_v), + total_seq_len, + head_dim, + group); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + reduce_gradients_kernel<__nv_bfloat16><<>>( + reinterpret_cast(grad_k_expanded), + reinterpret_cast(grad_v_expanded), + reinterpret_cast<__nv_bfloat16 *>(grad_k), + reinterpret_cast<__nv_bfloat16 *>(grad_v), + total_seq_len, + head_dim, + group); + } + + cudaFree(out); + cudaFree(l); + cudaFree(grad_k_expanded); + cudaFree(grad_v_expanded); + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::flash_attention_backward::nvidia diff --git a/src/infiniop/ops/flash_attention_backward/nvidia/flash_attention_backward_nvidia.cuh b/src/infiniop/ops/flash_attention_backward/nvidia/flash_attention_backward_nvidia.cuh new file mode 100644 index 000000000..36e56085c --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/nvidia/flash_attention_backward_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __FLASH_ATTENTION_BACKWARD_CUDA_H__ +#define __FLASH_ATTENTION_BACKWARD_CUDA_H__ + +#include "../flash_attention_backward.h" + +DESCRIPTOR(nvidia) + +#endif // __FLASH_ATTENTION_BACKWARD_CUDA_H__ diff --git a/src/infiniop/ops/flash_attention_backward/operator.cc b/src/infiniop/ops/flash_attention_backward/operator.cc new file mode 100644 index 000000000..19c451616 --- /dev/null +++ b/src/infiniop/ops/flash_attention_backward/operator.cc @@ -0,0 +1,166 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/flash_attention_backward.h" + +#ifdef ENABLE_CPU_API +#include "cpu/flash_attention_backward_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/flash_attention_backward_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/flash_attention_backward_metax.cuh" +#endif + +__C infiniStatus_t infiniopCreateFlashAttentionBackwardDescriptor( + infiniopHandle_t handle, + infiniopFlashAttentionBackwardDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t grad_q_desc, + infiniopTensorDescriptor_t grad_k_desc, + infiniopTensorDescriptor_t grad_v_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_desc, + infiniopTensorDescriptor_t v_desc, + infiniopTensorDescriptor_t grad_out_desc, + infiniopTensorDescriptor_t mask_desc, + infiniopAttentionMaskType_t mask_type) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::flash_attention_backward::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + grad_q_desc, \ + grad_k_desc, \ + grad_v_desc, \ + q_desc, \ + k_desc, \ + v_desc, \ + grad_out_desc, \ + mask_desc, \ + mask_type); + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetFlashAttentionBackwardWorkspaceSize( + infiniopFlashAttentionBackwardDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__C infiniStatus_t infiniopFlashAttentionBackward( + infiniopFlashAttentionBackwardDescriptor_t desc, + void *workspace, size_t workspace_size, + void *grad_q, + void *grad_k, + void *grad_v, + const void *q, + const void *k, + const void *v, + const void *grad_out, + const void *mask, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, \ + grad_q, grad_k, grad_v, q, k, v, grad_out, mask, stream); + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFlashAttentionBackwardDescriptor( + infiniopFlashAttentionBackwardDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DESTROY +} diff --git a/test/infiniop-test/test_generate/infiniop_test.py b/test/infiniop-test/test_generate/infiniop_test.py index 0edc1644c..58392c118 100644 --- a/test/infiniop-test/test_generate/infiniop_test.py +++ b/test/infiniop-test/test_generate/infiniop_test.py @@ -1,3 +1,4 @@ +import gguf from typing import List import gguf @@ -15,7 +16,7 @@ def np_dtype_to_ggml(tensor_dtype: np.dtype): return GGMLQuantizationType.F32 elif tensor_dtype == np.float64: return GGMLQuantizationType.F64 - elif tensor_dtype == np.bool_: + elif tensor_dtype == np.bool: return GGMLQuantizationType.Q8_K elif tensor_dtype == np.int8: return GGMLQuantizationType.I8 diff --git a/test/infiniop-test/test_generate/testcases/flash_attention.py b/test/infiniop-test/test_generate/testcases/flash_attention.py new file mode 100644 index 000000000..eb40eeb0f --- /dev/null +++ b/test/infiniop-test/test_generate/testcases/flash_attention.py @@ -0,0 +1,255 @@ +import gguf +import torch +import numpy as np +import torch.nn.functional as F +from typing import List +from ml_dtypes import bfloat16 + +from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor + +def flash_attention( + q: np.ndarray, + k: np.ndarray, + v: np.ndarray, + mask: np.ndarray | None, + mask_type: int, +): + # mask + if mask_type == 0: + mask = None + elif mask_type == 1: + mask = mask + elif mask_type == 2: + mask = np.triu(np.ones((q.shape[-3], k.shape[-3]), dtype=np.float32), k=1) + mask = np.where(mask == 1, -np.inf, mask) + + q_tensor = torch.tensor(q, requires_grad=True) + k_tensor = torch.tensor(k, requires_grad=True) + v_tensor = torch.tensor(v, requires_grad=True) + mask_tensor = torch.tensor(mask) if mask is not None else None + + if q_tensor.dim() == 3: + # from (seq_len, num_heads, dim) to (num_heads, seq_len, dim) + q_shaped = q_tensor.permute(1, 0, 2) + k_shaped = k_tensor.permute(1, 0, 2) + v_shaped = v_tensor.permute(1, 0, 2) + elif q_tensor.dim() == 4: + # from (batch_size, seq_len, num_heads, head_dim) to (batch_size, num_heads, seq_len, head_dim) + q_shaped = q_tensor.permute(0, 2, 1, 3) + k_shaped = k_tensor.permute(0, 2, 1, 3) + v_shaped = v_tensor.permute(0, 2, 1, 3) + + out = F.scaled_dot_product_attention( + query=q_shaped, + key=k_shaped, + value=v_shaped, + attn_mask=mask_tensor, + enable_gqa=True, + ) + # Permute back to original shape + out = out.permute(1, 0, 2) if q_tensor.dim() == 3 else out.permute(0, 2, 1, 3) + + return out.detach().numpy() + + +class FlashAttentionTestCase(InfiniopTestCase): + def __init__( + self, + q: np.ndarray, + shape_q: List[int] | None, + stride_q: List[int] | None, + k: np.ndarray, + shape_k: List[int] | None, + stride_k: List[int] | None, + v: np.ndarray, + shape_v: List[int] | None, + stride_v: List[int] | None, + out: np.ndarray, + shape_out: List[int] | None, + stride_out: List[int] | None, + l: np.ndarray, + shape_l: List[int] | None, + stride_l: List[int] | None, + mask: np.ndarray | None, + shape_mask: List[int] | None, + stride_mask: List[int] | None, + mask_type: int, + ): + super().__init__("flash_attention") + # input + self.q = q + self.shape_q = shape_q + self.stride_q = stride_q + self.k = k + self.shape_k = shape_k + self.stride_k = stride_k + self.v = v + self.shape_v = shape_v + self.stride_v = stride_v + self.mask = mask + self.shape_mask = shape_mask + self.stride_mask = stride_mask + self.mask_type = mask_type + # output + self.out = out + self.shape_out = shape_out + self.stride_out = stride_out + self.l = l + self.shape_l = shape_l + self.stride_l = stride_l + + def write_test(self, test_writer: InfiniopTestWriter): + super().write_test(test_writer) + # test_writer.add_array(test_writer.gguf_key("mask_type"), [self.mask_type,]) + test_writer.add_int32(test_writer.gguf_key("mask_type"), self.mask_type) + + if self.shape_q is not None: + test_writer.add_array(test_writer.gguf_key("q.shape"), self.shape_q) + if self.shape_k is not None: + test_writer.add_array(test_writer.gguf_key("k.shape"), self.shape_k) + if self.shape_v is not None: + test_writer.add_array(test_writer.gguf_key("v.shape"), self.shape_v) + if self.shape_mask is not None: + test_writer.add_array(test_writer.gguf_key("mask.shape"), self.shape_mask) + if self.shape_out is not None: + test_writer.add_array(test_writer.gguf_key("out.shape"), self.shape_out) + if self.shape_l is not None: + test_writer.add_array(test_writer.gguf_key("l.shape"), self.shape_l) + + if self.stride_q is not None: + test_writer.add_array(test_writer.gguf_key("q.strides"), gguf_strides(*self.stride_q)) + if self.stride_k is not None: + test_writer.add_array(test_writer.gguf_key("k.strides"), gguf_strides(*self.stride_k)) + if self.stride_v is not None: + test_writer.add_array(test_writer.gguf_key("v.strides"), gguf_strides(*self.stride_v)) + if self.stride_mask is not None: + test_writer.add_array(test_writer.gguf_key("mask.strides"), gguf_strides(*self.stride_mask)) + test_writer.add_array( + test_writer.gguf_key("out.strides"), + gguf_strides(*self.stride_out if self.stride_out is not None else contiguous_gguf_strides(self.shape_out)) + ) + test_writer.add_array( + test_writer.gguf_key("l.strides"), + gguf_strides(*self.stride_l if self.stride_l is not None else contiguous_gguf_strides(self.shape_l)) + ) + + test_writer.add_tensor( + test_writer.gguf_key("q"), self.q, raw_dtype=np_dtype_to_ggml(self.q.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("k"), self.k, raw_dtype=np_dtype_to_ggml(self.k.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("v"), self.v, raw_dtype=np_dtype_to_ggml(self.v.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("mask"), self.mask, raw_dtype=np_dtype_to_ggml(self.mask.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("out"), self.out, raw_dtype=np_dtype_to_ggml(self.out.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("l"), self.l, raw_dtype=np_dtype_to_ggml(self.l.dtype) + ) + + ans = flash_attention( + self.q.astype(np.float64), + self.k.astype(np.float64), + self.v.astype(np.float64), + self.mask.astype(np.float64) if self.mask is not None else None, + self.mask_type, + ) + test_writer.add_tensor( + test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64 + ) + + +def gen_gguf(dtype: np.dtype, filename: str): + test_writer = InfiniopTestWriter(filename) + test_cases = [] + + # ============================================================================== + # Configuration + # ============================================================================== + # These are not meant to be imported from other modules + _TEST_CASES_ = [ + # shape_qo, shape_kv, mask_type(0: None, 1: Full, 2: Causal) + # inputLayout -> ((batch_size), seq_len, num_heads, head_dim) + ((4, 2, 2), (4, 2, 2), 0), + ((4, 4, 4), (10, 4, 4), 0), + ((10, 4, 4), (4, 4, 4), 2), + ((1, 10, 2, 4), (1, 10, 2, 4), 0), + ((4, 10, 8, 4), (4, 10, 2, 4), 1), + ((16, 1024, 8, 64), (16, 1024, 2, 64), 0), + ] + for shape_qo, shape_kv, mask_type in _TEST_CASES_: + q = (np.random.rand(*shape_qo) * 0.1).astype(dtype) + k = (np.random.rand(*shape_kv) * 0.1).astype(dtype) + v = (np.random.rand(*shape_kv) * 0.1).astype(dtype) + + shape_mask = (q.shape[-3], k.shape[-3]) + mask = np.random.randint(0, 2, size=shape_mask).astype(np.float32) + mask = np.where(mask == 1, -np.inf, mask) + + out = np.zeros(shape_qo, dtype=dtype) + # out = np.empty(tuple(0 for _ in shape_out), dtype=dtype) + + shape_l = shape_qo[:-1] + l = np.zeros(shape_l, dtype=dtype) + + stride_q = None + stride_k = None + stride_v = None + stride_out = None + stride_l = None + stride_mask = None + + q = process_zero_stride_tensor(q, stride_q) + k = process_zero_stride_tensor(k, stride_k) + v = process_zero_stride_tensor(v, stride_v) + out = process_zero_stride_tensor(out, stride_out) + l = process_zero_stride_tensor(l, stride_l) + mask = process_zero_stride_tensor(mask, stride_mask) + + test_case = FlashAttentionTestCase( + q=q, + shape_q=shape_qo, + stride_q=stride_q, + k=k, + shape_k=shape_kv, + stride_k=stride_k, + v=v, + shape_v=shape_kv, + stride_v=stride_v, + out=out, + shape_out=shape_qo, + stride_out=stride_out, + l=l, + shape_l=shape_l, + stride_l=stride_l, + mask=mask, + shape_mask=shape_mask, + stride_mask=stride_mask, + mask_type=mask_type, + ) + test_cases.append(test_case) + + test_writer.add_tests(test_cases) + test_writer.save() + + +if __name__ == "__main__": + _TENSOR_DTYPES_ = [ + np.float32, + np.float16, + bfloat16, + ] + dtype_filename_map = { + np.float32: "flash_attention_f32.gguf", + np.float16: "flash_attention_f16.gguf", + bfloat16: "flash_attention_bf16.gguf", + } + + for dtype in _TENSOR_DTYPES_: + filename = dtype_filename_map[dtype] + gen_gguf(dtype, filename) diff --git a/test/infiniop-test/test_generate/testcases/flash_attention_backward.py b/test/infiniop-test/test_generate/testcases/flash_attention_backward.py new file mode 100644 index 000000000..379102004 --- /dev/null +++ b/test/infiniop-test/test_generate/testcases/flash_attention_backward.py @@ -0,0 +1,308 @@ +import gguf +import torch +import numpy as np +import torch.nn.functional as F +from typing import List +from ml_dtypes import bfloat16 + +from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor + +def flash_attention_backward( + q: np.ndarray, + k: np.ndarray, + v: np.ndarray, + grad_out: np.ndarray, + mask: np.ndarray | None, + mask_type: int, +): + # mask + if mask_type == 0: + mask = None + elif mask_type == 1: + mask = mask + elif mask_type == 2: + mask = np.triu(np.ones((q.shape[-3], k.shape[-3]), dtype=np.float32), k=1) + mask = np.where(mask == 1, -np.inf, mask) + + q_tensor = torch.tensor(q, requires_grad=True) + k_tensor = torch.tensor(k, requires_grad=True) + v_tensor = torch.tensor(v, requires_grad=True) + q_tensor.grad = torch.zeros_like(q_tensor) + k_tensor.grad = torch.zeros_like(k_tensor) + v_tensor.grad = torch.zeros_like(v_tensor) + grad_out_tensor = torch.tensor(grad_out) + mask_tensor = torch.tensor(mask) if mask is not None else None + + if q_tensor.dim() == 3: + # from (seq_len, num_heads, dim) to (num_heads, seq_len, dim) + q_shaped = q_tensor.permute(1, 0, 2) + k_shaped = k_tensor.permute(1, 0, 2) + v_shaped = v_tensor.permute(1, 0, 2) + grad_out_shaped = grad_out_tensor.permute(1, 0, 2) + elif q_tensor.dim() == 4: + # from (batch_size, seq_len, num_heads, head_dim) to (batch_size, num_heads, seq_len, head_dim) + q_shaped = q_tensor.permute(0, 2, 1, 3) + k_shaped = k_tensor.permute(0, 2, 1, 3) + v_shaped = v_tensor.permute(0, 2, 1, 3) + grad_out_shaped = grad_out_tensor.permute(0, 2, 1, 3) + + out = F.scaled_dot_product_attention( + query=q_shaped, + key=k_shaped, + value=v_shaped, + attn_mask=mask_tensor, + enable_gqa=True, + ) + out.backward(grad_out_shaped) + + return ( + q_tensor.grad.numpy(), + k_tensor.grad.numpy(), + v_tensor.grad.numpy(), + ) + + +class FlashAttentionBackwardTest(InfiniopTestCase): + def __init__( + self, + q: np.ndarray, + shape_q: List[int] | None, + stride_q: List[int] | None, + k: np.ndarray, + shape_k: List[int] | None, + stride_k: List[int] | None, + v: np.ndarray, + shape_v: List[int] | None, + stride_v: List[int] | None, + grad_out: np.ndarray, + shape_grad_out: List[int] | None, + stride_grad_out: List[int] | None, + grad_q: np.ndarray, + shape_grad_q: List[int] | None, + stride_grad_q: List[int] | None, + grad_k: np.ndarray, + shape_grad_k: List[int] | None, + stride_grad_k: List[int] | None, + grad_v: np.ndarray, + shape_grad_v: List[int] | None, + stride_grad_v: List[int] | None, + mask: np.ndarray | None, + shape_mask: List[int] | None, + stride_mask: List[int] | None, + mask_type: int, + ): + super().__init__("flash_attention_backward") + # input + self.q = q + self.shape_q = shape_q + self.stride_q = stride_q + self.k = k + self.shape_k = shape_k + self.stride_k = stride_k + self.v = v + self.shape_v = shape_v + self.stride_v = stride_v + self.grad_out = grad_out + self.shape_grad_out = shape_grad_out + self.stride_grad_out = stride_grad_out + self.mask = mask + self.shape_mask = shape_mask + self.stride_mask = stride_mask + self.mask_type = mask_type + # output + self.grad_q = grad_q + self.shape_grad_q = shape_grad_q + self.stride_grad_q = stride_grad_q + self.grad_k = grad_k + self.shape_grad_k = shape_grad_k + self.stride_grad_k = stride_grad_k + self.grad_v = grad_v + self.shape_grad_v = shape_grad_v + self.stride_grad_v = stride_grad_v + + def write_test(self, test_writer: InfiniopTestWriter): + super().write_test(test_writer) + test_writer.add_int32(test_writer.gguf_key("mask_type"), self.mask_type) + + if self.shape_q is not None: + test_writer.add_array(test_writer.gguf_key("q.shape"), self.shape_q) + if self.shape_k is not None: + test_writer.add_array(test_writer.gguf_key("k.shape"), self.shape_k) + if self.shape_v is not None: + test_writer.add_array(test_writer.gguf_key("v.shape"), self.shape_v) + if self.shape_mask is not None: + test_writer.add_array(test_writer.gguf_key("mask.shape"), self.shape_mask) + if self.shape_grad_out is not None: + test_writer.add_array(test_writer.gguf_key("grad_out.shape"), self.shape_grad_out) + if self.shape_grad_q is not None: + test_writer.add_array(test_writer.gguf_key("grad_q.shape"), self.shape_grad_q) + if self.shape_grad_k is not None: + test_writer.add_array(test_writer.gguf_key("grad_k.shape"), self.shape_grad_k) + if self.shape_grad_v is not None: + test_writer.add_array(test_writer.gguf_key("grad_v.shape"), self.shape_grad_v) + + if self.stride_q is not None: + test_writer.add_array(test_writer.gguf_key("q.strides"), gguf_strides(*self.stride_q)) + if self.stride_k is not None: + test_writer.add_array(test_writer.gguf_key("k.strides"), gguf_strides(*self.stride_k)) + if self.stride_v is not None: + test_writer.add_array(test_writer.gguf_key("v.strides"), gguf_strides(*self.stride_v)) + if self.stride_mask is not None: + test_writer.add_array(test_writer.gguf_key("mask.strides"), gguf_strides(*self.stride_mask)) + if self.stride_grad_out is not None: + test_writer.add_array(test_writer.gguf_key("grad_out.strides"), gguf_strides(*self.stride_grad_out)) + test_writer.add_array( + test_writer.gguf_key("grad_q.strides"), + gguf_strides(*self.stride_grad_q if self.stride_grad_q is not None else contiguous_gguf_strides(self.shape_grad_q)) + ) + test_writer.add_array( + test_writer.gguf_key("grad_k.strides"), + gguf_strides(*self.stride_grad_k if self.stride_grad_k is not None else contiguous_gguf_strides(self.shape_grad_k)) + ) + test_writer.add_array( + test_writer.gguf_key("grad_v.strides"), + gguf_strides(*self.stride_grad_v if self.stride_grad_v is not None else contiguous_gguf_strides(self.shape_grad_v)) + ) + + test_writer.add_tensor( + test_writer.gguf_key("q"), self.q, raw_dtype=np_dtype_to_ggml(self.q.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("k"), self.k, raw_dtype=np_dtype_to_ggml(self.k.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("v"), self.v, raw_dtype=np_dtype_to_ggml(self.v.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("mask"), self.mask, raw_dtype=np_dtype_to_ggml(self.mask.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("grad_out"), self.grad_out, raw_dtype=np_dtype_to_ggml(self.grad_out.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("grad_q"), self.grad_q, raw_dtype=np_dtype_to_ggml(self.grad_q.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("grad_k"), self.grad_k, raw_dtype=np_dtype_to_ggml(self.grad_k.dtype) + ) + test_writer.add_tensor( + test_writer.gguf_key("grad_v"), self.grad_v, raw_dtype=np_dtype_to_ggml(self.grad_v.dtype) + ) + + ans_grad_q, ans_grad_k, ans_grad_v = flash_attention_backward( + self.q.astype(np.float64), + self.k.astype(np.float64), + self.v.astype(np.float64), + self.grad_out.astype(np.float64), + self.mask.astype(np.float64) if self.mask is not None else None, + self.mask_type, + ) + test_writer.add_tensor( + test_writer.gguf_key("ans_grad_q"), ans_grad_q, raw_dtype=gguf.GGMLQuantizationType.F64 + ) + test_writer.add_tensor( + test_writer.gguf_key("ans_grad_k"), ans_grad_k, raw_dtype=gguf.GGMLQuantizationType.F64 + ) + test_writer.add_tensor( + test_writer.gguf_key("ans_grad_v"), ans_grad_v, raw_dtype=gguf.GGMLQuantizationType.F64 + ) + + +def gen_gguf(dtype: np.dtype, filename: str): + test_writer = InfiniopTestWriter(filename) + test_cases = [] + + # ============================================================================== + # Configuration + # ============================================================================== + _TEST_CASES = [ + # shape_q, shape_kv, mask_type(0: None, 1: Full, 2: Causal) + # inputLayout -> ((batch_size), seq_len, num_heads, head_dim) + ((4, 2, 2), (4, 2, 2), 0), + ((4, 4, 4), (10, 4, 4), 0), + # ((10, 4, 4), (4, 4, 4), 2), + ((1, 10, 2, 4), (1, 10, 2, 4), 0), + # ((4, 10, 8, 4), (4, 10, 2, 4), 0), + ] + + for shape_qo, shape_kv, mask_type in _TEST_CASES: + q = (np.random.rand(*shape_qo) * 0.1).astype(dtype) + k = (np.random.rand(*shape_kv) * 0.1).astype(dtype) + v = (np.random.rand(*shape_kv) * 0.1).astype(dtype) + grad_out = (np.random.rand(*shape_qo) * 0.1).astype(dtype) + + shape_mask = (q.shape[-3], k.shape[-3]) + mask = np.random.randint(0, 2, size=shape_mask).astype(np.float32) + mask = np.where(mask == 1, -np.inf, mask) + + # grad_q = np.empty(tuple(0 for _ in shape_q), dtype=dtype) + # grad_k = np.empty(tuple(0 for _ in shape_kv), dtype=dtype) + # grad_v = np.empty(tuple(0 for _ in shape_kv), dtype=dtype) + grad_q = np.zeros(shape_qo, dtype=dtype) + grad_k = np.zeros(shape_kv, dtype=dtype) + grad_v = np.zeros(shape_kv, dtype=dtype) + + stride_q = None + stride_kv = None + stride_grad_out = None + stride_grad_q = None + stride_grad_kv = None + stride_grad_v = None + stride_mask = None + + q = process_zero_stride_tensor(q, stride_q) + k = process_zero_stride_tensor(k, stride_kv) + v = process_zero_stride_tensor(v, stride_kv) + grad_out = process_zero_stride_tensor(grad_out, stride_grad_out) + grad_q = process_zero_stride_tensor(grad_q, stride_grad_q) + grad_k = process_zero_stride_tensor(grad_k, stride_grad_kv) + grad_v = process_zero_stride_tensor(grad_v, stride_grad_kv) + mask = process_zero_stride_tensor(mask, stride_mask) + + test_case = FlashAttentionBackwardTest( + q=q, + shape_q=shape_qo, + stride_q=stride_q, + k=k, + shape_k=shape_kv, + stride_k=stride_kv, + v=v, + shape_v=shape_kv, + stride_v=stride_kv, + grad_out=grad_out, + shape_grad_out=shape_qo, + stride_grad_out=stride_grad_out, + grad_q=grad_q, + shape_grad_q=shape_qo, + stride_grad_q=stride_grad_q, + grad_k=grad_k, + shape_grad_k=shape_kv, + stride_grad_k=stride_grad_kv, + grad_v=grad_v, + shape_grad_v=shape_kv, + stride_grad_v=stride_grad_kv, + mask=mask, + shape_mask=shape_mask, + stride_mask=stride_mask, + mask_type=mask_type, + ) + test_cases.append(test_case) + + test_writer.add_tests(test_cases) + test_writer.save() + +if __name__ == "__main__": + _TENSOR_DTYPES = [ + np.float32, + np.float16, + bfloat16, + ] + dtype_filename_map = { + np.float32: "flash_attention_backward_f32.gguf", + np.float16: "flash_attention_backward_f16.gguf", + bfloat16: "flash_attention_backward_bf16.gguf", + } + + for dtype in _TENSOR_DTYPES: + filename = dtype_filename_map[dtype] + gen_gguf(dtype, filename) diff --git a/test/infiniop/flash_attention.py b/test/infiniop/flash_attention.py new file mode 100644 index 000000000..2d9f618dd --- /dev/null +++ b/test/infiniop/flash_attention.py @@ -0,0 +1,221 @@ +import math +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceEnum, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +class infiniopAttentionMaskType: + NONE = 0 + FULL = 1 + CAUSAL = 2 + +InfiniopAttentionMaskTypeNames = { + infiniopAttentionMaskType.NONE: "NONE", + infiniopAttentionMaskType.FULL: "FULL", + infiniopAttentionMaskType.CAUSAL: "CAUSAL", +} + +_TEST_CASES = [ + # ( + # (1, 2, 2, 2), # q/out shape + # (1, 2, 2, 2), # k/v shape + # infiniopAttentionMaskType.NONE, # Mask type + # ), + ((4, 2, 2), (4, 2, 2), 0), + ((4, 4, 4), (10, 4, 4), 0), + ((10, 4, 4), (4, 4, 4), 2), + ((1, 10, 2, 4), (1, 10, 2, 4), 0), + ((4, 10, 8, 4), (4, 10, 2, 4), 1), + ((16, 1024, 8, 64), (16, 1024, 2, 64), 0), +] + +_TENSOR_DTYPES = [ + InfiniDtype.F32, + InfiniDtype.F16, + InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-6, "rtol": 1e-6}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def causal_mask(shape): + mask = torch.tril(torch.ones(shape, dtype=torch.float32)) + mask = torch.where(mask == 1, + torch.tensor(0.0, dtype=torch.float32), + torch.tensor(float('-inf'), dtype=torch.float32)) + return mask + + +def attention(query, key, value, attn_mask=None, mask_type=None, dropout_p=0.0, scale=None) -> torch.Tensor: + add_dim = False + if(query.ndim == 3): + query =torch.unsqueeze(query, 0) + key = torch.unsqueeze(key, 0) + value = torch.unsqueeze(value, 0) + add_dim = True + B = query.size(0) + L, S = query.size(-3), key.size(-3) + NH, NKVH = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype).to(query.device) + + if mask_type == 0: + attn_mask = None + elif mask_type == 2: + attn_mask = causal_mask((L, S)).to(query.device) + + if attn_mask is not None: + attn_bias += attn_mask + + attn_weight = query.reshape(B, L, NKVH, NH//NKVH, -1).permute(0, 2, 3, 1 ,4) @ key.reshape(B, S, NKVH, 1, -1).permute(0, 2, 3, 4, 1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, train=True) + attn_out = (attn_weight @ value.reshape(B, S, NKVH, 1, -1).permute(0, 2, 3, 1, 4)).permute(0, 3, 1, 2, 4).reshape(B, L, NH, -1) + if add_dim: + attn_out = torch.squeeze(attn_out, 0) + return attn_out + + +def test( + handle, + device, + qo_shape, + kv_shape, + mask_type, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing FlashAttention on {InfiniDeviceNames[device]} with qo_shape:{qo_shape} kv_shape:{kv_shape} dtype:{InfiniDtypeNames[dtype]} mask_type:{InfiniopAttentionMaskTypeNames[mask_type]}" + ) + + out = TestTensor(qo_shape, None, dtype, device, mode="zeros") + l = TestTensor(qo_shape[:-1], None, dtype, device, mode="zeros") + + q = TestTensor(qo_shape, None, dtype, device, scale=0.1) + k = TestTensor(kv_shape, None, dtype, device, scale=0.1) + v = TestTensor(kv_shape, None, dtype, device, scale=0.1) + + mask = causal_mask((qo_shape[-3], kv_shape[-3])) + mask = TestTensor.from_torch(mask, InfiniDtype.F32, device) + + def torch_attention(): + return attention( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + mask.torch_tensor(), + mask_type, + ) + + ans = torch_attention() + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateFlashAttentionDescriptor( + handle, + ctypes.byref(descriptor), + out.descriptor, + l.descriptor, + q.descriptor, + k.descriptor, + v.descriptor, + mask.descriptor, + mask_type, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [ + out, + l, + q, + k, + v, + mask, + ]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetFlashAttentionWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, out.device) + + def lib_flash_attention(): + check_error( + LIBINFINIOP.infiniopFlashAttention( + descriptor, + workspace.data(), + workspace_size.value, + out.data(), + l.data(), + q.data(), + k.data(), + v.data(), + mask.data(), + None, + ) + ) + + lib_flash_attention() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out.actual_tensor(), ans, atol=atol, rtol=rtol) + assert torch.allclose(out.actual_tensor(), ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + profile_operation("PyTorch", lambda: torch_attention(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_flash_attention(), device, NUM_PRERUN, NUM_ITERATIONS) + check_error(LIBINFINIOP.infiniopDestroyFlashAttentionDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") + \ No newline at end of file diff --git a/test/infiniop/flash_attention_backward.py b/test/infiniop/flash_attention_backward.py new file mode 100644 index 000000000..412795177 --- /dev/null +++ b/test/infiniop/flash_attention_backward.py @@ -0,0 +1,283 @@ +import math +import torch +import torch.nn.functional as F +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceEnum, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +import warnings +warnings.filterwarnings("ignore", message=".*no current CUDA context.*") + + +class infiniopAttentionMaskType: + NONE = 0 + FULL = 1 + CAUSAL = 2 + +InfiniopAttentionMaskTypeNames = { + infiniopAttentionMaskType.NONE: "NONE", + infiniopAttentionMaskType.FULL: "FULL", + infiniopAttentionMaskType.CAUSAL: "CAUSAL", +} + + +_TEST_CASES = [ + # ( + # (1, 2, 2, 2), # q/out shape + # (1, 2, 2, 2), # k/v shape + # infiniopAttentionMaskType.NONE, # Mask type + # ), + ((4, 2, 2), (4, 2, 2), 0), + # ((4, 4, 4), (10, 4, 4), 0), + # ((4, 4, 4), (4, 2, 4), 0), + # ((1, 10, 2, 4), (1, 10, 2, 4), 0), + # ((4, 10, 8, 4), (4, 10, 2, 4), 1), +] + +_TENSOR_DTYPES = [ + # InfiniDtype.F32, + # InfiniDtype.F16, + InfiniDtype.BF16, +] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-1, "rtol": 1e-1}, + InfiniDtype.F32: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.BF16: {"atol": 1e-1, "rtol": 1e-1}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def causal_mask(shape): + mask = torch.tril(torch.ones(shape, dtype=torch.float32)) + mask = torch.where(mask == 1, + torch.tensor(0.0, dtype=torch.float32), + torch.tensor(float('-inf'), dtype=torch.float32)) + return mask + + +def attention_backward(q, k, v, grad_out, attn_mask, mask_type): + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + q.grad = torch.zeros_like(q) + k.grad = torch.zeros_like(k) + v.grad = torch.zeros_like(v) + + if mask_type == 0: + attn_mask = None + elif mask_type == 2: + attn_mask = causal_mask((q.shape[-3], k.shape[-3])).to(q.device) + + if q.ndim == 3: + q_shaped = q.permute(1, 0 ,2) + k_shaped = k.permute(1, 0 ,2) + v_shaped = v.permute(1, 0, 2) + grad_out_shaped = grad_out.permute(1, 0, 2) + elif q.ndim == 4: + q_shaped = q.permute(0, 2, 1, 3) + k_shaped = k.permute(0, 2, 1, 3) + v_shaped = v.permute(0, 2, 1, 3) + grad_out_shaped = grad_out.permute(0, 2, 1, 3) + + out = F.scaled_dot_product_attention( + query=q_shaped, + key=k_shaped, + value=v_shaped, + attn_mask=attn_mask, + # enable_gqa=True, + ) + + out.backward(grad_out_shaped) + + return q.grad, k.grad, v.grad + + +def test( + handle, + device, + qo_shape, + kv_shape, + mask_type, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing FlashAttentionBackward on {InfiniDeviceNames[device]} with qo_shape:{qo_shape} kv_shape:{kv_shape} dtype:{InfiniDtypeNames[dtype]} mask_type:{InfiniopAttentionMaskTypeNames[mask_type]}" + ) + + # q = torch.tensor( + # [[[[0.0399, 0.0517], + # [0.0025, 0.0940]], + + # [[0.0946, 0.0797], + # [0.0415, 0.0820]]]] + # ) + # k = torch.tensor( + # [[[[0.0972, 0.0791], + # [0.0469, 0.0330]], + + # [[0.0334, 0.0378], + # [0.0764, 0.0640]]]] + # ) + # v = torch.tensor( + # [[[[0.0019, 0.0773], + # [0.0865, 0.0810]], + + # [[0.0667, 0.0365], + # [0.0364, 0.0568]]]] + # ) + # grad_out = torch.tensor( + # [[[[0.0787, 0.0134], + # [0.0219, 0.0819]], + + # [[0.0697, 0.0730], + # [0.0233, 0.0903]]]] + # ) + # q = TestTensor.from_torch(q, dtype, device) + # k = TestTensor.from_torch(k, dtype, device) + # v = TestTensor.from_torch(v, dtype, device) + # grad_out = TestTensor.from_torch(grad_out, dtype, device) + + grad_q = TestTensor(qo_shape, None, dtype, device, mode="zeros") + grad_k = TestTensor(kv_shape, None, dtype, device, mode="zeros") + grad_v = TestTensor(kv_shape, None, dtype, device, mode="zeros") + + q = TestTensor(qo_shape, None, dtype, device, scale=0.1) + k = TestTensor(kv_shape, None, dtype, device, scale=0.1) + v = TestTensor(kv_shape, None, dtype, device, scale=0.1) + grad_out = TestTensor(qo_shape, None, dtype, device, scale=0.1) + + # print(f"q:\n{q.torch_tensor()}") + # print(f"k:\n{k.torch_tensor()}") + # print(f"v:\n{v.torch_tensor()}") + # print(f"grad_out:\n{grad_out.torch_tensor()}") + + mask = causal_mask((qo_shape[-3], kv_shape[-3])) + mask = TestTensor.from_torch(mask, InfiniDtype.F32, device) + + def torch_attention_backward(): + return attention_backward( + q.torch_tensor(), + k.torch_tensor(), + v.torch_tensor(), + grad_out.torch_tensor(), + mask.torch_tensor(), + mask_type + ) + + ans_grad_q, ans_grad_k, ans_grad_v = torch_attention_backward() + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateFlashAttentionBackwardDescriptor( + handle, + ctypes.byref(descriptor), + grad_q.descriptor, + grad_k.descriptor, + grad_v.descriptor, + q.descriptor, + k.descriptor, + v.descriptor, + grad_out.descriptor, + mask.descriptor, + mask_type, + ) + ) + + for tensor in [ + grad_q, + grad_k, + grad_v, + q, + k, + v, + grad_out, + mask, + ]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetFlashAttentionBackwardWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, grad_q.device) + + def lib_flash_attention_backward(): + check_error( + LIBINFINIOP.infiniopFlashAttentionBackward( + descriptor, + workspace.data(), + workspace_size.value, + grad_q.data(), + grad_k.data(), + grad_v.data(), + q.data(), + k.data(), + v.data(), + grad_out.data(), + mask.data(), + None, + ) + ) + + lib_flash_attention_backward() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(grad_q.actual_tensor().cpu(), ans_grad_q.cpu(), atol=atol, rtol=rtol) + debug(grad_k.actual_tensor().cpu(), ans_grad_k.cpu(), atol=atol, rtol=rtol) + debug(grad_v.actual_tensor().cpu(), ans_grad_v.cpu(), atol=atol, rtol=rtol) + + assert torch.allclose(grad_q.actual_tensor().cpu(), ans_grad_q.cpu(), atol=atol, rtol=rtol) + assert torch.allclose(grad_k.actual_tensor().cpu(), ans_grad_k.cpu(), atol=atol, rtol=rtol) + assert torch.allclose(grad_v.actual_tensor().cpu(), ans_grad_v.cpu(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + profile_operation("PyTorch", lambda: torch_attention_backward(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_flash_attention_backward(), device, NUM_PRERUN, NUM_ITERATIONS) + check_error(LIBINFINIOP.infiniopDestroyFlashAttentionBackwardDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") + \ No newline at end of file diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 869e4aa86..88c5f82ed 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -167,6 +167,92 @@ def conv_(lib): pass +@OpRegister.operator +def flash_attention_(lib): + lib.infiniopCreateFlashAttentionDescriptor.restype = c_int32 + lib.infiniopCreateFlashAttentionDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + ] + + lib.infiniopGetFlashAttentionWorkspaceSize.restype = c_int32 + lib.infiniopGetFlashAttentionWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopFlashAttention.restype = c_int32 + lib.infiniopFlashAttention.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyFlashAttentionDescriptor.restype = c_int32 + lib.infiniopDestroyFlashAttentionDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def flash_attention_backward_(lib): + lib.infiniopCreateFlashAttentionBackwardDescriptor.restype = c_int32 + lib.infiniopCreateFlashAttentionBackwardDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + ] + + lib.infiniopGetFlashAttentionBackwardWorkspaceSize.restype = c_int32 + lib.infiniopGetFlashAttentionBackwardWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopFlashAttentionBackward.restype = c_int32 + lib.infiniopFlashAttentionBackward.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyFlashAttentionBackwardDescriptor.restype = c_int32 + lib.infiniopDestroyFlashAttentionBackwardDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def gemm_(lib): lib.infiniopCreateGemmDescriptor.restype = c_int32