-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[CANN]: add the basic supports of Flash Attention kernel #13627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
72df31d
3a73182
6a39d63
8a902b9
f5e24a5
c8c2908
47f2c64
fb62f01
b266beb
8a112f0
1779e00
092ccf6
c380305
89f884e
1a3bfec
3b084d5
d23697b
8a7829b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -258,6 +258,15 @@ cmake --build build --config release | |
| ### **GitHub contribution**: | ||
| Please add the **[CANN]** prefix/tag in issues/PRs titles to help the CANN-team check/address them without delay. | ||
|
|
||
| ## Updates | ||
| ### Basic Flash Attention Support | ||
| The basic FA kernel with aclnnops has been added in aclnn_ops.cpp. | ||
| Currently, the FA only supports the cases with FP16 KV tensors and NO logit softcap. | ||
| Since the aclnn interface for flash attention cannot support the logit softcap, we will only update the quantized version in the future. | ||
|
|
||
| Authors from Peking University: Bizhao Shi ([email protected]), Yuxin Yang ([email protected]), Ruiyang Ma ([email protected]), and Guojie Luo ([email protected]). | ||
|
|
||
| Thanks Tuo Dai and Shanni Li from Huawei Technologies Co., Ltd. | ||
|
|
||
| ## TODO | ||
| - Support more models and data types. | ||
| - Support more models and d | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,8 @@ | |
| #include <aclnnop/aclnn_cos.h> | ||
| #include <aclnnop/aclnn_log.h> | ||
| #include <aclnnop/aclnn_sign.h> | ||
| #include <aclnnop/aclnn_fused_infer_attention_score_v2.h> | ||
|
||
| #include <aclnnop/aclnn_isneginf.h> | ||
| #include "acl_tensor.h" | ||
| #include "common.h" | ||
|
|
||
|
|
@@ -714,6 +716,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst); | |
| */ | ||
| void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||
|
|
||
| /** | ||
| * @brief Performs the Flash Attention extended operator using the CANN backend. | ||
| * | ||
| * @details This function implements the memory-efficient Flash Attention algorithm | ||
| * for computing scaled dot-product attention with hardware acceleration. | ||
| * The result is stored in the destination tensor `dst`. | ||
| * | ||
| * This operation is accelerated using the CANN backend to improve runtime performance. | ||
| * | ||
| * @param ctx The CANN context used for operations. | ||
| * @param dst The destination tensor where the result will be stored. | ||
| * dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`. | ||
| */ | ||
| void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst); | ||
|
|
||
| /* | ||
| * @brief A generic wrapper for ACL resources with custom deleter support. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1747,6 +1747,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, | |
| case GGML_OP_COUNT_EQUAL: | ||
| ggml_cann_count_equal(ctx, dst); | ||
| break; | ||
| case GGML_OP_FLASH_ATTN_EXT: | ||
| ggml_cann_flash_attn_ext(ctx, dst); | ||
| break; | ||
| default: | ||
| return false; | ||
| } | ||
|
|
@@ -2161,6 +2164,36 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, | |
| case GGML_OP_PAD_REFLECT_1D: | ||
| case GGML_OP_COUNT_EQUAL: | ||
| return true; | ||
| case GGML_OP_FLASH_ATTN_EXT:{ | ||
| // copy from [ggml-cuda.cu] | ||
| if (op->src[1]->ne[0] != op->src[2]->ne[0]) { | ||
| // different head sizes of K and V are not supported yet | ||
| return false; | ||
| } | ||
| if (op->src[0]->ne[0] == 192) { | ||
| return false; | ||
| } | ||
| if (op->src[0]->ne[0] == 576) { | ||
| // DeepSeek MLA | ||
| return false; | ||
| } | ||
| if (op->src[0]->ne[3] != 1) { | ||
| return false; | ||
| } | ||
| if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { | ||
| return false; | ||
| } | ||
| if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { | ||
| return true; | ||
| } | ||
| if (op->src[0]->ne[0] == 128) { | ||
|
||
| return true; | ||
| } | ||
| if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { | ||
| return true; | ||
| } | ||
|
||
| return op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; | ||
| } | ||
| default: | ||
| return false; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here seems to be some documentation errors