Skip to content

Commit 7cb26bc

Browse files
committed
The required changes are made regarding the implementation of all Binary operators
1 parent 4954852 commit 7cb26bc

File tree

5 files changed

+57
-48
lines changed

5 files changed

+57
-48
lines changed

tmva/sofie/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie
1616
TMVA/OperatorList.hxx
1717
TMVA/RModel.hxx
1818
TMVA/ROperator.hxx
19-
TMVA/RModel_BasicBinaryOp.hxx
19+
TMVA/ROperator_BasicBinary.hxx
2020
TMVA/ROperator_BatchNormalization.hxx
2121
TMVA/ROperator_Conv.hxx
2222
TMVA/ROperator_Gemm.hxx

tmva/sofie/inc/TMVA/OperatorList.hxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "TMVA/ROperator_LSTM.hxx"
1010
#include "TMVA/ROperator_BatchNormalization.hxx"
1111
#include "TMVA/ROperator_Pool.hxx"
12-
#include "TMVA/RModel_BasicBinaryOp.hxx"
12+
#include "TMVA/ROperator_BasicBinary.hxx"
1313
#include "TMVA/ROperator_Reshape.hxx"
1414
#include "TMVA/ROperator_Slice.hxx"
1515
#include "TMVA/ROperator_GRU.hxx"
Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TMVA_SOFIE_RModel_BasicBinaryOp
2-
#define TMVA_SOFIE_RModel_BasicBinaryOp
1+
#ifndef TMVA_SOFIE_ROperator_BasicBinary
2+
#define TMVA_SOFIE_ROperator_BasicBinary
33

44
#include "TMVA/SOFIE_common.hxx"
55
#include "TMVA/ROperator.hxx"
@@ -13,51 +13,53 @@ namespace SOFIE{
1313

1414
enum EBasicBinaryOperator { Add, Sub, Mul, Div };
1515

16-
class RModel_BasicBinaryOp final : public ROperator{
16+
template <typename T, EBasicBinaryOperator Op1>
17+
struct BinaryOperatorTrait {
18+
const char *Name() { return ""; }
19+
const char *Op() { return ""; }
20+
};
21+
template <typename T>
22+
struct BinaryOperatorTrait<T, Add> {
23+
static const char *Name() { return "Add"; }
24+
static const char *Op() { return "+"; }
25+
};
26+
27+
template <typename T>
28+
struct BinaryOperatorTrait<T, Sub> {
29+
static const char *Name() { return "Sub"; }
30+
static const char *Op() { return "-"; }
31+
};
32+
33+
template <typename T>
34+
struct BinaryOperatorTrait<T, Mul> {
35+
static const char *Name() { return "Mul"; }
36+
static const char *Op() { return "*"; }
37+
};
38+
39+
template <typename T>
40+
struct BinaryOperatorTrait<T, Div> {
41+
static const char *Name() { return "Div"; }
42+
static const char *Op() { return "/"; }
43+
};
44+
45+
template<typename T, EBasicBinaryOperator Op>
46+
class ROperator_BasicBinary final : public ROperator{
1747
private:
1848

1949
std::string fNX1;
2050
std::string fNX2;
2151
std::string fNY;
2252
std::vector<size_t> fShape;
23-
template <typename T, EBasicBinaryOperator Op1>
24-
struct BinaryOperatorTrait {
25-
const char * Name() { return "";}
26-
const char * Op() { return "";}
27-
};
28-
template <typename T>
29-
struct BinaryOperatorTrait<T,Add> {
30-
const char * Name() { return "Add";}
31-
const char * Op() { return "+";}
32-
};
33-
34-
template <typename T>
35-
struct BinaryOperatorTrait<T,Sub> {
36-
const char * Name() { return "Sub";}
37-
const char * Op() { return "-";}
38-
};
39-
40-
template <typename T>
41-
struct BinaryOperatorTrait<T,Mul> {
42-
const char * Name() { return "Mul";}
43-
const char * Op() { return "*";}
44-
};
45-
46-
template <typename T>
47-
struct BinaryOperatorTrait<T,Div> {
48-
const char * Name() { return "Div";}
49-
const char * Op() { return "/";}
50-
};
51-
52-
template <typename T, EBasicBinaryOperator Op1>
53-
BinaryOperatorTrait<T,Op1> *s;
53+
54+
// template <typename T, EBasicBinaryOperator Op1>
55+
// BinaryOperatorTrait<T,Op1> *s;
5456

5557
public:
56-
RModel_BasicBinaryOp(){}
57-
RModel_BasicBinaryOp(std::string nameX1, std::string nameX2, std::string nameY):
58+
ROperator_BasicBinary(){}
59+
ROperator_BasicBinary(std::string nameX1, std::string nameX2, std::string nameY):
5860
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
5961

60-
// type of output given input
62+
// type of output given input
6163
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
6264
return input;
6365
}
@@ -79,7 +81,7 @@ public:
7981
}
8082
auto shapeX1 = model.GetTensorShape(fNX1);
8183
auto shapeX2 = model.GetTensorShape(fNX2);
82-
// assume same shape X1 and X2
84+
// assume same shape X1 and X2
8385
if (shapeX1 != shapeX2) {
8486
std::string msg = "TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
8587
ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
@@ -102,9 +104,10 @@ public:
102104
// length *= i;
103105
// }
104106
size_t length = ConvertShapeToLength(fShape);
105-
out << "\n//------ " + s->Name()+"\n";
107+
out << "\n//------ " + std::string(BinaryOperatorTrait<T,Op>::Name())+"\n";
106108
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
107-
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id]" + s->Op() + "tensor_" << fNX2 << "[id];\n";
109+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id]" +
110+
std::string(BinaryOperatorTrait<T,Op>::Op()) + "tensor_" << fNX2 << "[id];\n";
108111
out << SP << "}\n";
109112
return out.str();
110113
}
@@ -116,4 +119,4 @@ public:
116119
}//TMVA
117120

