1- #ifndef TMVA_SOFIE_ROPERATOR_ADD
2- #define TMVA_SOFIE_ROPERATOR_ADD
1+ #ifndef TMVA_SOFIE_ROPERATOR_BinaryOperator
2+ #define TMVA_SOFIE_ROPERATOR_BinaryOperator
33
44#include " TMVA/SOFIE_common.hxx"
55#include " TMVA/ROperator.hxx"
@@ -11,8 +11,9 @@ namespace TMVA{
1111namespace Experimental {
1212namespace SOFIE {
1313
14+ enum RModel_BasicBinaryOperator { Add, Sub, Mul, Div };
1415template <typename T>
15- class ROperator_Add final : public ROperator
16+ class ROperator_BinaryOperator final : public ROperator
1617{
1718
1819private:
@@ -23,8 +24,17 @@ private:
2324 std::vector<size_t > fShape ;
2425
2526public:
26- ROperator_Add (){}
27- ROperator_Add (std::string nameX1, std::string nameX2, std::string nameY):
27+
28+ std::string Name () {
29+ if (RModel_BasicBinaryOperator == Add) return " Add" ;
30+ else if (RModel_BasicBinaryOperator == Sub) return " Sub" ;
31+ else if (RModel_BasicBinaryOperator == Mul ) return " Mul" ;
32+ else if (RModel_BasicBinaryOperator == Div ) return " Div" ;
33+ return " Invalid" ;
34+ }
35+
36+ ROperator_BinaryOperator (){}
37+ ROperator_BinaryOperator (std::string nameX1, std::string nameX2, std::string nameY):
2838 fNX1 (UTILITY::Clean_name(nameX1)), fNX2 (UTILITY::Clean_name(nameX2)), fNY (UTILITY::Clean_name(nameY)){}
2939
3040 // type of output given input
@@ -42,16 +52,16 @@ public:
4252 void Initialize (RModel& model){
4353 // input must be a graph input, or already initialized intermediate tensor
4454 if (model.CheckIfTensorAlreadyExist (fNX1 ) == false ){
45- throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX1 + " is not found in model" );
55+ throw std::runtime_error (std::string (" TMVA SOFIE Binary Op Input Tensor " ) + fNX1 + " is not found in model" );
4656 }
4757 if (model.CheckIfTensorAlreadyExist (fNX2 ) == false ) {
48- throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX1 + " is not found in model" );
58+ throw std::runtime_error (std::string (" TMVA SOFIE Binary Op Input Tensor " ) + fNX2 + " is not found in model" );
4959 }
5060 auto shapeX1 = model.GetTensorShape (fNX1 );
5161 auto shapeX2 = model.GetTensorShape (fNX2 );
5262 // assume same shape X1 and X2
5363 if (shapeX1 != shapeX2) {
54- std::string msg = " TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
64+ std::string msg = " TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
5565 ConvertShapeToString (shapeX1) + " shape 2 is " + ConvertShapeToString (shapeX2);
5666 throw std::runtime_error (msg);
5767 }
@@ -63,18 +73,36 @@ public:
6373 std::string Generate (std::string OpName){
6474 OpName = " op_" + OpName;
6575 if (fShape .empty ()) {
66- throw std::runtime_error (" TMVA SOFIE Add called to Generate without being initialized first" );
76+ throw std::runtime_error (" TMVA SOFIE binary operator called to Generate without being initialized first" );
6777 }
6878 std::stringstream out;
6979 // int length = 1;
7080 // for(auto& i: fShape){
7181 // length *= i;
7282 // }
83+
7384 size_t length = ConvertShapeToLength (fShape );
74- out << " \n //------ Add\n " ;
85+ out << " \n //---- operator " << Name () << " " << OpName << " \n " ;
86+ if (RModel_BasicBinaryOperator == Add){
7587 out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
7688 out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id] + tensor_" << fNX2 << " [id];\n " ;
7789 out << SP << " }\n " ;
90+ }
91+ if (RModel_BasicBinaryOperator == Sub){
92+ out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
93+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id] - tensor_" << fNX2 << " [id];\n " ;
94+ out << SP << " }\n " ;
95+ }
96+ if (RModel_BasicBinaryOperator == Mul){
97+ out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
98+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id] * tensor_" << fNX2 << " [id];\n " ;
99+ out << SP << " }\n " ;
100+ }
101+ if (RModel_BasicBinaryOperator == Div){
102+ out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
103+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id] / tensor_" << fNX2 << " [id];\n " ;
104+ out << SP << " }\n " ;
105+ }
78106 return out.str ();
79107 }
80108
@@ -85,4 +113,4 @@ public:
85113}// TMVA
86114
87115
88- #endif // TMVA_SOFIE_ROPERATOR_Add
116+ #endif // TMVA_SOFIE_ROPERATOR_BinaryOperator
0 commit comments