@@ -289,7 +289,6 @@ void ggml_qnn_general_node(ggml_backend_qnn_context * ctx, ggml_tensor * op) {
289289 * than ggml_qnn_mul_mat, so it's a standalone function.
290290 * it will be combined with ggml_qnn_mul_mat after bugfix
291291 */
292-
293292static void ggml_qnn_mul_mat_4d (ggml_backend_qnn_context *ctx, ggml_tensor *op) {
294293 Qnn_ErrorHandle_t error = QNN_SUCCESS;
295294 bool graph_initialized = false ;
@@ -347,7 +346,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
347346
348347 // Validate
349348 GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]); // K must match
350- GGML_ASSERT (dst->ne [0 ] == N && dst->ne [1 ] == M && dst->ne [2 ] == src1->ne [2 ] && dst->ne [3 ] == src1->ne [3 ]);
349+ // GGML_ASSERT(dst->ne[0] == N && dst->ne[1] == M && dst->ne[2] == src1->ne[2] && dst->ne[3] == src1->ne[3]);
351350
352351 // src0: [K, M, H0, B0] -> QNN: [B0, H0, M, K]
353352 uint32_t src0_dims[] = {static_cast <uint32_t >(src0->ne [3 ]), static_cast <uint32_t >(src0->ne [2 ]), static_cast <uint32_t >(src0->ne [1 ]), static_cast <uint32_t >(src0->ne [0 ])};
@@ -372,7 +371,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
372371 p_tile0_out = GQCGT (nullptr , " tile0_out" , QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3 ,
373372 tile0_out_dims, nullptr , 0 );
374373 CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_tile0_out));
375- uint32_t tile_multiples[] = {B1 / B0, 1 , 1 }; // e.g., 24/6 = 4, 6/6 = 1
374+ uint32_t tile_multiples[] = {B1 / B0, 1 , 1 };
376375 uint32_t tile_dims[] = {3 };
377376 Qnn_Tensor_t *p_tile_multiples = GQCGT (nullptr , " tile_multiples" , QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1 ,
378377 tile_dims, tile_multiples, sizeof (tile_multiples));
@@ -465,6 +464,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
465464
466465 // Log dst for debugging
467466 float *dst_data = (float *)dst->data ;
467+ GGMLQNN_LOG_DEBUG (" dst shape: [%d, %d, %d, %d]\n " , dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], dst->ne [3 ]);
468468 for (int i = 0 ; i < dst->ne [0 ] * dst->ne [1 ] * dst->ne [2 ] * dst->ne [3 ]; i++) {
469469 GGMLQNN_LOG_DEBUG (" dst[%d] = %f\n " , i, dst_data[i]);
470470 }
0 commit comments