Skip to content

Commit c380305

Browse files
committed
cann: update the constraints of flash_attn_ext on ggml-cann.cpp
1 parent 092ccf6 commit c380305

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

ggml/src/ggml-cann/ggml-cann.cpp

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "ggml-backend-impl.h"
3737
#include "ggml-cann/aclnn_ops.h"
3838
#include "ggml-cann/common.h"
39+
#include "ggml.h"
3940

4041
#define GGML_COMMON_DECL_C
4142

@@ -2165,7 +2166,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21652166
case GGML_OP_COUNT_EQUAL:
21662167
return true;
21672168
case GGML_OP_FLASH_ATTN_EXT:{
2168-
// copy from [ggml-cuda.cu]
2169+
// derived from [ggml-cuda.cu]
2170+
if(op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16){
2171+
return false;
2172+
}
2173+
if(op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 && op->src[1]->type != GGML_TYPE_BF16){
2174+
return false;
2175+
}
2176+
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2177+
return false;
2178+
}
21692179
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
21702180
// different head sizes of K and V are not supported yet
21712181
return false;
@@ -2180,19 +2190,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21802190
if (op->src[0]->ne[3] != 1) {
21812191
return false;
21822192
}
2183-
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
2193+
float logitSoftcap = 0.0f;
2194+
memcpy(&logitSoftcap, (float*)op->op_params + 2, sizeof(float));
2195+
if(logitSoftcap != 0.0f) {
21842196
return false;
21852197
}
2186-
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
2187-
return true;
2188-
}
2189-
if (op->src[0]->ne[0] == 128) {
2190-
return true;
2191-
}
2192-
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
2193-
return true;
2194-
}
2195-
return op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
2198+
return true;
21962199
}
21972200
default:
21982201
return false;

0 commit comments

Comments
 (0)