@@ -3093,104 +3093,24 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
30933093 // Compute the slope if needed. Derived from ggml_cann_softmax().
30943094 if (maxBias != 0 .0f ){
30953095 // alibi
3096- const int64_t ne2_ne3 = src0->ne [2 ] * src0->ne [3 ];
3097- const int64_t n_head = src0->ne [2 ];
3098- const int n_head_log2 = 1u << (uint32_t )floor (log2 (n_head));
3099- float m0 = powf (2 .0f , -(maxBias) / n_head_log2);
3100- float m1 = powf (2 .0f , -(maxBias / 2 .0f ) / n_head_log2);
3101- // init arange
3102- ggml_cann_pool_alloc arange_allocator (ctx.pool (),
3103- ne2_ne3 * faElemSize);
3104- void * tmp_arange_buffer = arange_allocator.get ();
3105-
3106- // arange1: [1, ..., n_head_log2+1)
3107- float start = 1 ;
3108- float stop = n_head_log2 + 1 ;
3109- float step = 1 ;
3110- int64_t n_elements_arange = n_head_log2;
3111-
3112- int64_t tmp_arange1_ne[] = {n_head_log2};
3113- size_t tmp_arange1_nb[] = {faElemSize};
3114- aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor (
3115- tmp_arange_buffer, faDataType, faElemSize,
3116- tmp_arange1_ne, tmp_arange1_nb,
3117- GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
3118-
3119- aclnn_arange (ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
3120-
3121- aclTensor* tmp_arange2_tensor = nullptr ;
3122- if (n_head_log2 < ne2_ne3) {
3123- // arange2: [1, ..., 2 * (k - n_head_log2) + 1)
3124- start = 1 ;
3125- stop = 2 * (ne2_ne3 - n_head_log2) + 1 ;
3126- step = 2 ;
3127- n_elements_arange = ne2_ne3 - n_head_log2;
3128- int64_t tmp_arange2_ne[] = {ne2_ne3 - n_head_log2};
3129- size_t tmp_arange2_nb[] = {faElemSize};
3130-
3131- aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor (
3132- (char *)tmp_arange_buffer +
3133- n_head_log2 * faElemSize,
3134- faDataType, faElemSize,
3135- tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
3136- aclnn_arange (ctx, tmp_arange2_tensor, start, stop, step,
3137- n_elements_arange);
3096+ const int64_t n_heads = src0->ne [2 ];
3097+ ggml_cann_pool_alloc slope_allocator (ctx.pool (), n_heads * sizeof (float ));
3098+ void * slope_buffer = slope_allocator.get ();
3099+ aclnn_get_slope (ctx, n_heads, slope_buffer, maxBias);
3100+
3101+ int64_t slope_ne[] = {1 , 1 , n_heads, 1 };
3102+ size_t slope_nb[GGML_MAX_DIMS];
3103+ slope_nb[0 ] = sizeof (float );
3104+ for (int i = 1 ;i<GGML_MAX_DIMS;i++) {
3105+ slope_nb[i] = slope_nb[i-1 ] * slope_ne[0 ];
31383106 }
31393107
3140- // init mk_base
3141- ggml_cann_pool_alloc mk_base_allocator (ctx.pool (),
3142- ne2_ne3 * faElemSize);
3143- void * tmp_mk_base_buffer = mk_base_allocator.get ();
3144- int64_t tmp_mk_base1_ne[] = {n_head_log2};
3145- size_t tmp_mk_base1_nb[] = {faElemSize};
3146- aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor (
3147- tmp_mk_base_buffer, faDataType, faElemSize,
3148- tmp_mk_base1_ne, tmp_mk_base1_nb,
3149- GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
3150-
3151- aclnn_fill_scalar (ctx, m0, tmp_mk_base1_tensor);
3152-
3153- aclTensor* tmp_mk_base2_tensor = nullptr ;
3154- if (n_head_log2 < ne2_ne3) {
3155- int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_head_log2};
3156- size_t tmp_mk_base2_nb[] = {faElemSize};
3157- aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor (
3158- (char *)tmp_mk_base_buffer +
3159- n_head_log2 * faElemSize,
3160- faDataType, faElemSize,
3161- tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
3162- aclnn_fill_scalar (ctx, m1, tmp_mk_base2_tensor);
3163- }
3108+ aclTensor* slope_tensor = ggml_cann_create_tensor (
3109+ slope_buffer, ACL_FLOAT, sizeof (float ),
3110+ slope_ne, slope_nb, GGML_MAX_DIMS);
3111+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
31643112
3165- // init mk
3166- int64_t tmp_mk_base_ne[] = {ne2_ne3};
3167- size_t tmp_mk_base_nb[] = {faElemSize};
3168- aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor (
3169- tmp_mk_base_buffer, faDataType, faElemSize,
3170- tmp_mk_base_ne, tmp_mk_base_nb,
3171- GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
3172- aclTensor* tmp_arange_tensor = ggml_cann_create_tensor (
3173- tmp_arange_buffer, faDataType, faElemSize,
3174- tmp_mk_base_ne, tmp_mk_base_nb,
3175- GGML_MAX_DIMS - 3 , ACL_FORMAT_ND);
3176- aclnn_pow_tensor_tensor (ctx, tmp_mk_base_tensor, tmp_arange_tensor);
3177-
3178- // reshape mk
3179- int64_t tmp_mk_ne[] = {1 , 1 , src0->ne [2 ], src0->ne [3 ]};
3180- size_t tmp_mk_nb[GGML_MAX_DIMS];
3181- tmp_mk_nb[0 ] = faElemSize;
3182- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3183- tmp_mk_nb[i] = tmp_mk_nb[i - 1 ] * tmp_mk_ne[i - 1 ];
3184- }
3185- aclTensor* tmp_mk_tensor = ggml_cann_create_tensor (
3186- tmp_mk_base_buffer, faDataType, faElemSize,
3187- tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
3188- ACL_FORMAT_ND);
3189- GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
3190-
3191- ggml_cann_release_resources (ctx, tmp_arange1_tensor, tmp_arange2_tensor,
3192- tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
3193- tmp_arange_tensor, tmp_mk_tensor);
3113+ ggml_cann_release_resources (ctx, slope_tensor);
31943114 }
31953115 }
31963116
0 commit comments