Skip to content

Commit 092ccf6

Browse files
committed
cann: update the inner precise for fusedInferAttention
1 parent 1779e00 commit 092ccf6

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include <vector>
7575

7676
#include "ggml-impl.h"
77+
#include "ggml.h"
7778

7879
#define GGML_COMMON_DECL_C
7980

@@ -2611,7 +2612,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26112612
aclTensor* acl_src0_f16_tensor = nullptr;
26122613
aclTensor* acl_src1_f16_tensor = nullptr;
26132614
aclTensor* acl_src2_f16_tensor = nullptr;
2614-
aclTensor* acl_src3_f16_tensor = nullptr;
26152615
aclTensor* acl_dst_f16_tensor = nullptr;
26162616

26172617
// Step 1: cast the src0 (Query) to fp16 if needed
@@ -2845,7 +2845,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28452845
int64_t nextTokens = 65535;
28462846
char layout[5] = {'B', 'N', 'S', 'D', 0};
28472847
int64_t sparseMode = 0;
2848-
int64_t innerPrecise = 2;
2848+
int64_t innerPrecise = (src0->ne[1] == 1) ? 0 : 2;
28492849
int64_t blockSize = 0;
28502850
int64_t antiquantMode = 0;
28512851
bool softmaxLseFlag = false;
@@ -2915,7 +2915,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
29152915
ggml_cann_release_resources(ctx, bcast_pse_tensor);
29162916
}
29172917
}else{
2918-
GGML_ABORT("Function not implemented");
2918+
GGML_ABORT("Function is not implemented.");
29192919
}
2920-
}
2921-
2920+
}

0 commit comments

Comments
 (0)