@@ -2587,3 +2587,163 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
25872587
25882588 ggml_cann_release_resources (ctx, acl_src, acl_dst, alpha);
25892589}
2590+
2591+ void ggml_cann_flash_attn_ext (ggml_backend_cann_context& ctx, ggml_tensor* dst){
2592+
2593+ ggml_tensor* src0 = dst->src [0 ]; // q, fp32
2594+ ggml_tensor* src1 = dst->src [1 ]; // k, fp16
2595+ ggml_tensor* src2 = dst->src [2 ]; // v, fp16
2596+ ggml_tensor* src3 = dst->src [3 ]; // mask, fp16
2597+
2598+ size_t faElemSize = sizeof (uint16_t );
2599+
2600+ // Step 1: cast the src0 (Query) to fp16
2601+ aclTensor* acl_src0_f16_tensor = nullptr ;
2602+
2603+ ggml_cann_pool_alloc src0_f16_allocator (ctx.pool ());
2604+ void * src0_f16_buffer = nullptr ;
2605+
2606+ if (src0->type != GGML_TYPE_F16){
2607+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0);
2608+
2609+ src0_f16_allocator.alloc (ggml_nelements (src0) * faElemSize);
2610+ src0_f16_buffer = src0_f16_allocator.get ();
2611+
2612+ int64_t * src0_f16_ne = src0->ne ;
2613+ size_t src0_f16_nb[GGML_MAX_DIMS];
2614+ src0_f16_nb[0 ] = sizeof (uint16_t );
2615+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
2616+ src0_f16_nb[i] = src0_f16_nb[i - 1 ] * src0_f16_ne[i - 1 ];
2617+ }
2618+
2619+ acl_src0_f16_tensor = ggml_cann_create_tensor (
2620+ src0_f16_buffer, ACL_FLOAT16, faElemSize,
2621+ src0_f16_ne, src0_f16_nb, GGML_MAX_DIMS
2622+ );
2623+ aclnn_cast (ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, ACL_FLOAT16);
2624+ ggml_cann_release_resources (ctx, acl_src0_f32_tensor);
2625+ }else {
2626+ acl_src0_f16_tensor = ggml_cann_create_tensor (src0);
2627+ }
2628+
2629+ // Step 2: genetates mask with ACL_BOOL
2630+ size_t maskElemSize = sizeof (char );
2631+ ggml_cann_pool_alloc src3_bool_allocator (ctx.pool ());
2632+ src3_bool_allocator.alloc (ggml_nelements (src3) * maskElemSize);
2633+ void * src3_bool_buffer = src3_bool_allocator.get ();
2634+
2635+ int64_t * src3_bool_ne = src3->ne ;
2636+ size_t src3_bool_nb[GGML_MAX_DIMS];
2637+ src3_bool_nb[0 ] = maskElemSize;
2638+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
2639+ src3_bool_nb[i] = src3_bool_nb[i - 1 ] * src3_bool_ne[i - 1 ];
2640+ }
2641+
2642+ aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor (src3);
2643+ aclTensor* acl_mask_bool_tensor = ggml_cann_create_tensor (
2644+ src3_bool_buffer, ACL_BOOL, maskElemSize,
2645+ src3_bool_ne, src3_bool_nb, GGML_MAX_DIMS);
2646+
2647+ GGML_CANN_CALL_ACLNN_OP (ctx, IsNegInf, acl_mask_f16_tensor, acl_mask_bool_tensor);
2648+ ggml_cann_release_resources (ctx, acl_mask_f16_tensor);
2649+
2650+ // Step 3: generates the output tensor directly from FA kernel
2651+ ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
2652+ out_f16_allocator.alloc (ggml_nelements (dst) * faElemSize);
2653+ void * out_f16_buffer = out_f16_allocator.get ();
2654+
2655+ int64_t * out_f16_ne = src0->ne ;
2656+ size_t out_f16_nb[GGML_MAX_DIMS];
2657+ out_f16_nb[0 ] = faElemSize;
2658+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
2659+ out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
2660+ }
2661+
2662+ aclTensor* acl_out_f16_tensor = ggml_cann_create_tensor (
2663+ out_f16_buffer, ACL_FLOAT16, faElemSize,
2664+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
2665+ );
2666+
2667+ // Step 4: Performs the f16 Flash Attention kernel
2668+
2669+ int kvTensorNum = 1 ;
2670+ aclTensor* acl_q_tensor = acl_src0_f16_tensor;
2671+ aclTensor* acl_k_tensors[] = {ggml_cann_create_tensor (src1)};
2672+ aclTensor* acl_v_tensors[] = {ggml_cann_create_tensor (src2)};
2673+ auto acl_k_tensor_list = aclCreateTensorList (acl_k_tensors, kvTensorNum);
2674+ auto acl_v_tensor_list = aclCreateTensorList (acl_v_tensors, kvTensorNum);
2675+ aclTensor* acl_out_tensor = acl_out_f16_tensor;
2676+
2677+
2678+ int64_t numHeads = src0->ne [2 ]; // N
2679+ int64_t numKeyValueHeads = src1->ne [2 ];
2680+ double scaleValue = 1 / sqrt (src0->ne [0 ]); // 1/sqrt(d)
2681+ int64_t preTokens = 65535 ;
2682+ int64_t nextTokens = 65535 ;
2683+ char layout[5 ] = {' B' , ' N' , ' S' , ' D' , 0 };
2684+ int64_t sparseMode = 0 ;
2685+ int64_t innerPrecise = 1 ;
2686+ int64_t blockSize = 0 ;
2687+ int64_t antiquantMode = 0 ;
2688+ bool softmaxLseFlag = false ;
2689+ int64_t keyAntiquantMode = 0 ;
2690+ int64_t valueAntiquantMode = 0 ;
2691+
2692+ // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
2693+
2694+ GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
2695+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
2696+ nullptr , acl_mask_bool_tensor, // pse, mask
2697+ nullptr , nullptr , // actSeqLen, actSeqLenkv
2698+ nullptr , nullptr , // deqScale1, quantScale1
2699+ nullptr , nullptr , nullptr , // deqScale2, quantScale2, quantOffset2
2700+ nullptr , nullptr , // antiquantScale, antiquantOffset
2701+ nullptr , // blockTable
2702+ nullptr , nullptr , // qPadSize, kvPadSize
2703+ nullptr , nullptr , // kAntiquantScale, kAntiQuantOffset
2704+ nullptr , nullptr , // vAntiquantScale, vAntiQuantOffset
2705+ nullptr , nullptr , nullptr , // kSharedPrefix, vSharedPrefix, actSharedLen
2706+ numHeads, scaleValue, // heads, scaleValue
2707+ preTokens, nextTokens, // preTokens, nextTokens
2708+ layout, // inputLayout
2709+ numKeyValueHeads, // numKVHeads
2710+ sparseMode, innerPrecise, // sparseMode, innerPrecise
2711+ blockSize, antiquantMode, // blockSize, antiquantMode
2712+ softmaxLseFlag, // softmaxLseFlag
2713+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
2714+ acl_out_tensor, // attentionOut
2715+ nullptr // softmaxLse
2716+ );
2717+
2718+ // Step 5: post-processing, permute and cast to f32
2719+ int64_t new_dim[] = {0 , 2 , 1 , 3 };
2720+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
2721+
2722+ if (dst->type != GGML_TYPE_F16){
2723+ ggml_cann_pool_alloc perm_out_f16_allocator (ctx.pool ());
2724+ perm_out_f16_allocator.alloc (ggml_nelements (dst) * faElemSize);
2725+ void * perm_out_f16_buffer = perm_out_f16_allocator.get ();
2726+
2727+ int64_t * perm_out_f16_ne = dst->ne ;
2728+ size_t perm_out_f16_nb[GGML_MAX_DIMS];
2729+ perm_out_f16_nb[0 ] = faElemSize;
2730+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
2731+ perm_out_f16_nb[i] = perm_out_f16_nb[i - 1 ] * perm_out_f16_ne[i - 1 ];
2732+ }
2733+ aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor (
2734+ perm_out_f16_buffer, ACL_FLOAT16, faElemSize,
2735+ perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
2736+ aclnn_permute (ctx, acl_out_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
2737+ aclnn_cast (ctx,
2738+ acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
2739+ ggml_cann_release_resources (ctx, acl_perm_out_f16_tensor);
2740+ }else {
2741+ // only need to permute
2742+ aclnn_permute (ctx, acl_out_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
2743+ }
2744+
2745+ ggml_cann_release_resources (ctx, acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list,
2746+ acl_mask_bool_tensor, acl_out_f16_tensor,
2747+ acl_dst_tensor);
2748+
2749+ }
0 commit comments