Skip to content

Commit eb46bc0

Browse files
author
zhouwg
committed
ggml-qnn: AI-assisted ggml_qnn_mul_mat_4d by Grok 3 --- step11
1 parent 166d220 commit eb46bc0

File tree

1 file changed

+48
-33
lines changed

1 file changed

+48
-33
lines changed

ggml/src/ggml-qnn/ggml-qnn-ops.cpp

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ void ggml_qnn_general_node(ggml_backend_qnn_context * ctx, ggml_tensor * op) {
289289
* than ggml_qnn_mul_mat, so it's a standalone function.
290290
* it will be combined with ggml_qnn_mul_mat after bugfix
291291
*/
292+
292293
static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op) {
293294
Qnn_ErrorHandle_t error = QNN_SUCCESS;
294295
bool graph_initialized = false;
@@ -313,6 +314,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
313314
Qnn_GraphHandle_t graph_handle = nullptr;
314315
Qnn_Tensor_t *p_tensor0 = nullptr;
315316
Qnn_Tensor_t *p_reshape0_out = nullptr;
317+
Qnn_Tensor_t *p_tile0_out = nullptr;
316318
Qnn_Tensor_t *p_tensor1 = nullptr;
317319
Qnn_Tensor_t *p_permute1_out = nullptr;
318320
Qnn_Tensor_t *p_reshape1_out = nullptr;
@@ -326,34 +328,34 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
326328
qnn_tensors_t &tensors = std::get<1>(graph_item);
327329
p_tensor0 = tensors[0];
328330
p_reshape0_out = tensors[1];
329-
p_tensor1 = tensors[2];
330-
p_permute1_out = tensors[3];
331-
p_reshape1_out = tensors[4];
332-
p_matmul_out = tensors[5];
333-
p_reshape2_out = tensors[6];
331+
p_tile0_out = tensors[2];
332+
p_tensor1 = tensors[3];
333+
p_permute1_out = tensors[4];
334+
p_reshape1_out = tensors[5];
335+
p_matmul_out = tensors[6];
336+
p_reshape2_out = tensors[7];
334337
} else {
335338
CHECK_QNN_API(error, qnn_raw_interface.graphCreate(instance->get_qnn_context_handle(),
336339
graph_name.c_str(), NULL, &graph_handle));
337340

338341
// Define dimensions
339-
uint32_t B0 = src0->ne[2] * src0->ne[3]; // src0 batch: 3 * 2 = 6
340-
uint32_t B1 = src1->ne[2] * src1->ne[3]; // src1 batch: 6 * 4 = 24
341-
uint32_t M = src0->ne[1]; // 16
342-
uint32_t K = src0->ne[0]; // 256
343-
uint32_t N = src1->ne[1]; // 1 (second case), 16 (first case)
344-
345-
// Validate K matches
346-
GGML_ASSERT(src0->ne[0] == src1->ne[0]); // K must match: 256 == 256
347-
// Output shape should match src1's batch dims
342+
uint32_t K = src0->ne[0]; // Inner dimension
343+
uint32_t M = src0->ne[1]; // Rows of src0
344+
uint32_t N = src1->ne[1]; // Columns of src1
345+
uint32_t B0 = src0->ne[2] * src0->ne[3]; // src0 batch
346+
uint32_t B1 = src1->ne[2] * src1->ne[3]; // src1 batch (drives output)
347+
348+
// Validate
349+
GGML_ASSERT(src0->ne[0] == src1->ne[0]); // K must match
348350
GGML_ASSERT(dst->ne[0] == N && dst->ne[1] == M && dst->ne[2] == src1->ne[2] && dst->ne[3] == src1->ne[3]);
349351

350-
// src0: [256, 16, 3, 2] -> QNN: [2, 3, 16, 256] (B, H, M, K)
352+
// src0: [K, M, H0, B0] -> QNN: [B0, H0, M, K]
351353
uint32_t src0_dims[] = {static_cast<uint32_t>(src0->ne[3]), static_cast<uint32_t>(src0->ne[2]), static_cast<uint32_t>(src0->ne[1]), static_cast<uint32_t>(src0->ne[0])};
352354
p_tensor0 = GQCGT(src0, "input0", QNN_TENSOR_TYPE_APP_WRITE, QNN_DATATYPE_FLOAT_32, 4,
353355
src0_dims, nullptr, 0);
354356
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_tensor0));
355357

