@@ -2690,10 +2690,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26902690 memcpy (&maxBias, (float *)dst->op_params + 1 , sizeof (float ));
26912691 memcpy (&logitSoftcap, (float *)dst->op_params + 2 , sizeof (float ));
26922692
2693- // if(logitSoftcap != 0.0f){
2694- // // call the non-fa implementation
2695- // }else{
2696-
26972693 size_t faElemSize = sizeof (uint16_t );
26982694 auto faDataType = ACL_FLOAT16; // ACL_BF16;
26992695
@@ -2825,6 +2821,108 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28252821#endif
28262822 ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
28272823 }
2824+
2825+ if (maxBias != 0 .0f ){
2826+ // alibi
2827+ const int64_t ne2_ne3 = src0->ne [2 ] * src0->ne [3 ];
2828+ const int64_t n_head = src0->ne [2 ];
2829+ const int n_heads_log2_floor = 1u << (uint32_t )floor (log2 (n_head));
2830+ float m0 = powf (2 .0f , -(maxBias) / n_heads_log2_floor);
2831+ float m1 = powf (2 .0f , -(maxBias / 2 .0f ) / n_heads_log2_floor);
2832+ // init arange
2833+ ggml_cann_pool_alloc arange_allocator (ctx.pool (),
2834+ ne2_ne3 * faElemSize);
2835+ void * tmp_arange_buffer = arange_allocator.get ();
2836+
2837+ // arange1: [1, ..., n_heads_log2_floor+1)
2838+ float start = 1 ;
2839+ float stop = n_heads_log2_floor + 1 ;
2840+ float step = 1 ;
2841+ int64_t n_elements_arange = n_heads_log2_floor;
2842+
2843+ int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
2844+ size_t tmp_arange1_nb[] = {faElemSize};
2845+ aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor (
2846+ tmp_arange_buffer, faDataType, faElemSize,
2847+ tmp_arange1_ne, tmp_arange1_nb,
2848+ GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
2849+
2850+ aclnn_arange (ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
2851+
2852+ aclTensor* tmp_arange2_tensor = nullptr ;
2853+ if (n_heads_log2_floor < ne2_ne3) {
2854+ // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
2855+ start = 1 ;
2856+ stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1 ;
2857+ step = 2 ;
2858+ n_elements_arange = ne2_ne3 - n_heads_log2_floor;
2859+ int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
2860+ size_t tmp_arange2_nb[] = {faElemSize};
2861+
2862+ aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor (
2863+ (char *)tmp_arange_buffer +
2864+ n_heads_log2_floor * faElemSize,
2865+ faDataType, faElemSize,
2866+ tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
2867+ aclnn_arange (ctx, tmp_arange2_tensor, start, stop, step,
2868+ n_elements_arange);
2869+ }
2870+
2871+ // init mk_base
2872+ ggml_cann_pool_alloc mk_base_allocator (ctx.pool (),
2873+ ne2_ne3 * faElemSize);
2874+ void * tmp_mk_base_buffer = mk_base_allocator.get ();
2875+ int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
2876+ size_t tmp_mk_base1_nb[] = {faElemSize};
2877+ aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor (
2878+ tmp_mk_base_buffer, faDataType, faElemSize,
2879+ tmp_mk_base1_ne, tmp_mk_base1_nb,
2880+ GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
2881+
2882+ aclnn_fill_scalar (ctx, m0, tmp_mk_base1_tensor);
2883+
2884+ aclTensor* tmp_mk_base2_tensor = nullptr ;
2885+ if (n_heads_log2_floor < ne2_ne3) {
2886+ int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
2887+ size_t tmp_mk_base2_nb[] = {faElemSize};
2888+ aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor (
2889+ (char *)tmp_mk_base_buffer +
2890+ n_heads_log2_floor * faElemSize,
2891+ faDataType, faElemSize,
2892+ tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
2893+ aclnn_fill_scalar (ctx, m1, tmp_mk_base2_tensor);
2894+ }
2895+
2896+ // init mk
2897+ int64_t tmp_mk_base_ne[] = {ne2_ne3};
2898+ size_t tmp_mk_base_nb[] = {faElemSize};
2899+ aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor (
2900+ tmp_mk_base_buffer, faDataType, faElemSize,
2901+ tmp_mk_base_ne, tmp_mk_base_nb,
2902+ GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
2903+ aclTensor* tmp_arange_tensor = ggml_cann_create_tensor (
2904+ tmp_arange_buffer, faDataType, faElemSize,
2905+ tmp_mk_base_ne, tmp_mk_base_nb,
2906+ GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
2907+ aclnn_pow_tensor_tensor (ctx, tmp_mk_base_tensor, tmp_arange_tensor);
2908+
2909+ // reshape mk
2910+ int64_t tmp_mk_ne[] = {1 , 1 , src0->ne [2 ], src0->ne [3 ]};
2911+ size_t tmp_mk_nb[GGML_MAX_DIMS];
2912+ tmp_mk_nb[0 ] = faElemSize;
2913+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2914+ tmp_mk_nb[i] = tmp_mk_nb[i - 1 ] * tmp_mk_ne[i - 1 ];
2915+ }
2916+ aclTensor* tmp_mk_tensor = ggml_cann_create_tensor (
2917+ tmp_mk_base_buffer, faDataType, faElemSize,
2918+ tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
2919+ ACL_FORMAT_ND);
2920+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
2921+
2922+ ggml_cann_release_resources (ctx, tmp_arange1_tensor, tmp_arange2_tensor,
2923+ tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
2924+ tmp_arange_tensor, tmp_mk_tensor);
2925+ }
28282926 }
28292927
28302928#ifdef DEBUG
@@ -2931,4 +3029,4 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
29313029 ggml_cann_release_resources (ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor);
29323030 if (src3)
29333031 ggml_cann_release_resources (ctx, bcast_pse_tensor);
2934- }
3032+ }
0 commit comments