@@ -385,36 +385,26 @@ bool ggml_qnn_matmul_op_config::create_mat_mul_nodes(QNNBackend device, Qnn_Grap
385385 * [5, 4],
386386 * ])
387387 * # Perform matrix multiplication
388- * result = torch.matmul(A, B.T)
389- * print(result .T)
388+ * C = torch.matmul(A, B.T)
389+ * print(C .T)
390390 * ```
391391 * Here, the B.T is the transpose of B.
392+ * So C.T = A * B.T which is equivalent to C = B * A.T.
393+ * See: https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md
392394 *
393395 * So here we need to create graph like:
394396 * ```mermaid
395397 * graph TD;
396- * i1>ggml_tensor_in0] --src0--> mat_mul0;
397- * i2>ggml_tensor_in1] --src1--> mat_mul0;
398- * mat_mul0 --dst_trans--> transpose_out;
399- * transpose1 --dst0--> o1>ggml_tensor_out];
398+ * i1>ggml_tensor_in0] --src1--> mat_mul0;
399+ * i2>ggml_tensor_in1] --src0.T--> mat_mul0;
400+ * mat_mul0 --dst0--> o1>ggml_tensor_out];
400401 * ```
401402 */
402403
403404 // create src0_trans tensor
404- auto src1 = tensor_inputs.back ();
405405 static_assert (GGML_MAX_DIMS == 4 , " GGML_MAX_DIMS does not match the expected value" );
406-
407- qnn_dimension_array_t dimensions = get_transposed_dimensions (src1->get_dimensions (), rank);
408-
409- // create dst_trans tensor
410- auto dst = tensor_outputs.front ();
411- dimensions = get_transposed_dimensions (dst->get_dimensions (), rank);
412- auto dst_trans = std::make_shared<ggml_qnn_tensor>(ggml_qnn_tensor::INTERMEDIATE, " dst_trans" , dimensions,
413- dst->get_data_type (), rank, device, graph_handle, _qnn_instance);
414-
415- // create transpose_out
416- auto transpose_out = std::make_shared<ggml_qnn_single_op_config>(_name + " _trans1" , QNN_OP_PACKAGE_NAME_QTI_AISW,
417- QNN_OP_TRANSPOSE, _qnn_instance);
406+ GGML_ASSERT (tensor_inputs.size () == 2 );
407+ GGML_ASSERT (tensor_outputs.size () == 1 );
418408
419409 // create mat_mul
420410 auto mat_mul =
@@ -425,24 +415,12 @@ bool ggml_qnn_matmul_op_config::create_mat_mul_nodes(QNNBackend device, Qnn_Grap
425415 scalar.bool8Value = 1 ;
426416 mat_mul->add_scalar_param (QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, scalar);
427417
428- // set transpose_out parameters
429- auto *params_data = reinterpret_cast <const uint8_t *>(kTransposeParamData [rank - 1 ].data ());
430- const qnn_dimension_array_t param_dims = {(uint32_t )rank, 1 , 1 , 1 };
431- transpose_out->add_tensor_param (QNN_OP_TRANSPOSE_PARAM_PERM, param_dims, 1 , params_data, QNN_DATATYPE_UINT_32,
432- device, graph_handle);
433-
434418 // set tensor to mat_mul
419+ std::swap (tensor_inputs[0 ], tensor_inputs[1 ]);
435420 mat_mul->set_input_tensors (tensor_inputs);
436- qnn_tensor_array_t tensors = {dst_trans};
437- mat_mul->set_output_tensors (tensors);
438-
439- // set tensor to transpose_out
440- tensors = {dst_trans};
441- transpose_out->set_input_tensors (tensors);
442- transpose_out->set_output_tensors (tensor_outputs);
421+ mat_mul->set_output_tensors (tensor_outputs);
443422
444423 _operations.push_back (mat_mul);
445- _operations.push_back (transpose_out);
446424 return true ;
447425}
448426
0 commit comments