@@ -289,6 +289,7 @@ void ggml_qnn_general_node(ggml_backend_qnn_context * ctx, ggml_tensor * op) {
289289 * than ggml_qnn_mul_mat, so it's a standalone function.
290290 * it will be combined with ggml_qnn_mul_mat after bugfix
291291 */
292+
292293static void ggml_qnn_mul_mat_4d (ggml_backend_qnn_context *ctx, ggml_tensor *op) {
293294 Qnn_ErrorHandle_t error = QNN_SUCCESS;
294295 bool graph_initialized = false ;
@@ -313,6 +314,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
313314 Qnn_GraphHandle_t graph_handle = nullptr ;
314315 Qnn_Tensor_t *p_tensor0 = nullptr ;
315316 Qnn_Tensor_t *p_reshape0_out = nullptr ;
317+ Qnn_Tensor_t *p_tile0_out = nullptr ;
316318 Qnn_Tensor_t *p_tensor1 = nullptr ;
317319 Qnn_Tensor_t *p_permute1_out = nullptr ;
318320 Qnn_Tensor_t *p_reshape1_out = nullptr ;
@@ -326,34 +328,34 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
326328 qnn_tensors_t &tensors = std::get<1 >(graph_item);
327329 p_tensor0 = tensors[0 ];
328330 p_reshape0_out = tensors[1 ];
329- p_tensor1 = tensors[2 ];
330- p_permute1_out = tensors[3 ];
331- p_reshape1_out = tensors[4 ];
332- p_matmul_out = tensors[5 ];
333- p_reshape2_out = tensors[6 ];
331+ p_tile0_out = tensors[2 ];
332+ p_tensor1 = tensors[3 ];
333+ p_permute1_out = tensors[4 ];
334+ p_reshape1_out = tensors[5 ];
335+ p_matmul_out = tensors[6 ];
336+ p_reshape2_out = tensors[7 ];
334337 } else {
335338 CHECK_QNN_API (error, qnn_raw_interface.graphCreate (instance->get_qnn_context_handle (),
336339 graph_name.c_str (), NULL , &graph_handle));
337340
338341 // Define dimensions
339- uint32_t B0 = src0->ne [2 ] * src0->ne [3 ]; // src0 batch: 3 * 2 = 6
340- uint32_t B1 = src1->ne [2 ] * src1->ne [3 ]; // src1 batch: 6 * 4 = 24
341- uint32_t M = src0->ne [1 ]; // 16
342- uint32_t K = src0->ne [0 ]; // 256
343- uint32_t N = src1->ne [1 ]; // 1 (second case), 16 (first case)
344-
345- // Validate K matches
346- GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]); // K must match: 256 == 256
347- // Output shape should match src1's batch dims
342+ uint32_t K = src0->ne [0 ]; // Inner dimension
343+ uint32_t M = src0->ne [1 ]; // Rows of src0
344+ uint32_t N = src1->ne [1 ]; // Columns of src1
345+ uint32_t B0 = src0->ne [2 ] * src0->ne [3 ]; // src0 batch
346+ uint32_t B1 = src1->ne [2 ] * src1->ne [3 ]; // src1 batch (drives output)
347+
348+ // Validate
349+ GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]); // K must match
348350 GGML_ASSERT (dst->ne [0 ] == N && dst->ne [1 ] == M && dst->ne [2 ] == src1->ne [2 ] && dst->ne [3 ] == src1->ne [3 ]);
349351
350- // src0: [256, 16, 3, 2 ] -> QNN: [2, 3, 16, 256] (B, H, M, K)
352+ // src0: [K, M, H0, B0 ] -> QNN: [B0, H0, M, K]
351353 uint32_t src0_dims[] = {static_cast <uint32_t >(src0->ne [3 ]), static_cast <uint32_t >(src0->ne [2 ]), static_cast <uint32_t >(src0->ne [1 ]), static_cast <uint32_t >(src0->ne [0 ])};
352354 p_tensor0 = GQCGT (src0, " input0" , QNN_TENSOR_TYPE_APP_WRITE, QNN_DATATYPE_FLOAT_32, 4 ,
353355 src0_dims, nullptr , 0 );
354356 CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_tensor0));
355357
356- // Reshape src0 to [6, 16, 256] for [ B0, M, K]
358+ // Reshape src0 to [B0, M, K]
357359 uint32_t reshape0_out_dims[] = {B0, M, K};
358360 p_reshape0_out = GQCGT (nullptr , " reshape0_out" , QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3 ,
359361 reshape0_out_dims, nullptr , 0 );
@@ -365,19 +367,37 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
365367 reshape0_inputs, 1 , reshape0_outputs, 1 );
366368 CHECK_QNN_API (error, qnn_raw_interface.graphAddNode (graph_handle, reshape0_op));
367369
368- // src1: [256, 1, 6, 4] -> QNN: [4, 6, 1, 256] (B, H, N, K)
370+ // Tile src0 to match B1: [B0, M, K] -> [B1, M, K]
371+ uint32_t tile0_out_dims[] = {B1, M, K};
372+ p_tile0_out = GQCGT (nullptr , " tile0_out" , QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3 ,
373+ tile0_out_dims, nullptr , 0 );
374+ CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_tile0_out));
375+ uint32_t tile_multiples[] = {B1 / B0, 1 , 1 }; // e.g., 24/6 = 4, 6/6 = 1
376+ uint32_t tile_dims[] = {3 };
377+ Qnn_Tensor_t *p_tile_multiples = GQCGT (nullptr , " tile_multiples" , QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1 ,
378+ tile_dims, tile_multiples, sizeof (tile_multiples));
379+ CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_tile_multiples));
380+ Qnn_Param_t tile_params[] = {{QNN_PARAMTYPE_TENSOR, " multiples" , .tensorParam = *p_tile_multiples}};
381+ Qnn_Tensor_t tile0_inputs[] = {*p_reshape0_out};
382+ Qnn_Tensor_t tile0_outputs[] = {*p_tile0_out};
383+ Qnn_OpConfig_t tile0_op = ggmlqnn_create_op_config (" tile0" , QNN_OP_PACKAGE_NAME_QTI_AISW,
384+ QNN_OP_TILE, tile_params, 1 ,
385+ tile0_inputs, 1 , tile0_outputs, 1 );
386+ CHECK_QNN_API (error, qnn_raw_interface.graphAddNode (graph_handle, tile0_op));
387+
388+ // src1: [N, K, H1, B1] -> QNN: [B1, H1, N, K]
369389 uint32_t src1_dims[] = {static_cast <uint32_t >(src1->ne [3 ]), static_cast <uint32_t >(src1->ne [2 ]), static_cast <uint32_t >(src1->ne [1 ]), static_cast <uint32_t >(src1->ne [0 ])};
370390 p_tensor1 = GQCGT (src1, " input1" , QNN_TENSOR_TYPE_APP_WRITE, QNN_DATATYPE_FLOAT_32, 4 ,
371391 src1_dims, nullptr , 0 );
372392 CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_tensor1));
373393
374- // Permute src1 to [4, 6, 256, 1] to align K and N
375- uint32_t perm_data[] = {0 , 1 , 3 , 2 }; // [B, H, N, K] -> [B, H, K, N]
394+ // Permute src1 to [B1, H1, K, N]
395+ uint32_t perm_data[] = {0 , 1 , 3 , 2 };
376396 uint32_t perm_dims[] = {4 };
377- Qnn_Tensor_t * p_perm = GQCGT (nullptr , " perm" , QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1 ,
378- perm_dims, perm_data, sizeof (perm_data));
397+ Qnn_Tensor_t *p_perm = GQCGT (nullptr , " perm" , QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1 ,
398+ perm_dims, perm_data, sizeof (perm_data));
379399 CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_perm));
380- uint32_t permute1_out_dims[] = {static_cast <uint32_t >(src1->ne [3 ]), static_cast <uint32_t >(src1->ne [2 ]), static_cast <uint32_t >(src1->ne [0 ]), static_cast <uint32_t >(src1->ne [1 ])}; // [4, 6, 256, 1]
400+ uint32_t permute1_out_dims[] = {static_cast <uint32_t >(src1->ne [3 ]), static_cast <uint32_t >(src1->ne [2 ]), static_cast <uint32_t >(src1->ne [0 ]), static_cast <uint32_t >(src1->ne [1 ])};
381401 p_permute1_out = GQCGT (nullptr , " permute1_out" , QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 4 ,
382402 permute1_out_dims, nullptr , 0 );
383403 CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_permute1_out));
@@ -389,7 +409,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
389409 permute1_inputs, 1 , permute1_outputs, 1 );
390410 CHECK_QNN_API (error, qnn_raw_interface.graphAddNode (graph_handle, permute1_op));
391411
392- // Reshape src1 to [24, 256, 1] for [ B1, K, N]
412+ // Reshape src1 to [B1, K, N]
393413 uint32_t reshape1_out_dims[] = {B1, K, N};
394414 p_reshape1_out = GQCGT (nullptr , " reshape1_out" , QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3 ,
395415 reshape1_out_dims, nullptr , 0 );
@@ -401,23 +421,19 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
401421 reshape1_inputs, 1 , reshape1_outputs, 1 );
402422 CHECK_QNN_API (error, qnn_raw_interface.graphAddNode (graph_handle, reshape1_op));
403423
404- // MatMul: [6, 16, 256] x [24, 256, 1] -> Needs adjustment for broadcasting
405- // Adjust src0 to match B1 by repeating or reshaping
406- uint32_t matmul_out_dims[] = {B1, M, N}; // [24, 16, 1]
424+ // MatMul: [B1, M, K] x [B1, K, N] -> [B1, M, N]
425+ uint32_t matmul_out_dims[] = {B1, M, N};
407426 p_matmul_out = GQCGT (nullptr , " matmul_out" , QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3 ,
408427 matmul_out_dims, nullptr , 0 );
409428 CHECK_QNN_API (error, qnn_raw_interface.tensorCreateGraphTensor (graph_handle, p_matmul_out));
410-
411- // Note: QNN MatMul doesn't broadcast; we need to tile src0
412- // For simplicity, assume dst shape drives execution; adjust src0 later if needed
413- Qnn_Tensor_t matmul_inputs[] = {*p_reshape0_out, *p_reshape1_out};
429+ Qnn_Tensor_t matmul_inputs[] = {*p_tile0_out, *p_reshape1_out};
414430 Qnn_Tensor_t matmul_outputs[] = {*p_matmul_out};
415431 Qnn_OpConfig_t matmul_op = ggmlqnn_create_op_config (" matmul" , QNN_OP_PACKAGE_NAME_QTI_AISW,
416432 QNN_OP_MAT_MUL, nullptr , 0 ,
417433 matmul_inputs, 2 , matmul_outputs, 1 );
418434 CHECK_QNN_API (error, qnn_raw_interface.graphAddNode (graph_handle, matmul_op));
419435
420- // Output: [1, 16, 6, 4 ] -> QNN: [4, 6, 16, 1 ]
436+ // Output: [N, M, H1, B1 ] -> QNN: [B1, H1, M, N ]
421437 uint32_t reshape2_out_dims[] = {static_cast <uint32_t >(dst->ne [3 ]), static_cast <uint32_t >(dst->ne [2 ]), static_cast <uint32_t >(dst->ne [1 ]), static_cast <uint32_t >(dst->ne [0 ])};
422438 p_reshape2_out = GQCGT (dst, " output" , QNN_TENSOR_TYPE_APP_READ, QNN_DATATYPE_FLOAT_32, 4 ,
423439 reshape2_out_dims, nullptr , 0 );
@@ -433,7 +449,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
433449 CHECK_QNN_API (error, qnn_raw_interface.graphFinalize (graph_handle, NULL , NULL ));
434450
435451 // Cache
436- qnn_tensors_t ggml_op_mulmat_tensors = {p_tensor0, p_reshape0_out, p_tensor1, p_permute1_out, p_reshape1_out, p_matmul_out, p_reshape2_out};
452+ qnn_tensors_t ggml_op_mulmat_tensors = {p_tensor0, p_reshape0_out, p_tile0_out, p_tensor1, p_permute1_out, p_reshape1_out, p_matmul_out, p_reshape2_out};
437453 instance->_qnn_graph_map [graph_name] = std::make_tuple (graph_handle, ggml_op_mulmat_tensors);
438454 }
439455
@@ -455,7 +471,6 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
455471
456472 op_perf.info ();
457473}
458-
459474/*
460475 * @brief performs matrix multiplication with FP32 & quantized weights and floating-point inputs
461476 * using the QNN backend. this function performs matrix multiplication of the input tensor
0 commit comments