118121

119-
#endif //TMVA_SOFIE_RModel_BasicBinaryOp
122+
#endif //TMVA_SOFIE_ROperator_BasicBinary

tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace Experimental{
2323
namespace SOFIE{
2424

2525
namespace INTERNAL{
26+
// enum EBasicBinaryOperator { Add, Sub, Mul, Div };
2627

2728
std::unique_ptr<ROperator> make_ROperator_Transpose(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
2829
std::unique_ptr<ROperator> make_ROperator_Relu(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
@@ -35,7 +36,7 @@ std::unique_ptr<ROperator> make_ROperator_RNN(const onnx::NodeProto& nodeproto,
3536
std::unique_ptr<ROperator> make_ROperator_LSTM(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
3637
std::unique_ptr<ROperator> make_ROperator_BatchNormalization(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
3738
std::unique_ptr<ROperator> make_ROperator_Pool(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
38-
std::unique_ptr<ROperator> make_RModel_BasicBinaryOp(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
39+
template <EBasicBinaryOperator Op1>std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
3940
std::unique_ptr<ROperator> make_ROperator_Reshape(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
4041
std::unique_ptr<ROperator> make_ROperator_Slice(const onnx::NodeProto &nodeproto, const onnx::GraphProto &graphproto, std::unordered_map<std::string, ETensorType> &tensor_type);
4142
std::unique_ptr<ROperator> make_ROperator_GRU(const onnx::NodeProto& nodeproto, const onnx::GraphProto& graphproto, std::unordered_map<std::string, ETensorType>& tensor_type);
@@ -57,7 +58,10 @@ const factoryMethodMap mapOptypeOperator = {
5758
{"AveragePool", &make_ROperator_Pool},
5859
{"GlobalAveragePool", &make_ROperator_Pool},
5960
{"MaxPool", &make_ROperator_Pool},
60-
{"RModel_BasicBinaryOp", &make_RModel_BasicBinaryOp},
61+
{"Add", &make_ROperator_BasicBinary<Add>},
62+
{"Sub", &make_ROperator_BasicBinary<Sub>},
63+
{"Mul", &make_ROperator_BasicBinary<Mul>},
64+
{"Div", &make_ROperator_BasicBinary<Div>},
6165
{"Reshape", &make_ROperator_Reshape},
6266
{"Flatten", &make_ROperator_Reshape},
6367
{"Slice", &make_ROperator_Slice},

tmva/sofie_parsers/src/RModelParser_ONNX.cxx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ std::unique_ptr<ROperator> make_ROperator(size_t idx, const onnx::GraphProto& gr
2323
return (find->second)(nodeproto, graphproto, tensor_type);
2424
}
2525
}
26-
27-
std::unique_ptr<ROperator> make_RModel_BasicBinaryOp(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
26+
// enum EBasicBinaryOperator { Add, Sub, Mul, Div };
27+
template<EBasicBinaryOperator Op1>
28+
std::unique_ptr<ROperator> make_ROperator_BasicBinary(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto */, std::unordered_map<std::string, ETensorType>& tensor_type){
2829

2930
ETensorType input_type = ETensorType::UNDEFINED;
3031

@@ -45,7 +46,7 @@ std::unique_ptr<ROperator> make_RModel_BasicBinaryOp(const onnx::NodeProto& node
4546

4647
switch(input_type){
4748
case ETensorType::FLOAT:
48-
op.reset(new ROperator_Add<float>(nodeproto.input(0), nodeproto.input(1), nodeproto.output(0)));
49+
op.reset(new ROperator_BasicBinary<float,Op1>(nodeproto.input(0), nodeproto.input(1), nodeproto.output(0)));
4950
break;
5051
default:
5152
throw std::runtime_error("TMVA::SOFIE - Unsupported - Binary Operator does not yet support input type " + std::to_string(static_cast<int>(input_type)));
@@ -59,6 +60,7 @@ std::unique_ptr<ROperator> make_RModel_BasicBinaryOp(const onnx::NodeProto& node
5960

6061
return op;
6162
}
63+
6264
std::unique_ptr<ROperator> make_ROperator_Transpose(const onnx::NodeProto& nodeproto, const onnx::GraphProto& /*graphproto*/, std::unordered_map<std::string, ETensorType>& tensor_type){
6365

6466
ETensorType input_type;

0 commit comments

Comments
 (0)