@@ -1765,35 +1765,31 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17651765 ggml_tensor* src0 = dst->src [0 ]; // src
17661766 ggml_tensor* src1 = dst->src [1 ]; // index
17671767
1768- switch (src0->type ) {
1769- case GGML_TYPE_F32: {
1768+ if (src0-> type == dst ->type ) {
1769+ GGML_ASSERT (src0-> type == GGML_TYPE_F32 || src0-> type == GGML_TYPE_F16);
17701770 aclnn_index_select_4d (ctx, src0->data , src0->ne , src0->nb ,
17711771 dst->data , dst->ne , dst->nb ,
17721772 src1, dst->type );
1773- break ;
1774- }
1775- case GGML_TYPE_F16: {
1773+ } else if (src0->type == GGML_TYPE_F16) {
17761774 aclTensor* acl_src0 = ggml_cann_create_tensor (src0);
17771775 ggml_cann_pool_alloc src_buffer_allocator (
1778- ctx.pool (), ggml_nelements (src0) * sizeof ( float ));
1776+ ctx.pool (), ggml_nelements (src0) * ggml_element_size (dst ));
17791777 void * src_trans_buffer = src_buffer_allocator.get ();
17801778 size_t src_trans_nb[GGML_MAX_DIMS];
1781- src_trans_nb[0 ] = sizeof ( float ) ;
1779+ src_trans_nb[0 ] = dst-> nb [ 0 ] ;
17821780 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
17831781 src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
17841782 }
17851783 aclTensor* src_trans_tensor = ggml_cann_create_tensor (
1786- src_trans_buffer, ACL_FLOAT , ggml_type_size (dst->type ),
1784+ src_trans_buffer, ggml_cann_type_mapping (dst-> type ) , ggml_type_size (dst->type ),
17871785 src0->ne , src_trans_nb, GGML_MAX_DIMS);
17881786 aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
17891787 aclnn_index_select_4d (ctx, src_trans_buffer, src0->ne , src_trans_nb,
17901788 dst->data , dst->ne , dst->nb ,
17911789 src1, dst->type );
17921790 ggml_cann_release_resources (ctx, acl_src0, src_trans_tensor);
1793- break ;
1794- }
1795- case GGML_TYPE_Q8_0: {
1796- // add 1 dim for bcast mul.
1791+ } else if (src0->type == GGML_TYPE_Q8_0){
1792+ // add 1 dim for bcast mul.
17971793 size_t weight_nb[GGML_MAX_DIMS + 1 ], scale_nb[GGML_MAX_DIMS + 1 ],
17981794 dequant_nb[GGML_MAX_DIMS + 1 ];
17991795 int64_t weight_ne[GGML_MAX_DIMS + 1 ], scale_ne[GGML_MAX_DIMS + 1 ],
@@ -1854,11 +1850,8 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18541850 src1, dst->type );
18551851
18561852 ggml_cann_release_resources (ctx, dequant_tensor);
1857- break ;
1858- }
1859- default :
1853+ } else {
18601854 GGML_ABORT (" Unsupported tensor type for GGML_OP_GET_ROWS" );
1861- break ;
18621855 }
18631856}
18641857
@@ -3178,7 +3171,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
31783171 aclTensor* acl_src0_f16_tensor = nullptr ;
31793172 aclTensor* acl_src1_f16_tensor = nullptr ;
31803173 aclTensor* acl_src2_f16_tensor = nullptr ;
3181- aclTensor* acl_dst_f16_tensor = nullptr ;
31823174
31833175 // Step 1: cast the src0 (Query) to fp16 if needed
31843176 ggml_cann_pool_alloc src0_f16_allocator (ctx.pool ());
@@ -3216,22 +3208,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32163208 acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
32173209 src2_bsnd_nb, GGML_MAX_DIMS);
32183210
3219- ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3220- void * out_f16_buffer = out_f16_allocator.alloc (
3221- ggml_nelements (dst) * faElemSize);
3222-
3223- int64_t * out_f16_ne = src0_bsnd_ne;
3224- size_t out_f16_nb[GGML_MAX_DIMS];
3225- out_f16_nb[0 ] = faElemSize;
3226- for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3227- out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3228- }
3229-
3230- acl_dst_f16_tensor = ggml_cann_create_tensor (
3231- out_f16_buffer, faDataType, faElemSize,
3232- out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3233- );
3234-
32353211 // Step 3: create the PSEShift tensor if needed
32363212 // this tensor is considered as mask (f16) in the llama.cpp
32373213 aclTensor* bcast_pse_tensor = nullptr ;
@@ -3336,40 +3312,88 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33363312
33373313 // Step 5: launch the FusedInferAttentionScoreV2 kernel.
33383314 // Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
3315+ if (dst->type == GGML_TYPE_F16) {
3316+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3317+
3318+ GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
3319+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
3320+ bcast_pse_tensor, nullptr , // pse, mask
3321+ nullptr , nullptr , // actSeqLen, actSeqLenkv
3322+ nullptr , nullptr , // deqScale1, quantScale1
3323+ nullptr , nullptr , nullptr , // deqScale2, quantScale2, quantOffset2
3324+ nullptr , nullptr , // antiquantScale, antiquantOffset
3325+ nullptr , // blockTable
3326+ nullptr , nullptr , // qPadSize, kvPadSize
3327+ nullptr , nullptr , // kAntiquantScale, kAntiQuantOffset
3328+ nullptr , nullptr , // vAntiquantScale, vAntiQuantOffset
3329+ nullptr , nullptr , nullptr , // kSharedPrefix, vSharedPrefix, actSharedLen
3330+ numHeads, scaleValue, // heads, scaleValue
3331+ preTokens, nextTokens, // preTokens, nextTokens
3332+ layout, // inputLayout
3333+ numKeyValueHeads, // numKVHeads
3334+ sparseMode, innerPrecise, // sparseMode, innerPrecise
3335+ blockSize, antiquantMode, // blockSize, antiquantMode
3336+ softmaxLseFlag, // softmaxLseFlag
3337+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3338+ acl_dst_tensor, // attentionOut
3339+ nullptr // softmaxLse
3340+ );
3341+
3342+ ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3343+ acl_src1_f16_tensor,
3344+ acl_src2_f16_tensor,
3345+ acl_dst_tensor);
3346+ } else {
3347+ aclTensor* acl_dst_f16_tensor = nullptr ;
3348+ ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3349+ void * out_f16_buffer = out_f16_allocator.alloc (
3350+ ggml_nelements (dst) * faElemSize);
3351+
3352+ int64_t * out_f16_ne = src0_bsnd_ne;
3353+ size_t out_f16_nb[GGML_MAX_DIMS];
3354+ out_f16_nb[0 ] = faElemSize;
3355+ for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3356+ out_f16_nb[i] = out_f16_nb[i - 1 ] * out_f16_ne[i - 1 ];
3357+ }
3358+
3359+ acl_dst_f16_tensor = ggml_cann_create_tensor (
3360+ out_f16_buffer, faDataType, faElemSize,
3361+ out_f16_ne, out_f16_nb, GGML_MAX_DIMS
3362+ );
3363+ GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
3364+ acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
3365+ bcast_pse_tensor, nullptr , // pse, mask
3366+ nullptr , nullptr , // actSeqLen, actSeqLenkv
3367+ nullptr , nullptr , // deqScale1, quantScale1
3368+ nullptr , nullptr , nullptr , // deqScale2, quantScale2, quantOffset2
3369+ nullptr , nullptr , // antiquantScale, antiquantOffset
3370+ nullptr , // blockTable
3371+ nullptr , nullptr , // qPadSize, kvPadSize
3372+ nullptr , nullptr , // kAntiquantScale, kAntiQuantOffset
3373+ nullptr , nullptr , // vAntiquantScale, vAntiQuantOffset
3374+ nullptr , nullptr , nullptr , // kSharedPrefix, vSharedPrefix, actSharedLen
3375+ numHeads, scaleValue, // heads, scaleValue
3376+ preTokens, nextTokens, // preTokens, nextTokens
3377+ layout, // inputLayout
3378+ numKeyValueHeads, // numKVHeads
3379+ sparseMode, innerPrecise, // sparseMode, innerPrecise
3380+ blockSize, antiquantMode, // blockSize, antiquantMode
3381+ softmaxLseFlag, // softmaxLseFlag
3382+ keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3383+ acl_dst_f16_tensor, // attentionOut
3384+ nullptr // softmaxLse
3385+ );
3386+ // Step 6: post-processing, permute and cast to f32
3387+ aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3388+ // TODO: when dst is fp16, don't need cast
3389+ aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3390+ ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3391+ acl_src1_f16_tensor,
3392+ acl_src2_f16_tensor,
3393+ acl_dst_f16_tensor,
3394+ acl_dst_tensor);
3395+ }
33393396
3340- GGML_CANN_CALL_ACLNN_OP (ctx, FusedInferAttentionScoreV2,
3341- acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
3342- bcast_pse_tensor, nullptr , // pse, mask
3343- nullptr , nullptr , // actSeqLen, actSeqLenkv
3344- nullptr , nullptr , // deqScale1, quantScale1
3345- nullptr , nullptr , nullptr , // deqScale2, quantScale2, quantOffset2
3346- nullptr , nullptr , // antiquantScale, antiquantOffset
3347- nullptr , // blockTable
3348- nullptr , nullptr , // qPadSize, kvPadSize
3349- nullptr , nullptr , // kAntiquantScale, kAntiQuantOffset
3350- nullptr , nullptr , // vAntiquantScale, vAntiQuantOffset
3351- nullptr , nullptr , nullptr , // kSharedPrefix, vSharedPrefix, actSharedLen
3352- numHeads, scaleValue, // heads, scaleValue
3353- preTokens, nextTokens, // preTokens, nextTokens
3354- layout, // inputLayout
3355- numKeyValueHeads, // numKVHeads
3356- sparseMode, innerPrecise, // sparseMode, innerPrecise
3357- blockSize, antiquantMode, // blockSize, antiquantMode
3358- softmaxLseFlag, // softmaxLseFlag
3359- keyAntiquantMode, valueAntiquantMode, // keyAntiqMode, valueAntiqMode
3360- acl_dst_f16_tensor, // attentionOut
3361- nullptr // softmaxLse
3362- );
3363-
3364- // Step 6: post-processing, permute and cast to f32
3365- aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3366- // TODO: when dst is fp16, don't need cast
3367- aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3368- ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3369- acl_src1_f16_tensor,
3370- acl_src2_f16_tensor,
3371- acl_dst_f16_tensor,
3372- acl_dst_tensor);
33733397 if (src3 != nullptr ){
33743398 ggml_cann_release_resources (ctx, bcast_pse_tensor);
33753399 }
0 commit comments