Skip to content

Commit 1e74897

Browse files
authored
CANN: refactor mask handling and improve performance in FA (ggml-org#15561)
* CANN(flash-attn): refactor mask handling and improve performance 1. Refactored the mask computation in Flash Attention, unified the logic without separating prefill and decode. 2. Optimized performance in non-alibi scenarios by reducing one repeat operation. 3. Updated operator management to explicitly mark unsupported cases on 310P devices and when dim is not divisible by 16. Signed-off-by: noemotiovon <[email protected]> * [CANN]: fix review Signed-off-by: noemotiovon <[email protected]> * [CANN]: Optimization FA BNSD to BSND Signed-off-by: noemotiovon <[email protected]> --------- Signed-off-by: noemotiovon <[email protected]>
1 parent 1cf123a commit 1e74897

File tree

2 files changed

+98
-89
lines changed

2 files changed

+98
-89
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 88 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,17 +1427,17 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
14271427
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
14281428
float m, int64_t size, float start, float stop, float step){
14291429
int64_t ne[] = {size};
1430-
size_t nb[] = {sizeof(float)};
1430+
size_t nb[] = {sizeof(uint16_t)};
14311431

1432-
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(float));
1432+
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
14331433
void* arange_buffer = arange_allocator.get();
14341434

14351435
aclTensor* arange_tensor = ggml_cann_create_tensor(
1436-
arange_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
1436+
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
14371437
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
14381438

14391439
aclTensor* slope_tensor = ggml_cann_create_tensor(
1440-
slope_buffer, ACL_FLOAT, sizeof(float), ne, nb, 1);
1440+
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
14411441

14421442
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
14431443

@@ -3180,11 +3180,38 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31803180

31813181
void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
31823182

3183-
ggml_tensor* src0 = dst->src[0]; // q, fp32
3184-
ggml_tensor* src1 = dst->src[1]; // k, fp16
3185-
ggml_tensor* src2 = dst->src[2]; // v, fp16
3183+
ggml_tensor* src0 = dst->src[0]; // q, fp32 | B, N, S, D (uncont) -> B, S, N, D (cont)
3184+
ggml_tensor* src1 = dst->src[1]; // k, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
3185+
ggml_tensor* src2 = dst->src[2]; // v, fp16 | B, N, S, D (uncont) -> B, S, N, D (cont)
31863186
ggml_tensor* src3 = dst->src[3]; // mask, fp16
31873187

3188+
// B, N, S, D (uncont) -> B, S, N, D (cont)
3189+
int64_t src0_bsnd_ne[GGML_MAX_DIMS];
3190+
memcpy(src0_bsnd_ne, src0->ne, GGML_MAX_DIMS * sizeof(int64_t));
3191+
size_t src0_bsnd_nb[GGML_MAX_DIMS];
3192+
memcpy(src0_bsnd_nb, src0->nb, GGML_MAX_DIMS * sizeof(size_t));
3193+
int64_t src1_bsnd_ne[GGML_MAX_DIMS];
3194+
memcpy(src1_bsnd_ne, src1->ne, GGML_MAX_DIMS * sizeof(int64_t));
3195+
size_t src1_bsnd_nb[GGML_MAX_DIMS];
3196+
memcpy(src1_bsnd_nb, src1->nb, GGML_MAX_DIMS * sizeof(size_t));
3197+
int64_t src2_bsnd_ne[GGML_MAX_DIMS];
3198+
memcpy(src2_bsnd_ne, src2->ne, GGML_MAX_DIMS * sizeof(int64_t));
3199+
size_t src2_bsnd_nb[GGML_MAX_DIMS];
3200+
memcpy(src2_bsnd_nb, src2->nb, GGML_MAX_DIMS * sizeof(size_t));
3201+
3202+
auto transpose12 = [](int64_t* ne, size_t* nb) {
3203+
int64_t ne_tmp = ne[1];
3204+
size_t nb_tmp = nb[1];
3205+
ne[1] = ne[2];
3206+
nb[1] = nb[2];
3207+
ne[2] = ne_tmp;
3208+
nb[2] = nb_tmp;
3209+
};
3210+
3211+
transpose12(src0_bsnd_ne, src0_bsnd_nb);
3212+
transpose12(src1_bsnd_ne, src1_bsnd_nb);
3213+
transpose12(src2_bsnd_ne, src2_bsnd_nb);
3214+
31883215
float maxBias = 0.0f;
31893216
float scaleValue = 1.0f;
31903217
float logitSoftcap = 0.0f;
@@ -3206,11 +3233,12 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32063233
void* src0_f16_buffer = nullptr;
32073234

32083235
if(ggml_cann_type_mapping(src0->type) != faDataType){
3209-
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0);
3236+
aclTensor* acl_src0_f32_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
3237+
src0_bsnd_nb, GGML_MAX_DIMS);
32103238
src0_f16_buffer = src0_f16_allocator.alloc(
32113239
ggml_nelements(src0) * faElemSize);
32123240

