@@ -14,6 +14,13 @@ namespace INTERNAL{
1414std::unique_ptr<ROperator> make_ROperator (size_t idx, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type){
1515 const auto & nodeproto = graphproto.node (idx);
1616 auto find = mapOptypeOperator.find (nodeproto.op_type ());
17+ // operator_type = nodeproto.op_type();
18+ if (nodeproto.op_type ()==" MatMul" ){
19+ if (graphproto.node (idx+1 ).op_type ()==" Add" ){
20+ return make_ROperator_GemmFromMatMulandAdd (graphproto.node (idx),graphproto.node (idx+1 ),graphproto,tensor_type);
21+ }
22+ }
23+
1724 if (find == mapOptypeOperator.end ()){
1825 throw std::runtime_error (" TMVA::SOFIE - Operator type " + nodeproto.op_type () + " is not yet supported" );
1926 // std::unique_ptr<ROperator> op;
@@ -195,6 +202,65 @@ std::unique_ptr<ROperator> make_ROperator_Sigmoid(const onnx::NodeProto& nodepro
195202 return op;
196203}
197204
205+
206+ std::unique_ptr<ROperator> make_ROperator_GemmFromMatMulandAdd (const onnx::NodeProto& nodeproto1,const onnx::NodeProto& nodeproto2, const onnx::GraphProto& /* graphproto */ , std::unordered_map<std::string, ETensorType>& tensor_type){
207+
208+ ETensorType input_type = ETensorType::UNDEFINED;
209+
210+ for (int i = 0 ; i < 2 ; ++i) {
211+ auto input_name = nodeproto1.input (i);
212+ auto it = tensor_type.find (input_name);
213+ if (it != tensor_type.end ()){
214+ // according to ONNX both inputs have same time
215+ if (i == 0 ) input_type = it->second ;
216+ else
217+ assert (it->second == input_type);
218+ } else {
219+ throw std::runtime_error (" TMVA::SOFIE ONNX Parser MatMul op has input tensor" + input_name + " but its type is not yet registered" );
220+ }
221+ }
222+
223+ for (int i = 0 ; i < 2 ; ++i) {
224+ auto input_name = nodeproto2.input (i);
225+ auto it = tensor_type.find (input_name);
226+ if (it != tensor_type.end ()){
227+ // according to ONNX both inputs have same time
228+ if (i == 0 ) input_type = it->second ;
229+ else
230+ assert (it->second == input_type);
231+ } else {
232+ throw std::runtime_error (" TMVA::SOFIE ONNX Parser Add op has input tensor" + input_name + " but its type is not yet registered" );
233+ }
234+ }
235+ std::unique_ptr<ROperator> op;
236+
237+
238+ float attr_alpha =1.0 ;
239+ float attr_beta =1.0 ;
240+ int_t attr_transA =0 ;
241+ int_t attr_transB =0 ;
242+
243+ switch (input_type){
244+ case ETensorType::FLOAT:
245+ if (nodeproto1.input_size () == 2 ){
246+ op.reset (new ROperator_Gemm<float >(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto1.input (0 ), nodeproto1.input (1 ), nodeproto2.output (0 )));
247+ }else {
248+ op.reset (new ROperator_Gemm<float >(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto1.input (0 ), nodeproto1.input (1 ), nodeproto2.input (1 ), nodeproto2.output (0 )));
249+ }
250+ break ;
251+ default :
252+ throw std::runtime_error (" TMVA::SOFIE - Unsupported - Operator for fusing MatMul and Add to Gemm does not yet support input type " + std::to_string (static_cast <int >(input_type)));
253+ }
254+
255+ ETensorType output_type = (op->TypeInference ({input_type}))[0 ];
256+ auto it2 = tensor_type.find (nodeproto2.output (0 ));
257+ if (it2 == tensor_type.end ()){
258+ tensor_type[nodeproto2.output (0 )] = output_type;
259+ }
260+
261+ return op;
262+ }
263+
198264std::unique_ptr<ROperator> make_ROperator_Gemm (const onnx::NodeProto& nodeproto, const onnx::GraphProto& /* graphproto */ , std::unordered_map<std::string, ETensorType>& tensor_type){
199265
200266 ETensorType input_type;
@@ -1017,6 +1083,8 @@ RModel RModelParser_ONNX::Parse(std::string filename){
10171083 rmodel.AddBlasRoutines ({" Copy" , " Axpy" });
10181084 } else if (op_type == " GRU" ) {
10191085 rmodel.AddBlasRoutines ({" Gemm" , " Axpy" });
1086+ } else if (op_type == " Add" && graph.node (i-1 ).op_type () == " MatMul" ) {
1087+ rmodel.AddBlasRoutines ({" Gemm" , " Gemv" });
10201088 }
10211089 }
10221090
0 commit comments