Skip to content

Commit 08725d3

Browse files
authored
Add all the Basic Binary Operators
1 parent fdc99d6 commit 08725d3

File tree

1 file changed

+39
-11
lines changed

1 file changed

+39
-11
lines changed

tmva/sofie/inc/TMVA/ROperator_Add.hxx

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#ifndef TMVA_SOFIE_ROPERATOR_ADD
2-
#define TMVA_SOFIE_ROPERATOR_ADD
1+
#ifndef TMVA_SOFIE_ROPERATOR_BinaryOperator
2+
#define TMVA_SOFIE_ROPERATOR_BinaryOperator
33

44
#include "TMVA/SOFIE_common.hxx"
55
#include "TMVA/ROperator.hxx"
@@ -11,8 +11,9 @@ namespace TMVA{
1111
namespace Experimental{
1212
namespace SOFIE{
1313

14+
enum RModel_BasicBinaryOperator { Add, Sub, Mul, Div };
1415
template <typename T>
15-
class ROperator_Add final : public ROperator
16+
class ROperator_BinaryOperator final : public ROperator
1617
{
1718

1819
private:
@@ -23,8 +24,17 @@ private:
2324
std::vector<size_t> fShape;
2425

2526
public:
26-
ROperator_Add(){}
27-
ROperator_Add(std::string nameX1, std::string nameX2, std::string nameY):
27+
28+
std::string Name() {
29+
if (RModel_BasicBinaryOperator == Add) return "Add";
30+
else if (RModel_BasicBinaryOperator == Sub) return "Sub";
31+
else if (RModel_BasicBinaryOperator == Mul ) return "Mul";
32+
else if (RModel_BasicBinaryOperator == Div ) return "Div";
33+
return "Invalid";
34+
}
35+
36+
ROperator_BinaryOperator(){}
37+
ROperator_BinaryOperator(std::string nameX1, std::string nameX2, std::string nameY):
2838
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
2939

3040
// type of output given input
@@ -42,16 +52,16 @@ public:
4252
void Initialize(RModel& model){
4353
// input must be a graph input, or already initialized intermediate tensor
4454
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
45-
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
55+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX1 + "is not found in model");
4656
}
4757
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
48-
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
58+
throw std::runtime_error(std::string("TMVA SOFIE Binary Op Input Tensor ") + fNX2 + "is not found in model");
4959
}
5060
auto shapeX1 = model.GetTensorShape(fNX1);
5161
auto shapeX2 = model.GetTensorShape(fNX2);
5262
// assume same shape X1 and X2
5363
if (shapeX1 != shapeX2) {
54-
std::string msg = "TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
64+
std::string msg = "TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
5565
ConvertShapeToString(shapeX1) + "shape 2 is " + ConvertShapeToString(shapeX2);
5666
throw std::runtime_error(msg);
5767
}
@@ -63,18 +73,36 @@ public:
6373
std::string Generate(std::string OpName){
6474
OpName = "op_" + OpName;
6575
if (fShape.empty()) {
66-
throw std::runtime_error("TMVA SOFIE Add called to Generate without being initialized first");
76+
throw std::runtime_error("TMVA SOFIE binary operator called to Generate without being initialized first");
6777
}
6878
std::stringstream out;
6979
// int length = 1;
7080
// for(auto& i: fShape){
7181
// length *= i;
7282
// }
83+
7384
size_t length = ConvertShapeToLength(fShape);
74-
out << "\n//------ Add\n";
85+
out << "\n//---- operator " << Name() << " " << OpName << "\n";
86+
if(RModel_BasicBinaryOperator == Add){
7587
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
7688
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] + tensor_" << fNX2 << "[id];\n";
7789
out << SP << "}\n";
90+
}
91+
if(RModel_BasicBinaryOperator == Sub){
92+
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
93+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] - tensor_" << fNX2 << "[id];\n";
94+
out << SP << "}\n";
95+
}
96+
if(RModel_BasicBinaryOperator == Mul){
97+
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
98+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] * tensor_" << fNX2 << "[id];\n";
99+
out << SP << "}\n";
100+
}
101+
if(RModel_BasicBinaryOperator == Div){
102+
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
103+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] / tensor_" << fNX2 << "[id];\n";
104+
out << SP << "}\n";
105+
}
78106
return out.str();
79107
}
80108

@@ -85,4 +113,4 @@ public:
85113
}//TMVA
86114

87115

88-
#endif //TMVA_SOFIE_ROPERATOR_Add
116+
#endif //TMVA_SOFIE_ROPERATOR_BinaryOperator

0 commit comments

Comments
 (0)