3213-
int64_t* src0_f16_ne = src0->ne;
3241+
int64_t* src0_f16_ne = src0_bsnd_ne;
32143242
size_t src0_f16_nb[GGML_MAX_DIMS];
32153243
src0_f16_nb[0] = sizeof(uint16_t);
32163244
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3224,20 +3252,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32243252
aclnn_cast(ctx, acl_src0_f32_tensor, acl_src0_f16_tensor, faDataType);
32253253
ggml_cann_release_resources(ctx, acl_src0_f32_tensor);
32263254
}else{
3227-
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
3255+
acl_src0_f16_tensor = ggml_cann_create_tensor(src0, src0_bsnd_ne,
3256+
src0_bsnd_nb, GGML_MAX_DIMS);
32283257
}
32293258

32303259
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
32313260
// and the direct output from FusedInferAttention
32323261

3233-
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
3234-
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
3262+
acl_src1_f16_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne,
3263+
src1_bsnd_nb, GGML_MAX_DIMS);
3264+
acl_src2_f16_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne,
3265+
src2_bsnd_nb, GGML_MAX_DIMS);
32353266

32363267
ggml_cann_pool_alloc out_f16_allocator(ctx.pool());
32373268
void* out_f16_buffer = out_f16_allocator.alloc(
32383269
ggml_nelements(dst) * faElemSize);
32393270

