Skip to content

Commit e57dc2a

Browse files
authored
[QNN EP] Lower Gemm with 2d bias to FC + ElementwiseAdd when targeting HTP. (microsoft#25605)
### Description Lower Gemm with 2d bias to FC + ElementwiseAdd when targeting HTP. ### Motivation and Context This change will allow Gemm with 2d bias stays on HTP and not falling back to CPU. --------- Signed-off-by: Mu-Chein Hsu <[email protected]>
1 parent 7b2f667 commit e57dc2a

File tree

2 files changed

+84
-14
lines changed

2 files changed

+84
-14
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,12 @@ Status GemmOpBuilder::ExplictOpCheck(const NodeUnit& node_unit) const {
5555
auto transB = node_helper.Get("transB", static_cast<int64_t>(0));
5656
auto M = (transB == 0) ? inputB_shape.at(1) : inputB_shape.at(0);
5757
if (inputC_shape.size() == 0 || (inputC_shape.size() == 1 && inputC_shape.at(0) != M) ||
58-
(inputC_shape.size() == 2 && (inputC_shape.at(0) != 1 || inputC_shape.at(1) != M))) {
59-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support C with shape [M].");
58+
(inputC_shape.size() == 2 && inputC_shape.at(1) != M)) {
59+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support C with shape [N, M].");
60+
}
61+
62+
if (inputC_shape.size() == 2 && node_unit.Inputs()[2].quant_param.has_value() && inputC_shape.at(0) != 1) {
63+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support quantized C with shape [1, M].");
6064
}
6165
}
6266

@@ -133,7 +137,8 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
133137
qnn_model_wrapper.IsGraphInput(node_input_name)));
134138
}
135139

136-
if (2 == input_i && 2 == input_shape.size()) {
140+
// Reshape [1, M] shape Bias.
141+
if (2 == input_i && 2 == input_shape.size() && input_shape[0] == 1) {
137142
input_shape[0] = input_shape[1];
138143
input_shape.resize(1);
139144
}
@@ -199,8 +204,70 @@ Status GemmOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
199204
std::vector<std::string>&& input_names,
200205
const logging::Logger& logger,
201206
bool do_op_validation) const {
202-
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {},
203-
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
207+
// FullyConnected dosen't support 2d bias with shape [N, M], In this case, decompose Gemm into FullyConnected + Add for compatibility.
208+
bool split_gemm = false;
209+
if (node_unit.Inputs().size() == 3) {
210+
auto& input_c = node_unit.Inputs()[2];
211+
std::vector<uint32_t> input_c_shape;
212+
QnnModelWrapper::GetOnnxShape(input_c.node_arg, input_c_shape);
213+
214+
// Split when input_c has 2d shape and not [1, M]
215+
split_gemm = (input_c_shape.size() == 2 && input_c_shape.at(0) != 1);
216+
}
217+
218+
if (split_gemm) {
219+
// If split_gemm, input and output of Gemm must at least 2d.
220+
const std::string& org_output_name = node_unit.Outputs()[0].node_arg.Name();
221+
TensorInfo input_info = {};
222+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));
223+
TensorInfo output_info = {};
224+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info));
225+
std::vector<uint32_t> output_shape = output_info.shape;
226+
QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy();
227+
228+
const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(org_output_name);
229+
230+
// Create FullyConnected Node
231+
std::vector<std::string> gemm_input_0_1;
232+
gemm_input_0_1.push_back(input_names[0]);
233+
gemm_input_0_1.push_back(input_names[1]);
234+
std::string split_fully_connected_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected";
235+
std::string split_fully_connected_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected_output";
236+
QnnTensorWrapper fully_connected_output(split_fully_connected_output_name, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type,
237+
QnnQuantParamsWrapper(), std::vector<uint32_t>(output_shape));
238+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(fully_connected_output)),
239+
"Failed to add FullyConnected output tensor.");
240+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_fully_connected_name,
241+
QNN_OP_PACKAGE_NAME_QTI_AISW,
242+
QNN_OP_FULLY_CONNECTED,
243+
std::move(gemm_input_0_1),
244+
{split_fully_connected_output_name},
245+
{},
246+
do_op_validation),
247+
"Failed to add FullyConnected node.");
248+
249+
// Create Add Node
250+
Qnn_TensorType_t op_output_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
251+
std::string split_add_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_add";
252+
QnnTensorWrapper op_output_tensor_wrapper(org_output_name, op_output_tensor_type, output_info.qnn_data_type,
253+
op_output_quant_param.Copy(), std::vector<uint32_t>(output_shape));
254+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)),
255+
"Failed to add ElementWiseAdd output tensor.");
256+
std::string bias_name = input_names[2];
257+
258+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_add_name,
259+
QNN_OP_PACKAGE_NAME_QTI_AISW,
260+
QNN_OP_ELEMENT_WISE_ADD,
261+
{split_fully_connected_output_name, bias_name}, // FullyConnected output as input
262+
{org_output_name}, // Original output as output
263+
{},
264+
do_op_validation),
265+
"Failed to add ElementWiseAdd node.");
266+
} else {
267+
ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {},
268+
logger, do_op_validation, GetQnnOpType(node_unit.OpType())));
269+
}
270+
204271
return Status::OK();
205272
}
206273

