Skip to content

Commit ff033e1

Browse files
authored
1 parent 84328ff commit ff033e1

File tree

1 file changed

+11
-33
lines changed

1 file changed

+11
-33
lines changed

ggml/src/ggml-qnn/op-config-impl.cpp

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -385,36 +385,26 @@ bool ggml_qnn_matmul_op_config::create_mat_mul_nodes(QNNBackend device, Qnn_Grap
385385
* [5, 4],
386386
* ])
387387
* # Perform matrix multiplication
388-
* result = torch.matmul(A, B.T)
389-
* print(result.T)
388+
* C = torch.matmul(A, B.T)
389+
* print(C.T)
390390
* ```
391391
* Here, the B.T is the transpose of B.
392+
* So C.T = A * B.T which is equivalent to C = B * A.T.
393+
* See: https://github.com/ggml-org/llama.cpp/blob/master/CONTRIBUTING.md
392394
*
393395
* So here we need to create graph like:
394396
* ```mermaid
395397
* graph TD;
396-
* i1>ggml_tensor_in0] --src0--> mat_mul0;
397-
* i2>ggml_tensor_in1] --src1--> mat_mul0;
398-
* mat_mul0 --dst_trans--> transpose_out;
399-
* transpose1 --dst0--> o1>ggml_tensor_out];
398+
* i1>ggml_tensor_in0] --src1--> mat_mul0;
399+
* i2>ggml_tensor_in1] --src0.T--> mat_mul0;
400+
* mat_mul0 --dst0--> o1>ggml_tensor_out];
400401
* ```
401402
*/
402403

403404
// create src0_trans tensor
404-
auto src1 = tensor_inputs.back();
405405
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS does not match the expected value");
406-
407-
qnn_dimension_array_t dimensions = get_transposed_dimensions(src1->get_dimensions(), rank);
408-
409-
// create dst_trans tensor
410-
auto dst = tensor_outputs.front();
411-
dimensions = get_transposed_dimensions(dst->get_dimensions(), rank);
412-
auto dst_trans = std::make_shared<ggml_qnn_tensor>(ggml_qnn_tensor::INTERMEDIATE, "dst_trans", dimensions,
413-
dst->get_data_type(), rank, device, graph_handle, _qnn_instance);
414-
415-
// create transpose_out
416-
auto transpose_out = std::make_shared<ggml_qnn_single_op_config>(_name + "_trans1", QNN_OP_PACKAGE_NAME_QTI_AISW,
417-
QNN_OP_TRANSPOSE, _qnn_instance);
406+
GGML_ASSERT(tensor_inputs.size() == 2);
407+
GGML_ASSERT(tensor_outputs.size() == 1);
418408

419409
// create mat_mul
420410
auto mat_mul =
@@ -425,24 +415,12 @@ bool ggml_qnn_matmul_op_config::create_mat_mul_nodes(QNNBackend device, Qnn_Grap
425415
scalar.bool8Value = 1;
426416
mat_mul->add_scalar_param(QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, scalar);
427417

428-
// set transpose_out parameters
429-
auto *params_data = reinterpret_cast<const uint8_t *>(kTransposeParamData[rank - 1].data());
430-
const qnn_dimension_array_t param_dims = {(uint32_t)rank, 1, 1, 1};
431-
transpose_out->add_tensor_param(QNN_OP_TRANSPOSE_PARAM_PERM, param_dims, 1, params_data, QNN_DATATYPE_UINT_32,
432-
device, graph_handle);
433-
434418
// set tensor to mat_mul
419+
std::swap(tensor_inputs[0], tensor_inputs[1]);
435420
mat_mul->set_input_tensors(tensor_inputs);
436-
qnn_tensor_array_t tensors = {dst_trans};
437-
mat_mul->set_output_tensors(tensors);
438-
439-
// set tensor to transpose_out
440-
tensors = {dst_trans};
441-
transpose_out->set_input_tensors(tensors);
442-
transpose_out->set_output_tensors(tensor_outputs);
421+
mat_mul->set_output_tensors(tensor_outputs);
443422

444423
_operations.push_back(mat_mul);
445-
_operations.push_back(transpose_out);
446424
return true;
447425
}
448426

0 commit comments

Comments
 (0)