Skip to content

Commit e35fecc

Browse files
shibizhaoggerganov
authored andcommitted
CANN: Add the basic supports of Flash Attention kernel (llama/13627)
* cann: add the basic FA support * cann: update the readme * cann: update the FlashAttention with PSEShift * cann: update the input parameters in FA * cann: update the alibi with max_bias * cann: add the constrints of softcap * cann: update the docs CANN.md * cann: update the docs CANN.md * cann: fix typo of CANN.md * cann: add some comments and update the CANN.md * cann: update the CANN.md * cann: update the inner precise for fusedInferAttention * cann: update the constraints of flash_attn_ext on ggml-cann.cpp * cann: clean the whitespace * cann: clean the whitespace * cann: add a new endline
1 parent 1cd7028 commit e35fecc

File tree

8 files changed

+383
-0
lines changed

8 files changed

+383
-0
lines changed

ggml/src/ggml-cann/CMakeLists.txt

100644100755
File mode changed.

ggml/src/ggml-cann/Doxyfile

100644100755
File mode changed.

ggml/src/ggml-cann/acl_tensor.cpp

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ aclDataType ggml_cann_type_mapping(ggml_type type) {
3131
return ACL_FLOAT;
3232
case GGML_TYPE_F16:
3333
return ACL_FLOAT16;
34+
case GGML_TYPE_BF16:
35+
return ACL_BF16;
3436
case GGML_TYPE_I8:
3537
return ACL_INT8;
3638
case GGML_TYPE_I16:

ggml/src/ggml-cann/acl_tensor.h

100644100755
File mode changed.

ggml/src/ggml-cann/aclnn_ops.cpp

100644100755
Lines changed: 330 additions & 0 deletions
Large diffs are not rendered by default.

ggml/src/ggml-cann/aclnn_ops.h

100644100755
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
714714
*/
715715
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
716716

717+
/**
718+
* @brief Performs the Flash Attention extended operator using the CANN backend.
719+
*
720+
* @details This function implements the memory-efficient Flash Attention algorithm
721+
* for computing scaled dot-product attention with hardware acceleration.
722+
* The result is stored in the destination tensor `dst`.
723+
*
724+
* This operation is accelerated using the CANN backend to improve runtime performance.
725+
*
726+
* @param ctx The CANN context used for operations.
727+
* @param dst The destination tensor where the result will be stored.
728+
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
729+
*/
730+
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
731+
717732
/*
718733
* @brief A generic wrapper for ACL resources with custom deleter support.
719734
*/

ggml/src/ggml-cann/common.h

100644100755
File mode changed.

ggml/src/ggml-cann/ggml-cann.cpp

100644100755
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "ggml-backend-impl.h"
3737
#include "ggml-cann/aclnn_ops.h"
3838
#include "ggml-cann/common.h"
39+
#include "ggml.h"
3940

4041
#define GGML_COMMON_DECL_C
4142

@@ -1748,6 +1749,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
17481749
case GGML_OP_COUNT_EQUAL:
17491750
ggml_cann_count_equal(ctx, dst);
17501751
break;
1752+
case GGML_OP_FLASH_ATTN_EXT:
1753+
ggml_cann_flash_attn_ext(ctx, dst);
1754+
break;
17511755
default:
17521756
return false;
17531757
}
@@ -2177,6 +2181,38 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21772181
case GGML_OP_PAD_REFLECT_1D:
21782182
case GGML_OP_COUNT_EQUAL:
21792183
return true;
2184+
case GGML_OP_FLASH_ATTN_EXT:{
2185+
// derived from [ggml-cuda.cu]
2186+
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2187+
return false;
2188+
}
2189+
if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
2190+
return false;
2191+
}
2192+
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2193+
return false;
2194+
}
2195+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2196+
// different head sizes of K and V are not supported yet
2197+
return false;
2198+
}
2199+
if (op->src[0]->ne[0] == 192) {
2200+
return false;
2201+
}
2202+
if (op->src[0]->ne[0] == 576) {
2203+
// DeepSeek MLA
2204+
return false;
2205+
}
2206+
if (op->src[0]->ne[3] != 1) {
2207+
return false;
2208+
}
2209+
float logitSoftcap = 0.0f;
2210+
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
2211+
if(logitSoftcap != 0.0f) {
2212+
return false;
2213+
}
2214+
return true;
2215+
}
21802216
default:
21812217
return false;
21822218
}

0 commit comments

Comments
 (0)