Skip to content

Commit 48f164f

Browse files
committed
Added 4 Binary Operators
1 parent 08725d3 commit 48f164f

File tree

6 files changed

+125
-123
lines changed

6 files changed

+125
-123
lines changed

tmva/sofie/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
1616
TMVA/OperatorList.hxx
1717
TMVA/RModel.hxx
1818
TMVA/ROperator.hxx
19-
TMVA/ROperator_Add.hxx
19+
TMVA/RModel_BasicBinaryOp.hxx
2020
TMVA/ROperator_BatchNormalization.hxx
2121
TMVA/ROperator_Conv.hxx
2222
TMVA/ROperator_Gemm.hxx

tmva/sofie/inc/TMVA/OperatorList.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "TMVA/ROperator_LSTM.hxx"
1010
#include "TMVA/ROperator_BatchNormalization.hxx"
1111
#include "TMVA/ROperator_Pool.hxx"
12-
#include "TMVA/ROperator_Add.hxx"
12+
#include "TMVA/RModel_BasicBinaryOp.hxx"
1313
#include "TMVA/ROperator_Reshape.hxx"
1414
#include "TMVA/ROperator_Slice.hxx"
1515
#include "TMVA/ROperator_GRU.hxx"
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#ifndef TMVA_SOFIE_RModel_BasicBinaryOp
2+
#define TMVA_SOFIE_RModel_BasicBinaryOp
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
enum EBasicBinaryOperator { Add, Sub, Mul, Div };
15+
template <typename T, EBasicBinaryOperator Op>
16+
struct BinaryOperatorTrait {
17+
const char * Name() { return"";}
18+
const char * Op() { return "";}
19+
};
20+
21+
class RModel_BasicBinaryOp final : public ROperator{
22+
private:
23+
24+
std::string fNX1;
25+
std::string fNX2;
26+
std::string fNY;
27+
std::vector<size_t> fShape;
28+
template <typename T>
29+
struct BinaryOperatorTrait<T,Add> {
30+
const char * Name() { return "Add";}
31+
const char * Op() { return "+";}
32+
};
33+
34+
template <typename T>
35+
struct BinaryOperatorTrait<T,Sub> {
36+
const char * Name() { return "Sub";}
37+
const char * Op() { return "-";}
38+
};
39+
40+
template <typename T>
41+
struct BinaryOperatorTrait<T,Mul> {
42+
const char * Name() { return "Mul";}
43+
const char * Op() { return "*";}
44+
};
45+
46+
template <typename T>
47+
struct BinaryOperatorTrait<T,Div> {
48+
const char * Name() { return "Div";}
49+
const char * Op() { return "/";}
50+
};
51+
template <typename T, EBasicBinaryOperator Op>
52+
BinaryOperatorTrait<T,Op> *s;
53+
54+
public:
55+
RModel_BasicBinaryOp(){}
56+
RModel_BasicBinaryOp(std::string nameX1, std::string nameX2, std::string nameY):
57+
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
58+
59+
// type of output given input
60+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
61+
return input;
62+
}
63+
64+
// shape of output tensors given input tensors
65+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
66+
// assume now inputs have same shape (no broadcasting)
67+
auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
68+
return ret;
69+
}
70+
71+
void Initialize(RModel& model){
72+
// input must be a graph input, or already initialized intermediate tensor
73+
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
74+
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
75+
}
76+
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
77+
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX2 + "is not found in model");
78+
}
79+
auto shapeX1 = model.GetTensorShape(fNX1);
80+
auto shapeX2 = model.GetTensorShape(fNX2);
81+
// assume same shape X1 and X2
82+
if (shapeX1 != shapeX2) {
83+
std::string msg = "TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
84+
ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
85+
throw std::runtime_error(msg);
86+
}
87+
fShape = shapeX1;
88+
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX1), fShape);
89+
}
90+
91+
92+
std::string Generate(std::string OpName){
93+
OpName = "op_" + OpName;
94+
95+
if (fShape.empty()) {
96+
throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first");
97+
}
98+
std::stringstream out;
99+
// int length = 1;
100+
// for(auto& i: fShape){
101+
// length *= i;
102+
// }
103+
size_t length = ConvertShapeToLength(fShape);
104+
out << "\n//------ " + s->Name()+"\n";
105+
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
106+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id]" + s->Op() + "tensor_" << fNX2 << "[id];\n";
107+
out << SP << "}\n";
108+
return out.str();
109+
}
110+
111+
};
112+
113+
}//SOFIE
114+
}//Experimental
115+
}//TMVA
116+
117+
118+
#endif //TMVA_SOFIE_RModel_BasicBinaryOp

tmva/sofie/inc/TMVA/ROperator_Add.hxx

Lines changed: 0 additions & 116 deletions
This file was deleted.

tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ std::unique_ptr<ROperator> make_ROperator_RNN(const onnx::NodeProto& nodeproto,
3535
std::unique_ptr<ROperator> make_ROperator_LSTM(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
3636
std::unique_ptr<ROperator> make_ROperator_BatchNormalization(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
3737
std::unique_ptr<ROperator> make_ROperator_Pool(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
38-
std::unique_ptr<ROperator> make_ROperator_Add(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
38+
std::unique_ptr<ROperator> make_RModel_BasicBinaryOp(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
3939
std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
4040
std::unique_ptr<ROperator> make_ROperator_Slice(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
4141
std::unique_ptr<ROperator> make_ROperator_GRU(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
@@ -57,7 +57,7 @@ const factoryMethodMap mapOptypeOperator = {
5757
{"AveragePool", &make_ROperator_Pool},
5858
{"GlobalAveragePool", &make_ROperator_Pool},
5959
{"MaxPool", &make_ROperator_Pool},
60-
{"Add", &make_ROperator_Add},
60+
{"RModel_BasicBinaryOp", &make_RModel_BasicBinaryOp},
6161
{"Reshape", &make_ROperator_Reshape},
6262
{"Flatten", &make_ROperator_Reshape},
6363
{"Slice", &make_ROperator_Slice},

tmva/sofie_parsers/src/RModelParser_ONNX.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ std::unique_ptr<ROperator> make_ROperator(size_t idx, const onnx::GraphProto& gr
2424
}
2525
}
2626

27-
std::unique_ptr<ROperator> make_ROperator_Add(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
27+
std::unique_ptr<ROperator> make_RModel_BasicBinaryOp(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
2828

2929
ETensorType input_type = ETensorType::UNDEFINED;
3030

@@ -37,7 +37,7 @@ std::unique_ptr<ROperator> make_ROperator_Add(const onnx::NodeProto& nodeproto,
3737
else
3838
assert(it->second == input_type);
3939
} else {
40-
throw std::runtime_error("TMVA::SOFIE ONNX Parser Add op has input tensor" + input_name + " but its type is not yet registered");
40+
throw std::runtime_error("TMVA::SOFIE ONNX Parser Binary op has input tensor" + input_name + " but its type is not yet registered");
4141
}
4242
}
4343

@@ -48,7 +48,7 @@ std::unique_ptr<ROperator> make_ROperator_Add(const onnx::NodeProto& nodeproto,
4848
op.reset(new ROperator_Add<float>(nodeproto.input(0), nodeproto.input(1), nodeproto.output(0)));
4949
break;
5050
default:
51-
throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Add does not yet support input type " + std::to_string(static_cast<int>(input_type)));
51+
throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " + std::to_string(static_cast<int>(input_type)));
5252
}
5353

5454
ETensorType output_type = (op->TypeInference({input_type}))[0];

0 commit comments

Comments
 (0)