6565#include < aclnnop/aclnn_eq_tensor.h>
6666#include < aclnnop/aclnn_gt_scalar.h>
6767#include < aclnnop/aclnn_pow.h>
68+ #include < aclnnop/aclnn_fused_infer_attention_score_v2.h>
6869#include < float.h>
6970
7071#include < cmath>
7172#include < cstring>
7273#include < exception>
7374#include < vector>
7475
75- #include < iostream>
76- #include < fstream>
77- #include < string>
78- #include < cstring>
79-
80- #include " aclnnop/aclnn_flash_attention_score.h"
81- #include " aclnnop/aclnn_logical_not.h"
82-
83- #include " ggml-cann/acl_tensor.h"
8476#include " ggml-impl.h"
85- #include " ggml.h"
8677
8778#define GGML_COMMON_DECL_C
8879
@@ -2623,7 +2614,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26232614 aclTensor* acl_src3_f16_tensor = nullptr ;
26242615 aclTensor* acl_dst_f16_tensor = nullptr ;
26252616
2626- // Step 1: cast the src0 (Query) to fp16
2617+ // Step 1: cast the src0 (Query) to fp16 if needed
26272618 ggml_cann_pool_alloc src0_f16_allocator (ctx.pool ());
26282619 void * src0_f16_buffer = nullptr ;
26292620
@@ -2649,6 +2640,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26492640 acl_src0_f16_tensor = ggml_cann_create_tensor (src0);
26502641 }
26512642
2643+ // Step 2: create the acl tensors for src1 (Key), src2 (Value),
2644+ // and the direct output from FusedInferAttention
2645+
26522646 acl_src1_f16_tensor = ggml_cann_create_tensor (src1);
26532647 acl_src2_f16_tensor = ggml_cann_create_tensor (src2);
26542648
@@ -2668,21 +2662,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26682662 out_f16_ne, out_f16_nb, GGML_MAX_DIMS
26692663 );
26702664
2671- aclTensor* bcast_pse_tensor = nullptr ;
26722665
2666+ // Step 3: create the PSEShift tensor if needed
2667+ // this tensor is considered as mask (f16) in the llama.cpp
2668+
2669+ aclTensor* bcast_pse_tensor = nullptr ;
26732670 int64_t bcast_pse_ne[GGML_MAX_DIMS];
26742671 size_t bcast_pse_nb[GGML_MAX_DIMS];
26752672 ggml_cann_pool_alloc bcast_pse_allocator (ctx.pool ());
26762673 void * bcast_pse_buffer = nullptr ;
2677- if (src3)
2674+
2675+ if (src3 != nullptr ){
26782676 bcast_pse_buffer = bcast_pse_allocator.alloc (
26792677 ggml_nelements (src3) * src0->ne [2 ] * sizeof (uint16_t ));
2680-
2681- if (src3 != nullptr ){
2682- // broadcast pse
2678+
26832679 if (src0->ne [1 ] > 1 ){
2680+ // Case 1: broadcast pse for prefill stage with multiple head
26842681 aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor (src3);
2685-
26862682 bcast_pse_ne[0 ] = src3->ne [0 ];
26872683 bcast_pse_ne[1 ] = src3->ne [1 ];
26882684 bcast_pse_ne[2 ] = src0->ne [2 ];
@@ -2702,6 +2698,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
27022698
27032699 ggml_cann_release_resources (ctx, acl_mask_f16_tensor);
27042700 }else {
2701+ // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
27052702 int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne [0 ], src0->ne [1 ], src3->ne [2 ], src3->ne [3 ]};
27062703 size_t * trunc_pse_nb = src3->nb ;
27072704
@@ -2729,6 +2726,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
27292726 ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
27302727 }
27312728
2729+ // Compute the slope if needed. Derived from ggml_cann_softmax().
27322730 if (maxBias != 0 .0f ){
27332731 // alibi
27342732 const int64_t ne2_ne3 = src0->ne [2 ] * src0->ne [3 ];
@@ -2832,6 +2830,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28322830 }
28332831 }
28342832
2833+ // Step 4: set the inputs for FusedInferAttention.
28352834 int kvTensorNum = 1 ;
28362835 aclTensor* acl_q_tensor = acl_src0_f16_tensor;
28372836 aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
@@ -2853,9 +2852,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28532852 int64_t keyAntiquantMode = 0 ;
28542853 int64_t valueAntiquantMode = 0 ;
28552854
2855+ // Step 5: launch the FusedInferAttentionScoreV2 kernel.
28562856 // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
2857-
2858-
2857+
28592858 GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
28602859 acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
28612860 bcast_pse_tensor, nullptr , // pse, mask
@@ -2880,7 +2879,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28802879 nullptr // softmaxLse
28812880 );
28822881
2883- // Step 5: post-processing, permute and cast to f32
2882+ // Step 6: post-processing, permute and cast to f32
2883+
28842884 int64_t new_dim[] = {0 , 2 , 1 , 3 };
28852885 aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
28862886
@@ -2911,9 +2911,11 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
29112911 acl_src2_f16_tensor,
29122912 acl_dst_f16_tensor,
29132913 acl_dst_tensor);
2914- if (src3)
2914+ if (src3 != nullptr ){
29152915 ggml_cann_release_resources (ctx, bcast_pse_tensor);
2916+ }
29162917 }else {
2917- throw std::runtime_error (" Function not implemented" );
2918+ GGML_ABORT (" Function not implemented" );
29182919 }
29192920}
2921+