356-
// Reshape src0 to [6, 16, 256] for [B0, M, K]
358+
// Reshape src0 to [B0, M, K]
357359
uint32_t reshape0_out_dims[] = {B0, M, K};
358360
p_reshape0_out = GQCGT(nullptr, "reshape0_out", QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3,
359361
reshape0_out_dims, nullptr, 0);
@@ -365,19 +367,37 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
365367
reshape0_inputs, 1, reshape0_outputs, 1);
366368
CHECK_QNN_API(error, qnn_raw_interface.graphAddNode(graph_handle, reshape0_op));
367369

368-
// src1: [256, 1, 6, 4] -> QNN: [4, 6, 1, 256] (B, H, N, K)
370+
// Tile src0 to match B1: [B0, M, K] -> [B1, M, K]
371+
uint32_t tile0_out_dims[] = {B1, M, K};
372+
p_tile0_out = GQCGT(nullptr, "tile0_out", QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3,
373+
tile0_out_dims, nullptr, 0);
374+
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_tile0_out));
375+
uint32_t tile_multiples[] = {B1 / B0, 1, 1}; // e.g., 24/6 = 4, 6/6 = 1
376+
uint32_t tile_dims[] = {3};
377+
Qnn_Tensor_t *p_tile_multiples = GQCGT(nullptr, "tile_multiples", QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1,
378+
tile_dims, tile_multiples, sizeof(tile_multiples));
379+
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_tile_multiples));
380+
Qnn_Param_t tile_params[] = {{QNN_PARAMTYPE_TENSOR, "multiples", .tensorParam = *p_tile_multiples}};
381+
Qnn_Tensor_t tile0_inputs[] = {*p_reshape0_out};
382+
Qnn_Tensor_t tile0_outputs[] = {*p_tile0_out};
383+
Qnn_OpConfig_t tile0_op = ggmlqnn_create_op_config("tile0", QNN_OP_PACKAGE_NAME_QTI_AISW,
384+
QNN_OP_TILE, tile_params, 1,
385+
tile0_inputs, 1, tile0_outputs, 1);
386+
CHECK_QNN_API(error, qnn_raw_interface.graphAddNode(graph_handle, tile0_op));
387+
388+
// src1: [N, K, H1, B1] -> QNN: [B1, H1, N, K]
369389
uint32_t src1_dims[] = {static_cast<uint32_t>(src1->ne[3]), static_cast<uint32_t>(src1->ne[2]), static_cast<uint32_t>(src1->ne[1]), static_cast<uint32_t>(src1->ne[0])};
370390
p_tensor1 = GQCGT(src1, "input1", QNN_TENSOR_TYPE_APP_WRITE, QNN_DATATYPE_FLOAT_32, 4,
371391
src1_dims, nullptr, 0);
372392
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_tensor1));
373393

374-
// Permute src1 to [4, 6, 256, 1] to align K and N
375-
uint32_t perm_data[] = {0, 1, 3, 2}; // [B, H, N, K] -> [B, H, K, N]
394+
// Permute src1 to [B1, H1, K, N]
395+
uint32_t perm_data[] = {0, 1, 3, 2};
376396
uint32_t perm_dims[] = {4};
377-
Qnn_Tensor_t * p_perm = GQCGT(nullptr, "perm", QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1,
378-
perm_dims, perm_data, sizeof(perm_data));
397+
Qnn_Tensor_t *p_perm = GQCGT(nullptr, "perm", QNN_TENSOR_TYPE_STATIC, QNN_DATATYPE_UINT_32, 1,
398+
perm_dims, perm_data, sizeof(perm_data));
379399
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_perm));
380-
uint32_t permute1_out_dims[] = {static_cast<uint32_t>(src1->ne[3]), static_cast<uint32_t>(src1->ne[2]), static_cast<uint32_t>(src1->ne[0]), static_cast<uint32_t>(src1->ne[1])}; // [4, 6, 256, 1]
400+
uint32_t permute1_out_dims[] = {static_cast<uint32_t>(src1->ne[3]), static_cast<uint32_t>(src1->ne[2]), static_cast<uint32_t>(src1->ne[0]), static_cast<uint32_t>(src1->ne[1])};
381401
p_permute1_out = GQCGT(nullptr, "permute1_out", QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 4,
382402
permute1_out_dims, nullptr, 0);
383403
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_permute1_out));
@@ -389,7 +409,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
389409
permute1_inputs, 1, permute1_outputs, 1);
390410
CHECK_QNN_API(error, qnn_raw_interface.graphAddNode(graph_handle, permute1_op));
391411