3240-
int64_t* out_f16_ne = src0->ne;
3271+
int64_t* out_f16_ne = src0_bsnd_ne;
32413272
size_t out_f16_nb[GGML_MAX_DIMS];
32423273
out_f16_nb[0] = faElemSize;
32433274
for(int i = 1; i < GGML_MAX_DIMS; ++i){
@@ -3251,88 +3282,81 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32513282

32523283
// Step 3: create the PSEShift tensor if needed
32533284
// this tensor is considered as mask (f16) in the llama.cpp
3254-
32553285
aclTensor* bcast_pse_tensor = nullptr;
3256-
int64_t bcast_pse_ne[GGML_MAX_DIMS];
3257-
size_t bcast_pse_nb[GGML_MAX_DIMS];
32583286
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
3259-
void* bcast_pse_buffer = nullptr;
3260-
32613287
if(src3 != nullptr){
3262-
bcast_pse_buffer = bcast_pse_allocator.alloc(
3263-
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
3264-
3265-
if(src0->ne[1] > 1){
3266-
// Case 1: broadcast pse for prefill stage with multiple head
3267-
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
3268-
bcast_pse_ne[0] = src3->ne[0];
3269-
bcast_pse_ne[1] = src3->ne[1];
3270-
bcast_pse_ne[2] = src0->ne[2];
3271-
bcast_pse_ne[3] = src3->ne[3];
3288+
// Construct the truncated pse tensor (common for prefill/decode)
3289+
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {
3290+
src3->ne[0], // D
3291+
src0->ne[1], // S (number of Q tokens)
3292+
src3->ne[2], // mask N
3293+
src3->ne[3] // B
3294+
};
3295+
size_t* trunc_pse_nb = src3->nb;
3296+
3297+
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
3298+
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3299+
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS
3300+
);
32723301

3302+
int64_t bcast_pse_ne[GGML_MAX_DIMS];
3303+
size_t bcast_pse_nb[GGML_MAX_DIMS];
3304+
bcast_pse_ne[0] = src3->ne[0]; // D
3305+
bcast_pse_ne[1] = src0->ne[1]; // S
3306+
bcast_pse_ne[2] = src0->ne[2]; // N (num_heads)
3307+
bcast_pse_ne[3] = src3->ne[3]; // B
3308+
if (maxBias == 0.0f) {
3309+
// When maxBias == 0.0f, use nb = 0 reduce once repeat (Qwen2)
3310+
// Construct the bcast tensor (simulate repeat on the head dimension using stride=0)
32733311
bcast_pse_nb[0] = sizeof(uint16_t);
3274-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3275-
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
3276-
}
3312+
bcast_pse_nb[1] = bcast_pse_nb[0] * bcast_pse_ne[0];
3313+
bcast_pse_nb[2] = 0; // <---- the head dimension shares the same data
3314+
bcast_pse_nb[3] = src3->nb[3];
32773315

32783316
bcast_pse_tensor = ggml_cann_create_tensor(
3279-
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
3280-
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3281-
3282-
int64_t repeats[] = {1, src0->ne[2], 1, 1};
3283-
aclnn_repeat(ctx, acl_mask_f16_tensor, bcast_pse_tensor, repeats);
3284-
3285-
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
3286-
}else{
3287-
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
3288-
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
3289-
size_t* trunc_pse_nb = src3->nb;
3290-
3291-
aclTensor* acl_mask_f16_trunc_tensor = ggml_cann_create_tensor(
32923317
src3->data, ACL_FLOAT16, sizeof(uint16_t),
3293-
trunc_pse_ne, trunc_pse_nb, GGML_MAX_DIMS);
3294-
3295-
bcast_pse_ne[0] = src3->ne[0];
3296-
bcast_pse_ne[1] = src0->ne[1];
3297-
bcast_pse_ne[2] = src0->ne[2];
3298-
bcast_pse_ne[3] = src3->ne[3];
3318+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3319+
);
32993320

3321+
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3322+
} else {
33003323
bcast_pse_nb[0] = sizeof(uint16_t);
3301-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3324+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
33023325
bcast_pse_nb[i] = bcast_pse_nb[i - 1] * bcast_pse_ne[i - 1];
33033326
}
33043327

3328+
void* bcast_pse_buffer = bcast_pse_allocator.alloc(
3329+
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t)
3330+
);
3331+
33053332
bcast_pse_tensor = ggml_cann_create_tensor(
33063333
bcast_pse_buffer, ACL_FLOAT16, sizeof(uint16_t),
3307-
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS);
3334+
bcast_pse_ne, bcast_pse_nb, GGML_MAX_DIMS
3335+
);
33083336

33093337
int64_t repeats[] = {1, src0->ne[2], 1, 1};
33103338
aclnn_repeat(ctx, acl_mask_f16_trunc_tensor, bcast_pse_tensor, repeats);
33113339

3312-
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
3313-
}
3314-
3315-
// Compute the slope if needed. Derived from ggml_cann_softmax().
3316-
if(maxBias != 0.0f){
33173340
// alibi
3341+
// Compute the slope if needed. Derived from ggml_cann_softmax().
33183342
const int64_t n_heads = src0->ne[2];
3319-
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(float));
3343+
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
33203344
void* slope_buffer = slope_allocator.get();
33213345
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
33223346

33233347
int64_t slope_ne[] = {1, 1, n_heads, 1};
33243348
size_t slope_nb[GGML_MAX_DIMS];
3325-
slope_nb[0] = sizeof(float);
3349+
slope_nb[0] = sizeof(uint16_t);
33263350
for(int i = 1;i<GGML_MAX_DIMS;i++) {
33273351
slope_nb[i] = slope_nb[i-1] * slope_ne[0];
33283352
}
33293353

33303354
aclTensor* slope_tensor = ggml_cann_create_tensor(
3331-
slope_buffer, ACL_FLOAT, sizeof(float),
3355+
slope_buffer, ACL_FLOAT16, sizeof(uint16_t),
33323356
slope_ne, slope_nb, GGML_MAX_DIMS);
33333357
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, slope_tensor);
33343358