onnxruntime/test/providers/qnn/gemm_op_test.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,17 @@ TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) {
5454
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
5555
}
5656

57-
// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1).
58-
// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector`
59-
TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) {
57+
// Test Gemm with 2D bias is supported.
58+
TEST_F(QnnCPUBackendTests, Gemm_2D_Bias) {
6059
std::vector<float> input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6);
6160
std::vector<float> input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12);
6261

63-
// 2D matrix mul with bias not supported.
62+
// 2D matrix mul with bias is supported.
6463
RunGemmTest<float>({TestInputDef<float>({2, 3}, false, input_a_data),
6564
TestInputDef<float>({3, 4}, false, input_b_data),
6665
TestInputDef<float>({2, 4}, false, -1.0f, 1.0f)},
6766
{},
68-
ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP.
67+
ExpectedEPNodeAssignment::All); // Assigned to QNN EP.
6968

7069
// However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`.
7170
RunGemmTest<float>({TestInputDef<float>({2, 3}, false, input_a_data),
@@ -525,15 +524,19 @@ TEST_F(QnnGPUBackendTests, Gemm_AlphaBetaUnsupported) {
525524
"gpu");
526525
}
527526

528-
// Gemm with matrix bias ie 2D (M, N) is NOT supported. (Note: vector bias is supported ie when M == 1).
527+
// Gemm with matrix bias ie 2D (M, N) is supported.
528+
// When vector bias ie M == 1
529529
// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector`
530-
TEST_F(QnnGPUBackendTests, Gemm_2DBiasUnsupported) {
531-
// 2D matrix mul with 2D bias not supported.
530+
// When 2D bias i.e. M != 1, N != 1.
531+
// When 2D bias i.e. M != 1, N != 1.
532+
// QNN's Gemm will be split in to FullyConnected and ElementwiseAdd.
533+
TEST_F(QnnGPUBackendTests, Gemm_2D_Bias) {
534+
// 2D matrix mul with 2D bias is supported when Gemm is not a QDQ node.
532535
RunGemmTest<float>({TestInputDef<float>({2, 3}, false, -10.0f, 10.0f),
533536
TestInputDef<float>({3, 4}, false, -10.0f, 10.0f),
534537
TestInputDef<float>({2, 4}, false, -1.0f, 1.0f)},
535538
{},
536-
ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP.
539+
ExpectedEPNodeAssignment::All, // Should be assigned to QNN EP.
537540
"gpu");
538541
}
539542

0 commit comments

Comments
 (0)