Skip to content

Commit 5772991

Browse files
lmonetaNeel-Shah-29
andcommitted
Neel task4 (#15)
* Fusing MatMul and Add operater into Gemm operater * Update RModelParser_ONNX.cxx * Update RModelParser_ONNX.cxx * Update RModelParser_ONNX.hxx * Update RModelParser_ONNX.cxx * Update RModelParser_ONNX.cxx * Update RModelParser_ONNX.cxx Co-authored-by: Neel Shah <neelshah29042002.com> Co-authored-by: Neel-Shah-29 <[email protected]>
1 parent 2a00482 commit 5772991

File tree

2 files changed

+73
-0
lines changed

2 files changed

+73
-0
lines changed

tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ std::unique_ptr<ROperator> make_ROperator_LSTM(const onnx::NodeProto& nodeproto,
3535
std::unique_ptr<ROperator> make_ROperator_BatchNormalization(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
3636
std::unique_ptr<ROperator> make_ROperator_Pool(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
3737
std::unique_ptr<ROperator> make_ROperator_Add(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
38+
std::unique_ptr<ROperator> make_ROperator_GemmFromMatMulandAdd(const onnx::NodeProto& nodeproto1,const onnx::NodeProto& nodeproto2, const onnx::GraphProto& graphproto , std::unordered_map<std::string, ETensorType>& tensor_type);
3839
std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
3940
std::unique_ptr<ROperator> make_ROperator_Slice(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
4041
std::unique_ptr<ROperator> make_ROperator_GRU(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
@@ -64,6 +65,10 @@ const factoryMethodMap mapOptypeOperator = {
6465
{"Flatten", &make_ROperator_Reshape}
6566
};
6667

68+
using factoryMethodMap1 = std::unordered_map<std::string, std::unique_ptr<ROperator> (*)(const onnx::NodeProto&,const onnx::NodeProto&, const onnx::GraphProto&, std::unordered_map<std::string, ETensorType>&)>;
69+
const factoryMethodMap1 mapOptypeOperator1 = {
70+
{"MatMul", &make_ROperator_GemmFromMatMulandAdd}
71+
};
6772
std::unique_ptr<ROperator> make_ROperator(size_t idx, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
6873
}//INTERNAL
6974

tmva/sofie_parsers/src/RModelParser_ONNX.cxx

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ namespace INTERNAL{
1414
std::unique_ptr<ROperator> make_ROperator(size_t idx, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type){
1515
const auto& nodeproto = graphproto.node(idx);
1616
auto find = mapOptypeOperator.find(nodeproto.op_type());
17+
// operator_type = nodeproto.op_type();
18+
if(nodeproto.op_type()=="MatMul"){
19+
if(graphproto.node(idx+1).op_type()=="Add"){
20+
return make_ROperator_GemmFromMatMulandAdd(graphproto.node(idx),graphproto.node(idx+1),graphproto,tensor_type);
21+
}
22+
}
23+
1724
if (find == mapOptypeOperator.end()){
1825
throw std::runtime_error("TMVA::SOFIE - Operator type " + nodeproto.op_type() + " is not yet supported");
1926
// std::unique_ptr<ROperator> op;
@@ -195,6 +202,65 @@ std::unique_ptr<ROperator> make_ROperator_Sigmoid(const onnx::NodeProto& nodepro
195202
return op;
196203
}
197204

205+
206+
std::unique_ptr<ROperator> make_ROperator_GemmFromMatMulandAdd(const onnx::NodeProto& nodeproto1,const onnx::NodeProto& nodeproto2, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
207+
208+
ETensorType input_type = ETensorType::UNDEFINED;
209+
210+
for (int i = 0; i < 2; ++i) {
211+
auto input_name = nodeproto1.input(i);
212+
auto it = tensor_type.find(input_name);
213+
if (it != tensor_type.end()){
214+
// according to ONNX both inputs have same time
215+
if (i == 0) input_type = it->second;
216+
else
217+
assert(it->second == input_type);
218+
} else {
219+
throw std::runtime_error("TMVA::SOFIE ONNX Parser MatMul op has input tensor" + input_name + " but its type is not yet registered");
220+
}
221+
}
222+
223+
for (int i = 0; i < 2; ++i) {
224+
auto input_name = nodeproto2.input(i);
225+
auto it = tensor_type.find(input_name);
226+
if (it != tensor_type.end()){
227+
// according to ONNX both inputs have same time
228+
if (i == 0) input_type = it->second;
229+
else
230+
assert(it->second == input_type);
231+
} else {
232+
throw std::runtime_error("TMVA::SOFIE ONNX Parser Add op has input tensor" + input_name + " but its type is not yet registered");
233+
}
234+
}
235+
std::unique_ptr<ROperator> op;
236+
237+
238+
float attr_alpha =1.0;
239+
float attr_beta =1.0;
240+
int_t attr_transA =0;
241+
int_t attr_transB =0;
242+
243+
switch(input_type){
244+
case ETensorType::FLOAT:
245+
if (nodeproto1.input_size() == 2){
246+
op.reset(new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto1.input(0), nodeproto1.input(1), nodeproto2.output(0)));
247+
}else{
248+
op.reset(new ROperator_Gemm<float>(attr_alpha, attr_beta, attr_transA, attr_transB, nodeproto1.input(0), nodeproto1.input(1), nodeproto2.input(1), nodeproto2.output(0)));
249+
}
250+
break;
251+
default:
252+
throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator for fusing MatMul and Add to Gemm does not yet support input type " + std::to_string(static_cast<int>(input_type)));
253+
}
254+
255+
ETensorType output_type = (op->TypeInference({input_type}))[0];
256+
auto it2 = tensor_type.find(nodeproto2.output(0));
257+
if (it2 == tensor_type.end()){
258+
tensor_type[nodeproto2.output(0)] = output_type;
259+
}
260+
261+
return op;
262+
}
263+
198264
std::unique_ptr<ROperator> make_ROperator_Gemm(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /* graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
199265

200266
ETensorType input_type;
@@ -1017,6 +1083,8 @@ RModel RModelParser_ONNX::Parse(std::string filename){
10171083
rmodel.AddBlasRoutines({"Copy", "Axpy"});
10181084
} else if (op_type == "GRU") {
10191085
rmodel.AddBlasRoutines({"Gemm", "Axpy"});
1086+
} else if (op_type == "Add" && graph.node(i-1).op_type() == "MatMul" ) {
1087+
rmodel.AddBlasRoutines({"Gemm", "Gemv"});
10201088
}
10211089
}
10221090

0 commit comments

Comments
 (0)