@@ -369,6 +369,31 @@ bool ggnl_qnn_supports_op_tensor(ggml_backend_qnn_device_context * ctx, const gg
369369 return true ;
370370}
371371
372+ bool ggml_qnn_have_same_tensor_types (ggml_backend_qnn_device_context * ctx, const ggml_tensor * op) {
373+ auto * src0 = op->src [0 ];
374+ auto * src1 = op->src [1 ];
375+ if (src1) {
376+ if (src0->type != op->type || src1->type != op->type ) {
377+ QNN_LOG_DEBUG (" [%s][%s]type src0(%s), src1(%s) and op(%s) are not equal\n " ,
378+ qnn::get_backend_name (ctx->device ), ggml_op_name (op->op ), ggml_type_name (src0->type ),
379+ ggml_type_name (src1->type ), ggml_type_name (op->type ));
380+ return false ;
381+ }
382+ } else {
383+ if (src0->type != op->type ) {
384+ QNN_LOG_DEBUG (" [%s][%s]type src0(%s) and op(%s) are not equal\n " , qnn::get_backend_name (ctx->device ),
385+ ggml_op_name (op->op ), ggml_type_name (src0->type ), ggml_type_name (op->type ));
386+ return false ;
387+ }
388+ }
389+
390+ #ifdef NDEBUG
391+ GGML_UNUSED (ctx);
392+ #endif
393+
394+ return true ;
395+ }
396+
372397bool ggml_qnn_supports_matmul_op (ggml_backend_qnn_device_context * ctx, const ggml_tensor * op) {
373398 constexpr const size_t kMaxNpuTensorSize = 8192L * 2048 + 8192 * 512 + 2048 * 512 ;
374399 constexpr const auto get_tensor_size = [](const ggml_tensor * tensor) -> size_t {
@@ -393,10 +418,8 @@ bool ggml_qnn_supports_matmul_op(ggml_backend_qnn_device_context * ctx, const gg
393418 // fall through, from test here, the convert op is super slow on NPU:
394419 // https://github.com/usefulsensors/qc_npu_benchmark
395420 case QNN_BACKEND_GPU:
396- if (src0-> type != src1-> type || src0-> type != op-> type ) {
421+ if (ggml_qnn_have_same_tensor_types (ctx, op) ) {
397422 // there's no convert op for GPU.
398- QNN_LOG_DEBUG (" [qnn-gpu][MUL_MAT]type src0(%s), src1(%s) and op(%s) are not equal\n " ,
399- ggml_type_name (src0->type ), ggml_type_name (src1->type ), ggml_type_name (op->type ));
400423 return false ;
401424 }
402425 break ;
@@ -472,7 +495,7 @@ bool device_supports_op(ggml_backend_qnn_device_context * ctx, const ggml_tensor
472495 break ;
473496
474497 default :
475- // default to supported
498+ is_op_supported = ggml_qnn_have_same_tensor_types (ctx, op);
476499 break ;
477500 }
478501 }
0 commit comments