3335-
ggml_cann_release_resources(ctx, slope_tensor);
3359+
ggml_cann_release_resources(ctx, slope_tensor, acl_mask_f16_trunc_tensor);
33363360
}
33373361
}
33383362

@@ -3349,7 +3373,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33493373
// double scaleValue = 1 / sqrt(src0->ne[0]); // 1/sqrt(d)
33503374
int64_t preTokens = 65535;
33513375
int64_t nextTokens = 65535;
3352-
char layout[5] = {'B', 'N', 'S', 'D', 0};
3376+
char layout[5] = {'B', 'S', 'N', 'D', 0};
33533377
int64_t sparseMode = 0;
33543378
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
33553379
int64_t blockSize = 0;
@@ -3386,32 +3410,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33863410
);
33873411

33883412
// Step 6: post-processing, permute and cast to f32
3389-
3390-
int64_t new_dim[] = {0, 2, 1, 3};
33913413
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
3392-
3393-
if(ggml_cann_type_mapping(dst->type) != faDataType){
3394-
ggml_cann_pool_alloc perm_out_f16_allocator(ctx.pool());
3395-
perm_out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize);
3396-
void* perm_out_f16_buffer = perm_out_f16_allocator.get();
3397-
3398-
int64_t* perm_out_f16_ne = dst->ne;
3399-
size_t perm_out_f16_nb[GGML_MAX_DIMS];
3400-
perm_out_f16_nb[0] = faElemSize;
3401-
for(int i = 1; i < GGML_MAX_DIMS; ++i){
3402-
perm_out_f16_nb[i] = perm_out_f16_nb[i - 1] * perm_out_f16_ne[i - 1];
3403-
}
3404-
aclTensor* acl_perm_out_f16_tensor = ggml_cann_create_tensor(
3405-
perm_out_f16_buffer, faDataType, faElemSize,
3406-
perm_out_f16_ne, perm_out_f16_nb, GGML_MAX_DIMS);
3407-
aclnn_permute(ctx, acl_dst_f16_tensor, acl_perm_out_f16_tensor, new_dim, GGML_MAX_DIMS);
3408-
aclnn_cast(ctx,
3409-
acl_perm_out_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
3410-
ggml_cann_release_resources(ctx, acl_perm_out_f16_tensor);
3411-
}else{
3412-
// only need to permute
3413-
aclnn_permute(ctx, acl_dst_f16_tensor, acl_dst_tensor, new_dim, GGML_MAX_DIMS);
3414-
}
3414+
// TODO: when dst is fp16, don't need cast
3415+
aclnn_cast(ctx, acl_dst_f16_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
34153416
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
34163417
acl_src1_f16_tensor,
34173418
acl_src2_f16_tensor,

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,7 +2336,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
23362336
case GGML_TYPE_Q8_0:
23372337
case GGML_TYPE_Q4_0:
23382338
#ifdef ASCEND_310P
2339-
// Q4 && Q8 per group is not suppor on 310p device
2339+
// Q4 && Q8 per group is not support on 310p device
23402340
return false;
23412341
#endif
23422342
// only support contiguous for quantized types.
@@ -2354,7 +2354,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
23542354
case GGML_TYPE_Q8_0:
23552355
case GGML_TYPE_Q4_0:
23562356
#ifdef ASCEND_310P
2357-
// Q4 && Q8 per group is not suppor on 310p device
2357+
// Q4 && Q8 per group is not support on 310p device
23582358
return false;
23592359
#endif
23602360
// only support contiguous for quantized types.
@@ -2505,6 +2505,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
25052505
}
25062506
return true;
25072507
case GGML_OP_FLASH_ATTN_EXT:{
2508+
#ifdef ASCEND_310P
2509+
// FA not support on 310p device
2510+
return false;
2511+
#endif
25082512
// derived from [ggml-cuda.cu]
25092513
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
25102514
return false;
@@ -2530,6 +2534,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
25302534
// DeepSeek MLA
25312535
return false;
25322536
}
2537+
if (op->src[0]->ne[0] % 16 != 0) {
2538+
// TODO: padding to support
2539+
return false;
2540+
}
25332541
float logitSoftcap = 0.0f;
25342542
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
25352543
if(logitSoftcap != 0.0f) {

0 commit comments

Comments
 (0)