3636#include " ggml-backend-impl.h"
3737#include " ggml-cann/aclnn_ops.h"
3838#include " ggml-cann/common.h"
39+ #include " ggml.h"
3940
4041#define GGML_COMMON_DECL_C
4142
@@ -2165,7 +2166,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21652166 case GGML_OP_COUNT_EQUAL:
21662167 return true ;
21672168 case GGML_OP_FLASH_ATTN_EXT:{
2168- // copy from [ggml-cuda.cu]
2169+ // derived from [ggml-cuda.cu]
2170+ if (op->src [1 ]->type != GGML_TYPE_F16 || op->src [2 ]->type != GGML_TYPE_F16){
2171+ return false ;
2172+ }
2173+ if (op->src [1 ]->type != GGML_TYPE_F16 && op->src [1 ]->type != GGML_TYPE_F32 && op->src [1 ]->type != GGML_TYPE_BF16){
2174+ return false ;
2175+ }
2176+ if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
2177+ return false ;
2178+ }
21692179 if (op->src [1 ]->ne [0 ] != op->src [2 ]->ne [0 ]) {
21702180 // different head sizes of K and V are not supported yet
21712181 return false ;
@@ -2180,19 +2190,12 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
21802190 if (op->src [0 ]->ne [3 ] != 1 ) {
21812191 return false ;
21822192 }
2183- if (op->src [1 ]->type == GGML_TYPE_BF16 || op->src [2 ]->type == GGML_TYPE_BF16) {
2193+ float logitSoftcap = 0 .0f ;
2194+ memcpy (&logitSoftcap, (float *)op->op_params + 2 , sizeof (float ));
2195+ if (logitSoftcap != 0 .0f ) {
21842196 return false ;
21852197 }
2186- if (op->src [0 ]->ne [0 ] == 64 && op->src [1 ]->type == GGML_TYPE_F16) {
2187- return true ;
2188- }
2189- if (op->src [0 ]->ne [0 ] == 128 ) {
2190- return true ;
2191- }
2192- if (op->src [0 ]->ne [0 ] == 256 && op->src [1 ]->type == GGML_TYPE_F16 && op->src [2 ]->type == GGML_TYPE_F16) {
2193- return true ;
2194- }
2195- return op->src [1 ]->type == GGML_TYPE_F16 && op->src [2 ]->type == GGML_TYPE_F16;
2198+ return true ;
21962199 }
21972200 default :
21982201 return false ;
0 commit comments