@@ -1427,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
1427
1427
static void aclnn_get_slope_inner (ggml_backend_cann_context& ctx, void * slope_buffer,
1428
1428
float m, int64_t size, float start, float stop, float step){
1429
1429
int64_t ne[] = {size};
1430
- size_t nb[] = {sizeof (float )};
1430
+ size_t nb[] = {sizeof (uint16_t )};
1431
1431
1432
- ggml_cann_pool_alloc arange_allocator (ctx.pool (), size * sizeof (float ));
1432
+ ggml_cann_pool_alloc arange_allocator (ctx.pool (), size * sizeof (uint16_t ));
1433
1433
void * arange_buffer = arange_allocator.get ();
1434
1434
1435
1435
aclTensor* arange_tensor = ggml_cann_create_tensor (
1436
- arange_buffer, ACL_FLOAT , sizeof (float ), ne, nb, 1 );
1436
+ arange_buffer, ACL_FLOAT16 , sizeof (uint16_t ), ne, nb, 1 );
1437
1437
aclnn_arange (ctx, arange_tensor, start, stop, step, size);
1438
1438
1439
1439
aclTensor* slope_tensor = ggml_cann_create_tensor (
1440
- slope_buffer, ACL_FLOAT , sizeof (float ), ne, nb, 1 );
1440
+ slope_buffer, ACL_FLOAT16 , sizeof (uint16_t ), ne, nb, 1 );
1441
1441
1442
1442
aclScalar* sc = aclCreateScalar (&m, aclDataType::ACL_FLOAT);
1443
1443
@@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
3180
3180
3181
3181
void ggml_cann_flash_attn_ext (ggml_backend_cann_context& ctx, ggml_tensor* dst){
3182
3182
3183
- ggml_tensor* src0 = dst->src [0 ]; // q, fp32
3184
- ggml_tensor* src1 = dst->src [1 ]; // k, fp16
3185
- ggml_tensor* src2 = dst->src [2 ]; // v, fp16
3183
+ ggml_tensor* src0 = dst->src [0 ]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
3184
+ ggml_tensor* src1 = dst->src [1 ]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3185
+ ggml_tensor* src2 = dst->src [2 ]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3186
3186
ggml_tensor* src3 = dst->src [3 ]; // mask, fp16
3187
3187
3188
+ // B, N, S, D (uncont) -> B, S, N, D (cont)
3189
+ int64_t src0_bsnd_ne[GGML_MAX_DIMS];
3190
+ memcpy (src0_bsnd_ne, src0->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3191
+ size_t src0_bsnd_nb[GGML_MAX_DIMS];
3192
+ memcpy (src0_bsnd_nb, src0->nb , GGML_MAX_DIMS * sizeof (size_t ));
3193
+ int64_t src1_bsnd_ne[GGML_MAX_DIMS];
3194
+ memcpy (src1_bsnd_ne, src1->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3195
+ size_t src1_bsnd_nb[GGML_MAX_DIMS];
3196
+ memcpy (src1_bsnd_nb, src1->nb , GGML_MAX_DIMS * sizeof (size_t ));
3197
+ int64_t src2_bsnd_ne[GGML_MAX_DIMS];
3198
+ memcpy (src2_bsnd_ne, src2->ne , GGML_MAX_DIMS * sizeof (int64_t ));
3199
+ size_t src2_bsnd_nb[GGML_MAX_DIMS];
3200
+ memcpy (src2_bsnd_nb, src2->nb , GGML_MAX_DIMS * sizeof (size_t ));
3201
+
3202
+ auto transpose12 = [](int64_t * ne, size_t * nb) {
3203
+ int64_t ne_tmp = ne[1 ];
3204
+ size_t nb_tmp = nb[1 ];
3205
+ ne[1 ] = ne[2 ];
3206
+ nb[1 ] = nb[2 ];
3207
+ ne[2 ] = ne_tmp;
3208
+ nb[2 ] = nb_tmp;
3209
+ };
3210
+
3211
+ transpose12 (src0_bsnd_ne, src0_bsnd_nb);
3212
+ transpose12 (src1_bsnd_ne, src1_bsnd_nb);
3213
+ transpose12 (src2_bsnd_ne, src2_bsnd_nb);
3214
+
3188
3215
float maxBias = 0 .0f ;
3189
3216
float scaleValue = 1 .0f ;
3190
3217
float logitSoftcap = 0 .0f ;
@@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3206
3233
void * src0_f16_buffer = nullptr ;
3207
3234
3208
3235
if (ggml_cann_type_mapping (src0->type ) != faDataType){
3209
- aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0);
3236
+ aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor (src0, src0_bsnd_ne,
3237
+ src0_bsnd_nb, GGML_MAX_DIMS);
3210
3238
src0_f16_buffer = src0_f16_allocator.alloc (
3211
3239
ggml_nelements (src0) * faElemSize);
3212
3240
3213
- int64_t * src0_f16_ne = src0-> ne ;
3241
+ int64_t * src0_f16_ne = src0_bsnd_ne ;
3214
3242
size_t src0_f16_nb[GGML_MAX_DIMS];
3215
3243
src0_f16_nb[0 ] = sizeof (uint16_t );
3216
3244
for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
@@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3224
3252
aclnn_cast (ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
3225
3253
ggml_cann_release_resources (ctx, acl_src0_f32_tensor);
3226
3254
}else {
3227
- acl_src0_f16_tensor = ggml_cann_create_tensor (src0);
3255
+ acl_src0_f16_tensor = ggml_cann_create_tensor (src0, src0_bsnd_ne,
3256
+ src0_bsnd_nb, GGML_MAX_DIMS);
3228
3257
}
3229
3258
3230
3259
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
3231
3260
// and the direct output from FusedInferAttention
3232
3261
3233
- acl_src1_f16_tensor = ggml_cann_create_tensor (src1);
3234
- acl_src2_f16_tensor = ggml_cann_create_tensor (src2);
3262
+ acl_src1_f16_tensor = ggml_cann_create_tensor (src1, src1_bsnd_ne,
3263
+ src1_bsnd_nb, GGML_MAX_DIMS);
3264
+ acl_src2_f16_tensor = ggml_cann_create_tensor (src2, src2_bsnd_ne,
3265
+ src2_bsnd_nb, GGML_MAX_DIMS);
3235
3266
3236
3267
ggml_cann_pool_alloc out_f16_allocator (ctx.pool ());
3237
3268
void * out_f16_buffer = out_f16_allocator.alloc (
3238
3269
ggml_nelements (dst) * faElemSize);
3239
3270
3240
- int64_t * out_f16_ne = src0-> ne ;
3271
+ int64_t * out_f16_ne = src0_bsnd_ne ;
3241
3272
size_t out_f16_nb[GGML_MAX_DIMS];
3242
3273
out_f16_nb[0 ] = faElemSize;
3243
3274
for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
@@ -3251,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3251
3282
3252
3283
// Step 3: create the PSEShift tensor if needed
3253
3284
// this tensor is considered as mask (f16) in the llama.cpp
3254
-
3255
3285
aclTensor* bcast_pse_tensor = nullptr ;
3256
- int64_t bcast_pse_ne[GGML_MAX_DIMS];
3257
- size_t bcast_pse_nb[GGML_MAX_DIMS];
3258
3286
ggml_cann_pool_alloc bcast_pse_allocator (ctx.pool ());
3259
- void * bcast_pse_buffer = nullptr ;
3260
-
3261
3287
if (src3 != nullptr ){
3262
- bcast_pse_buffer = bcast_pse_allocator.alloc (
3263
- ggml_nelements (src3) * src0->ne [2 ] * sizeof (uint16_t ));
3264
-
3265
- if (src0->ne [1 ] > 1 ){
3266
- // Case 1: broadcast pse for prefill stage with multiple head
3267
- aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor (src3);
3268
- bcast_pse_ne[0 ] = src3->ne [0 ];
3269
- bcast_pse_ne[1 ] = src3->ne [1 ];
3270
- bcast_pse_ne[2 ] = src0->ne [2 ];
3271
- bcast_pse_ne[3 ] = src3->ne [3 ];
3288
+ // Construct the truncated pse tensor (common for prefill/decode)
3289
+ int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
3290
+ src3->ne [0 ], // D
3291
+ src0->ne [1 ], // S (number of Q tokens)
3292
+ src3->ne [2 ], // mask N
3293
+ src3->ne [3 ] // B
3294
+ };
3295
+ size_t * trunc_pse_nb = src3->nb ;
3296
+
3297
+ aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor (
3298
+ src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3299
+ trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3300
+ );
3272
3301
3302
+ int64_t bcast_pse_ne[GGML_MAX_DIMS];
3303
+ size_t bcast_pse_nb[GGML_MAX_DIMS];
3304
+ bcast_pse_ne[0 ] = src3->ne [0 ]; // D
3305
+ bcast_pse_ne[1 ] = src0->ne [1 ]; // S
3306
+ bcast_pse_ne[2 ] = src0->ne [2 ]; // N (num_heads)
3307
+ bcast_pse_ne[3 ] = src3->ne [3 ]; // B
3308
+ if (maxBias == 0 .0f ) {
3309
+ // When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
3310
+ // Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
3273
3311
bcast_pse_nb[0 ] = sizeof (uint16_t );
3274
- for ( int i = 1 ; i < GGML_MAX_DIMS; ++i){
3275
- bcast_pse_nb[i ] = bcast_pse_nb[i - 1 ] * bcast_pse_ne[i - 1 ];
3276
- }
3312
+ bcast_pse_nb[ 1 ] = bcast_pse_nb[ 0 ] * bcast_pse_ne[ 0 ];
3313
+ bcast_pse_nb[2 ] = 0 ; // <---- the head dimension shares the same data
3314
+ bcast_pse_nb[ 3 ] = src3-> nb [ 3 ];
3277
3315
3278
3316
bcast_pse_tensor = ggml_cann_create_tensor (
3279
- bcast_pse_buffer, ACL_FLOAT16, sizeof (uint16_t ),
3280
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3281
-
3282
- int64_t repeats[] = {1 , src0->ne [2 ], 1 , 1 };
3283
- aclnn_repeat (ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
3284
-
3285
- ggml_cann_release_resources (ctx, acl_mask_f16_tensor);
3286
- }else {
3287
- // Case 2: trunc the first row and broadcast pse for decode stage with multiple head
3288
- int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne [0 ], src0->ne [1 ], src3->ne [2 ], src3->ne [3 ]};
3289
- size_t * trunc_pse_nb = src3->nb ;
3290
-
3291
- aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor (
3292
3317
src3->data , ACL_FLOAT16, sizeof (uint16_t ),
3293
- trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
3294
-
3295
- bcast_pse_ne[0 ] = src3->ne [0 ];
3296
- bcast_pse_ne[1 ] = src0->ne [1 ];
3297
- bcast_pse_ne[2 ] = src0->ne [2 ];
3298
- bcast_pse_ne[3 ] = src3->ne [3 ];
3318
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3319
+ );
3299
3320
3321
+ ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
3322
+ } else {
3300
3323
bcast_pse_nb[0 ] = sizeof (uint16_t );
3301
- for (int i = 1 ; i < GGML_MAX_DIMS; ++i) {
3324
+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3302
3325
bcast_pse_nb[i] = bcast_pse_nb[i - 1 ] * bcast_pse_ne[i - 1 ];
3303
3326
}
3304
3327
3328
+ void * bcast_pse_buffer = bcast_pse_allocator.alloc (
3329
+ ggml_nelements (src3) * src0->ne [2 ] * sizeof (uint16_t )
3330
+ );
3331
+
3305
3332
bcast_pse_tensor = ggml_cann_create_tensor (
3306
3333
bcast_pse_buffer, ACL_FLOAT16, sizeof (uint16_t ),
3307
- bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3334
+ bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3335
+ );
3308
3336
3309
3337
int64_t repeats[] = {1 , src0->ne [2 ], 1 , 1 };
3310
3338
aclnn_repeat (ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
3311
3339
3312
- ggml_cann_release_resources (ctx, acl_mask_f16_trunc_tensor);
3313
- }
3314
-
3315
- // Compute the slope if needed. Derived from ggml_cann_softmax().
3316
- if (maxBias != 0 .0f ){
3317
3340
// alibi
3341
+ // Compute the slope if needed. Derived from ggml_cann_softmax().
3318
3342
const int64_t n_heads = src0->ne [2 ];
3319
- ggml_cann_pool_alloc slope_allocator (ctx.pool (), n_heads * sizeof (float ));
3343
+ ggml_cann_pool_alloc slope_allocator (ctx.pool (), n_heads * sizeof (uint16_t ));
3320
3344
void * slope_buffer = slope_allocator.get ();
3321
3345
aclnn_get_slope (ctx, n_heads, slope_buffer, maxBias);
3322
3346
3323
3347
int64_t slope_ne[] = {1 , 1 , n_heads, 1 };
3324
3348
size_t slope_nb[GGML_MAX_DIMS];
3325
- slope_nb[0 ] = sizeof (float );
3349
+ slope_nb[0 ] = sizeof (uint16_t );
3326
3350
for (int i = 1 ;i<GGML_MAX_DIMS;i++) {
3327
3351
slope_nb[i] = slope_nb[i-1 ] * slope_ne[0 ];
3328
3352
}
3329
3353
3330
3354
aclTensor* slope_tensor = ggml_cann_create_tensor (
3331
- slope_buffer, ACL_FLOAT , sizeof (float ),
3355
+ slope_buffer, ACL_FLOAT16 , sizeof (uint16_t ),
3332
3356
slope_ne, slope_nb, GGML_MAX_DIMS);
3333
3357
GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
3334
3358
3335
- ggml_cann_release_resources (ctx, slope_tensor);
3359
+ ggml_cann_release_resources (ctx, slope_tensor, acl_mask_f16_trunc_tensor );
3336
3360
}
3337
3361
}
3338
3362
@@ -3349,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3349
3373
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
3350
3374
int64_t preTokens = 65535 ;
3351
3375
int64_t nextTokens = 65535 ;
3352
- char layout[5 ] = {' B' , ' N ' , ' S ' , ' D' , 0 };
3376
+ char layout[5 ] = {' B' , ' S ' , ' N ' , ' D' , 0 };
3353
3377
int64_t sparseMode = 0 ;
3354
3378
int64_t innerPrecise = (src0->ne [1 ] == 1 ) ? 0 : 2 ;
3355
3379
int64_t blockSize = 0 ;
@@ -3386,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
3386
3410
);
3387
3411
3388
3412
// Step 6: post-processing, permute and cast to f32
3389
-
3390
- int64_t new_dim[] = {0 , 2 , 1 , 3 };
3391
3413
aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
3392
-
3393
- if (ggml_cann_type_mapping (dst->type ) != faDataType){
3394
- ggml_cann_pool_alloc perm_out_f16_allocator (ctx.pool ());
3395
- perm_out_f16_allocator.alloc (ggml_nelements (dst) * faElemSize);
3396
- void * perm_out_f16_buffer = perm_out_f16_allocator.get ();
3397
-
3398
- int64_t * perm_out_f16_ne = dst->ne ;
3399
- size_t perm_out_f16_nb[GGML_MAX_DIMS];
3400
- perm_out_f16_nb[0 ] = faElemSize;
3401
- for (int i = 1 ; i < GGML_MAX_DIMS; ++i){
3402
- perm_out_f16_nb[i] = perm_out_f16_nb[i - 1 ] * perm_out_f16_ne[i - 1 ];
3403
- }
3404
- aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor (
3405
- perm_out_f16_buffer, faDataType, faElemSize,
3406
- perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3407
- aclnn_permute (ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3408
- aclnn_cast (ctx,
3409
- acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3410
- ggml_cann_release_resources (ctx, acl_perm_out_f16_tensor);
3411
- }else {
3412
- // only need to permute
3413
- aclnn_permute (ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3414
- }
3414
+ // TODO: when dst is fp16, don't need cast
3415
+ aclnn_cast (ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping (dst->type ));
3415
3416
ggml_cann_release_resources (ctx, acl_src0_f16_tensor,
3416
3417
acl_src1_f16_tensor,
3417
3418
acl_src2_f16_tensor,
0 commit comments