392-
// Reshape src1 to [24, 256, 1] for [B1, K, N]
412+
// Reshape src1 to [B1, K, N]
393413
uint32_t reshape1_out_dims[] = {B1, K, N};
394414
p_reshape1_out = GQCGT(nullptr, "reshape1_out", QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3,
395415
reshape1_out_dims, nullptr, 0);
@@ -401,23 +421,19 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
401421
reshape1_inputs, 1, reshape1_outputs, 1);
402422
CHECK_QNN_API(error, qnn_raw_interface.graphAddNode(graph_handle, reshape1_op));
403423

404-
// MatMul: [6, 16, 256] x [24, 256, 1] -> Needs adjustment for broadcasting
405-
// Adjust src0 to match B1 by repeating or reshaping
406-
uint32_t matmul_out_dims[] = {B1, M, N}; // [24, 16, 1]
424+
// MatMul: [B1, M, K] x [B1, K, N] -> [B1, M, N]
425+
uint32_t matmul_out_dims[] = {B1, M, N};
407426
p_matmul_out = GQCGT(nullptr, "matmul_out", QNN_TENSOR_TYPE_NATIVE, QNN_DATATYPE_FLOAT_32, 3,
408427
matmul_out_dims, nullptr, 0);
409428
CHECK_QNN_API(error, qnn_raw_interface.tensorCreateGraphTensor(graph_handle, p_matmul_out));
410-
411-
// Note: QNN MatMul doesn't broadcast; we need to tile src0
412-
// For simplicity, assume dst shape drives execution; adjust src0 later if needed
413-
Qnn_Tensor_t matmul_inputs[] = {*p_reshape0_out, *p_reshape1_out};
429+
Qnn_Tensor_t matmul_inputs[] = {*p_tile0_out, *p_reshape1_out};
414430
Qnn_Tensor_t matmul_outputs[] = {*p_matmul_out};
415431
Qnn_OpConfig_t matmul_op = ggmlqnn_create_op_config("matmul", QNN_OP_PACKAGE_NAME_QTI_AISW,
416432
QNN_OP_MAT_MUL, nullptr, 0,
417433
matmul_inputs, 2, matmul_outputs, 1);
418434
CHECK_QNN_API(error, qnn_raw_interface.graphAddNode(graph_handle, matmul_op));
419435

420-
// Output: [1, 16, 6, 4] -> QNN: [4, 6, 16, 1]
436+
// Output: [N, M, H1, B1] -> QNN: [B1, H1, M, N]
421437
uint32_t reshape2_out_dims[] = {static_cast<uint32_t>(dst->ne[3]), static_cast<uint32_t>(dst->ne[2]), static_cast<uint32_t>(dst->ne[1]), static_cast<uint32_t>(dst->ne[0])};
422438
p_reshape2_out = GQCGT(dst, "output", QNN_TENSOR_TYPE_APP_READ, QNN_DATATYPE_FLOAT_32, 4,
423439
reshape2_out_dims, nullptr, 0);
@@ -433,7 +449,7 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
433449
CHECK_QNN_API(error, qnn_raw_interface.graphFinalize(graph_handle, NULL, NULL));
434450

435451
// Cache
436-
qnn_tensors_t ggml_op_mulmat_tensors = {p_tensor0, p_reshape0_out, p_tensor1, p_permute1_out, p_reshape1_out, p_matmul_out, p_reshape2_out};
452+
qnn_tensors_t ggml_op_mulmat_tensors = {p_tensor0, p_reshape0_out, p_tile0_out, p_tensor1, p_permute1_out, p_reshape1_out, p_matmul_out, p_reshape2_out};
437453
instance->_qnn_graph_map[graph_name] = std::make_tuple(graph_handle, ggml_op_mulmat_tensors);
438454
}
439455

@@ -455,7 +471,6 @@ static void ggml_qnn_mul_mat_4d(ggml_backend_qnn_context *ctx, ggml_tensor *op)
455471

456472
op_perf.info();
457473
}
458-
459474
/*
460475
* @brief performs matrix multiplication with FP32 & quantized weights and floating-point inputs
461476
* using the QNN backend. this function performs matrix multiplication of the input tensor

0 commit comments

Comments
 (0)