Skip to content

Commit 72df31d

Browse files
committed
cann: add the basic FA support
1 parent be1d4a1 commit 72df31d

File tree

3 files changed

+210
-0
lines changed

3 files changed

+210
-0
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,3 +2587,163 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
25872587

25882588
ggml_cann_release_resources(ctx, acl_src, acl_dst, alpha);
25892589
}
2590+
2591+
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
2592+
2593+
ggml_tensor* src0 = dst->src[0]; // q, fp32
2594+
ggml_tensor* src1 = dst->src[1]; // k, fp16
2595+
ggml_tensor* src2 = dst->src[2]; // v, fp16
2596+
ggml_tensor* src3 = dst->src[3]; // mask, fp16
2597+
2598+
size_t faElemSize = sizeof(uint16_t);
2599+
2600+
// Step 1: cast the src0 (Query) to fp16
2601+
aclTensor* acl_src0_f16_tensor = nullptr;
2602+
2603+
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
2604+
void* src0_f16_buffer = nullptr;
2605+
2606+
if(src0->type != GGML_TYPE_F16){
2607+
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
2608+
2609+
src0_f16_allocator.alloc(ggml_nelements(src0) * faElemSize);
2610+
src0_f16_buffer = src0_f16_allocator.get();
2611+
2612+
int64_t* src0_f16_ne = src0->ne;
2613+
size_t src0_f16_nb[GGML_MAX_DIMS];
2614+
src0_f16_nb[0] = sizeof(uint16_t);
2615+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
2616+
src0_f16_nb[i] = src0_f16_nb[i - 1] * src0_f16_ne[i - 1];
2617+
}
2618+
2619+
acl_src0_f16_tensor = ggml_cann_create_tensor(
2620+
src0_f16_buffer, ACL_FLOAT16, faElemSize,
2621+
src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
2622+
);
2623+
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, ACL_FLOAT16);
2624+
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
2625+
}else{
2626+
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
2627+
}
2628+
2629+
// Step 2: genetates mask with ACL_BOOL
2630+
size_t maskElemSize = sizeof(char);
2631+
ggml_cann_pool_alloc src3_bool_allocator(ctx.pool());
2632+
src3_bool_allocator.alloc(ggml_nelements(src3) * maskElemSize);
2633+
void* src3_bool_buffer = src3_bool_allocator.get();
2634+
2635+
int64_t* src3_bool_ne = src3->ne;
2636+
size_t src3_bool_nb[GGML_MAX_DIMS];
2637+
src3_bool_nb[0] = maskElemSize;
2638+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
2639+
src3_bool_nb[i] = src3_bool_nb[i - 1] * src3_bool_ne[i - 1];
2640+
}
2641+
2642+
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
2643+
aclTensor* acl_mask_bool_tensor = ggml_cann_create_tensor(
2644+
src3_bool_buffer, ACL_BOOL, maskElemSize,
2645+
src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS);
2646+
2647+
GGML_CANN_CALL_ACLNN_OP(ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor);
2648+
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
2649+
2650+
// Step 3: generates the output tensor directly from FA kernel
2651+
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
2652+
out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
2653+
void* out_f16_buffer = out_f16_allocator.get();
2654+
2655+
int64_t* out_f16_ne = src0->ne;
2656+
size_t out_f16_nb[GGML_MAX_DIMS];
2657+
out_f16_nb[0] = faElemSize;
2658+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
2659+
out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1];
2660+
}
2661+
2662+
aclTensor* acl_out_f16_tensor = ggml_cann_create_tensor(
2663+
out_f16_buffer, ACL_FLOAT16, faElemSize,
2664+
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
2665+
);
2666+
2667+
// Step 4: Performs the f16 Flash Attention kernel
2668+
2669+
int kvTensorNum = 1;
2670+
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
2671+
aclTensor* acl_k_tensors[] = {ggml_cann_create_tensor(src1)};
2672+
aclTensor* acl_v_tensors[] = {ggml_cann_create_tensor(src2)};
2673+
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
2674+
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
2675+
aclTensor* acl_out_tensor = acl_out_f16_tensor;
2676+
2677+
2678+
int64_t numHeads = src0->ne[2]; // N
2679+
int64_t numKeyValueHeads = src1->ne[2];
2680+
double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
2681+
int64_t preTokens = 65535;
2682+
int64_t nextTokens = 65535;
2683+
char layout[5] = {'B', 'N', 'S', 'D', 0};
2684+
int64_t sparseMode = 0;
2685+
int64_t innerPrecise = 1;
2686+
int64_t blockSize = 0;
2687+
int64_t antiquantMode = 0;
2688+
bool softmaxLseFlag = false;
2689+
int64_t keyAntiquantMode = 0;
2690+
int64_t valueAntiquantMode = 0;
2691+
2692+
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
2693+
2694+
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
2695+
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
2696+
nullptr, acl_mask_bool_tensor, // pse, mask
2697+
nullptr, nullptr, // actSeqLen, actSeqLenkv
2698+
nullptr, nullptr, // deqScale1, quantScale1
2699+
nullptr, nullptr, nullptr, // deqScale2, quantScale2, quantOffset2
2700+
nullptr, nullptr, // antiquantScale, antiquantOffset
2701+
nullptr, // blockTable
2702+
nullptr, nullptr, // qPadSize, kvPadSize
2703+
nullptr, nullptr, // kAntiquantScale, kAntiQuantOffset
2704+
nullptr, nullptr, // vAntiquantScale, vAntiQuantOffset
2705+
nullptr, nullptr, nullptr, // kSharedPrefix, vSharedPrefix, actSharedLen
2706+
numHeads, scaleValue, // heads, scaleValue
2707+
preTokens, nextTokens, // preTokens, nextTokens
2708+
layout, // inputLayout
2709+
numKeyValueHeads, // numKVHeads
2710+
sparseMode, innerPrecise, // sparseMode, innerPrecise
2711+
blockSize, antiquantMode, // blockSize, antiquantMode
2712+
softmaxLseFlag, // softmaxLseFlag
2713+
keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
2714+
acl_out_tensor, // attentionOut
2715+
nullptr // softmaxLse
2716+
);
2717+
2718+
// Step 5: post-processing, permute and cast to f32
2719+
int64_t new_dim[] = {0, 2, 1, 3};
2720+
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
2721+
2722+
if(dst->type != GGML_TYPE_F16){
2723+
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
2724+
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
2725+
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
2726+
2727+
int64_t* perm_out_f16_ne = dst->ne;
2728+
size_t perm_out_f16_nb[GGML_MAX_DIMS];
2729+
perm_out_f16_nb[0] = faElemSize;
2730+
for(int i = 1; i < GGML_MAX_DIMS; ++i){
2731+
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
2732+
}
2733+
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
2734+
perm_out_f16_buffer, ACL_FLOAT16, faElemSize,
2735+
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
2736+
aclnn_permute(ctx, acl_out_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
2737+
aclnn_cast(ctx,
2738+
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
2739+
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
2740+
}else{
2741+
// only need to permute
2742+
aclnn_permute(ctx, acl_out_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
2743+
}
2744+
2745+
ggml_cann_release_resources(ctx, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list,
2746+
acl_mask_bool_tensor, acl_out_f16_tensor,
2747+
acl_dst_tensor);
2748+
2749+
}

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
#include <aclnnop/aclnn_cos.h>
4646
#include <aclnnop/aclnn_log.h>
4747
#include <aclnnop/aclnn_sign.h>
48+
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
49+
#include <aclnnop/aclnn_isneginf.h>
4850
#include "acl_tensor.h"
4951
#include "common.h"
5052

@@ -714,6 +716,21 @@ void ggml_cann_count_equal(ggml_backend_cann_context& ctx, ggml_tensor* dst);
714716
*/
715717
void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst);
716718

719+
/**
720+
* @brief Performs the Flash Attention extended operator using the CANN backend.
721+
*
722+
* @details This function implements the memory-efficient Flash Attention algorithm
723+
* for computing scaled dot-product attention with hardware acceleration.
724+
* The result is stored in the destination tensor `dst`.
725+
*
726+
* This operation is accelerated using the CANN backend to improve runtime performance.
727+
*
728+
* @param ctx The CANN context used for operations.
729+
* @param dst The destination tensor where the result will be stored.
730+
* dst->op is expected to be `GGML_OP_FLASH_ATTN_EXT`.
731+
*/
732+
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst);
733+
717734
/*
718735
* @brief A generic wrapper for ACL resources with custom deleter support.
719736
*/

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
17471747
case GGML_OP_COUNT_EQUAL:
17481748
ggml_cann_count_equal(ctx, dst);
17491749
break;
1750+
case GGML_OP_FLASH_ATTN_EXT:
1751+
ggml_cann_flash_attn_ext(ctx, dst);
1752+
break;
17501753
default:
17511754
return false;
17521755
}
@@ -2161,6 +2164,36 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21612164
case GGML_OP_PAD_REFLECT_1D:
21622165
case GGML_OP_COUNT_EQUAL:
21632166
return true;
2167+
case GGML_OP_FLASH_ATTN_EXT:{
2168+
// copy from [ggml-cuda.cu]
2169+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2170+
// different head sizes of K and V are not supported yet
2171+
return false;
2172+
}
2173+
if (op->src[0]->ne[0] == 192) {
2174+
return false;
2175+
}
2176+
if (op->src[0]->ne[0] == 576) {
2177+
// DeepSeek MLA
2178+
return false;
2179+
}
2180+
if (op->src[0]->ne[3] != 1) {
2181+
return false;
2182+
}
2183+
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
2184+
return false;
2185+
}
2186+
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2187+
return true;
2188+
}
2189+
if (op->src[0]->ne[0] == 128) {
2190+
return true;
2191+
}
2192+
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
2193+
return true;
2194+
}
2195+
return op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2196+
}
21642197
default:
21652198
return false;
21662199
}

0 commit comments

Comments
 (0)