Skip to content

Commit 8a112f0

Browse files
committed
cann: add some comments and update the CANN.md
1 parent b266beb commit 8a112f0

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,24 +65,15 @@
6565
#include <aclnnop/aclnn_eq_tensor.h>
6666
#include <aclnnop/aclnn_gt_scalar.h>
6767
#include <aclnnop/aclnn_pow.h>
68+
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
6869
#include <float.h>
6970

7071
#include <cmath>
7172
#include <cstring>
7273
#include <exception>
7374
#include <vector>
7475

75-
#include <iostream>
76-
#include <fstream>
77-
#include <string>
78-
#include <cstring>
79-
80-
#include "aclnnop/aclnn_flash_attention_score.h"
81-
#include "aclnnop/aclnn_logical_not.h"
82-
83-
#include "ggml-cann/acl_tensor.h"
8476
#include "ggml-impl.h"
85-
#include "ggml.h"
8677

8778
#define GGML_COMMON_DECL_C
8879

@@ -2623,7 +2614,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26232614
aclTensor* acl_src3_f16_tensor = nullptr;
26242615
aclTensor* acl_dst_f16_tensor = nullptr;
26252616

2626-
// Step 1: cast the src0 (Query) to fp16
2617+
// Step 1: cast the src0 (Query) to fp16 if needed
26272618
ggml_cann_pool_alloc src0_f16_allocator(ctx.pool());
26282619
void* src0_f16_buffer = nullptr;
26292620

@@ -2649,6 +2640,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26492640
acl_src0_f16_tensor = ggml_cann_create_tensor(src0);
26502641
}
26512642

2643+
// Step 2: create the acl tensors for src1 (Key), src2 (Value),
2644+
// and the direct output from FusedInferAttention
2645+
26522646
acl_src1_f16_tensor = ggml_cann_create_tensor(src1);
26532647
acl_src2_f16_tensor = ggml_cann_create_tensor(src2);
26542648

@@ -2668,21 +2662,23 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26682662
out_f16_ne, out_f16_nb, GGML_MAX_DIMS
26692663
);
26702664

2671-
aclTensor* bcast_pse_tensor = nullptr;
26722665

2666+
// Step 3: create the PSEShift tensor if needed
2667+
// this tensor is considered as mask (f16) in the llama.cpp
2668+
2669+
aclTensor* bcast_pse_tensor = nullptr;
26732670
int64_t bcast_pse_ne[GGML_MAX_DIMS];
26742671
size_t bcast_pse_nb[GGML_MAX_DIMS];
26752672
ggml_cann_pool_alloc bcast_pse_allocator(ctx.pool());
26762673
void* bcast_pse_buffer = nullptr;
2677-
if(src3)
2674+
2675+
if(src3 != nullptr){
26782676
bcast_pse_buffer = bcast_pse_allocator.alloc(
26792677
ggml_nelements(src3) * src0->ne[2] * sizeof(uint16_t));
2680-
2681-
if(src3 != nullptr){
2682-
// broadcast pse
2678+
26832679
if(src0->ne[1] > 1){
2680+
// Case 1: broadcast pse for prefill stage with multiple head
26842681
aclTensor* acl_mask_f16_tensor = ggml_cann_create_tensor(src3);
2685-
26862682
bcast_pse_ne[0] = src3->ne[0];
26872683
bcast_pse_ne[1] = src3->ne[1];
26882684
bcast_pse_ne[2] = src0->ne[2];
@@ -2702,6 +2698,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
27022698

27032699
ggml_cann_release_resources(ctx, acl_mask_f16_tensor);
27042700
}else{
2701+
// Case 2: trunc the first row and broadcast pse for decode stage with multiple head
27052702
int64_t trunc_pse_ne[GGML_MAX_DIMS] = {src3->ne[0], src0->ne[1], src3->ne[2], src3->ne[3]};
27062703
size_t* trunc_pse_nb = src3->nb;
27072704

@@ -2729,6 +2726,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
27292726
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
27302727
}
27312728

2729+
// Compute the slope if needed. Derived from ggml_cann_softmax().
27322730
if(maxBias != 0.0f){
27332731
// alibi
27342732
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
@@ -2832,6 +2830,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28322830
}
28332831
}
28342832

2833+
// Step 4: set the inputs for FusedInferAttention.
28352834
int kvTensorNum = 1;
28362835
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
28372836
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
@@ -2853,9 +2852,9 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28532852
int64_t keyAntiquantMode = 0;
28542853
int64_t valueAntiquantMode = 0;
28552854

2855+
// Step 5: launch the FusedInferAttentionScoreV2 kernel.
28562856
// Refer to https://gitee.com/ascend/cann-ops-adv/blob/master/docs/FusedInferAttentionScoreV2.md
2857-
2858-
2857+
28592858
GGML_CANN_CALL_ACLNN_OP(ctx, FusedInferAttentionScoreV2,
28602859
acl_q_tensor, acl_k_tensor_list, acl_v_tensor_list, // q, k, v
28612860
bcast_pse_tensor, nullptr, // pse, mask
@@ -2880,7 +2879,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28802879
nullptr // softmaxLse
28812880
);
28822881

2883-
// Step 5: post-processing, permute and cast to f32
2882+
// Step 6: post-processing, permute and cast to f32
2883+
28842884
int64_t new_dim[] = {0, 2, 1, 3};
28852885
aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
28862886

@@ -2911,9 +2911,11 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
29112911
acl_src2_f16_tensor,
29122912
acl_dst_f16_tensor,
29132913
acl_dst_tensor);
2914-
if(src3)
2914+
if(src3 != nullptr){
29152915
ggml_cann_release_resources(ctx, bcast_pse_tensor);
2916+
}
29162917
}else{
2917-
throw std::runtime_error("Function not implemented");
2918+
GGML_ABORT("Function not implemented");
29182919
}
29192920
}
2921+

ggml/src/ggml-cann/aclnn_ops.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@
4545
#include <aclnnop/aclnn_cos.h>
4646
#include <aclnnop/aclnn_log.h>
4747
#include <aclnnop/aclnn_sign.h>
48-
#include <aclnnop/aclnn_fused_infer_attention_score_v2.h>
49-
#include <aclnnop/aclnn_isneginf.h>
5048
#include "acl_tensor.h"
5149
#include "common.h"
5250

0 commit comments

Comments
 (0)