@@ -55,8 +55,12 @@ Status GemmOpBuilder::ExplictOpCheck(const NodeUnit& node_unit) const {
5555    auto  transB = node_helper.Get (" transB"  , static_cast <int64_t >(0 ));
5656    auto  M = (transB == 0 ) ? inputB_shape.at (1 ) : inputB_shape.at (0 );
5757    if  (inputC_shape.size () == 0  || (inputC_shape.size () == 1  && inputC_shape.at (0 ) != M) ||
58-         (inputC_shape.size () == 2  && (inputC_shape.at (0 ) != 1  || inputC_shape.at (1 ) != M))) {
59-       return  ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " QNN FullyConnected Op only support C with shape [M]."  );
58+         (inputC_shape.size () == 2  && inputC_shape.at (1 ) != M)) {
59+       return  ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " QNN FullyConnected Op only support C with shape [N, M]."  );
60+     }
61+ 
62+     if  (inputC_shape.size () == 2  && node_unit.Inputs ()[2 ].quant_param .has_value () && inputC_shape.at (0 ) != 1 ) {
63+       return  ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " QNN FullyConnected Op only support quantized C with shape [1, M]."  );
6064    }
6165  }
6266
@@ -133,7 +137,8 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
133137                                                             qnn_model_wrapper.IsGraphInput (node_input_name)));
134138    }
135139
136-     if  (2  == input_i && 2  == input_shape.size ()) {
140+     //  Reshape [1, M] shape Bias.
141+     if  (2  == input_i && 2  == input_shape.size () && input_shape[0 ] == 1 ) {
137142      input_shape[0 ] = input_shape[1 ];
138143      input_shape.resize (1 );
139144    }
@@ -199,8 +204,70 @@ Status GemmOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
199204                                                  std::vector<std::string>&& input_names,
200205                                                  const  logging::Logger& logger,
201206                                                  bool  do_op_validation) const  {
202-   ORT_RETURN_IF_ERROR (ProcessOutputs (qnn_model_wrapper, node_unit, std::move (input_names), {},
203-                                      logger, do_op_validation, GetQnnOpType (node_unit.OpType ())));
207+   //  FullyConnected dosen't support 2d bias with shape [N, M], In this case, decompose Gemm into FullyConnected + Add for compatibility.
208+   bool  split_gemm = false ;
209+   if  (node_unit.Inputs ().size () == 3 ) {
210+     auto & input_c = node_unit.Inputs ()[2 ];
211+     std::vector<uint32_t > input_c_shape;
212+     QnnModelWrapper::GetOnnxShape (input_c.node_arg , input_c_shape);
213+ 
214+     //  Split when input_c has 2d shape and not [1, M]
215+     split_gemm = (input_c_shape.size () == 2  && input_c_shape.at (0 ) != 1 );
216+   }
217+ 
218+   if  (split_gemm) {
219+     //  If split_gemm, input and output of Gemm must at least 2d.
220+     const  std::string& org_output_name = node_unit.Outputs ()[0 ].node_arg .Name ();
221+     TensorInfo input_info = {};
222+     ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Inputs ()[0 ], input_info));
223+     TensorInfo output_info = {};
224+     ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (node_unit.Outputs ()[0 ], output_info));
225+     std::vector<uint32_t > output_shape = output_info.shape ;
226+     QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param .Copy ();
227+ 
228+     const  bool  is_graph_output = qnn_model_wrapper.IsGraphOutput (org_output_name);
229+ 
230+     //  Create FullyConnected Node
231+     std::vector<std::string> gemm_input_0_1;
232+     gemm_input_0_1.push_back (input_names[0 ]);
233+     gemm_input_0_1.push_back (input_names[1 ]);
234+     std::string split_fully_connected_name = onnxruntime::qnn::utils::GetNodeName (node_unit) + " _split_FullyConnected"  ;
235+     std::string split_fully_connected_output_name = onnxruntime::qnn::utils::GetNodeName (node_unit) + " _split_FullyConnected_output"  ;
236+     QnnTensorWrapper fully_connected_output (split_fully_connected_output_name, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type ,
237+                                             QnnQuantParamsWrapper (), std::vector<uint32_t >(output_shape));
238+     ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (fully_connected_output)),
239+                       " Failed to add FullyConnected output tensor."  );
240+     ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (split_fully_connected_name,
241+                                                       QNN_OP_PACKAGE_NAME_QTI_AISW,
242+                                                       QNN_OP_FULLY_CONNECTED,
243+                                                       std::move (gemm_input_0_1),
244+                                                       {split_fully_connected_output_name},
245+                                                       {},
246+                                                       do_op_validation),
247+                       " Failed to add FullyConnected node."  );
248+ 
249+     //  Create Add Node
250+     Qnn_TensorType_t op_output_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
251+     std::string split_add_name = onnxruntime::qnn::utils::GetNodeName (node_unit) + " _split_add"  ;
252+     QnnTensorWrapper op_output_tensor_wrapper (org_output_name, op_output_tensor_type, output_info.qnn_data_type ,
253+                                               op_output_quant_param.Copy (), std::vector<uint32_t >(output_shape));
254+     ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (op_output_tensor_wrapper)),
255+                       " Failed to add ElementWiseAdd output tensor."  );
256+     std::string bias_name = input_names[2 ];
257+ 
258+     ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (split_add_name,
259+                                                       QNN_OP_PACKAGE_NAME_QTI_AISW,
260+                                                       QNN_OP_ELEMENT_WISE_ADD,
261+                                                       {split_fully_connected_output_name, bias_name},  //  FullyConnected output as input
262+                                                       {org_output_name},                               //  Original output as output
263+                                                       {},
264+                                                       do_op_validation),
265+                       " Failed to add ElementWiseAdd node."  );
266+   } else  {
267+     ORT_RETURN_IF_ERROR (ProcessOutputs (qnn_model_wrapper, node_unit, std::move (input_names), {},
268+                                        logger, do_op_validation, GetQnnOpType (node_unit.OpType ())));
269+   }
270+ 
204271  return  Status::OK ();
205272}
206273
0 commit comments