1- #ifndef TMVA_SOFIE_RModel_BasicBinaryOp
2- #define TMVA_SOFIE_RModel_BasicBinaryOp
1+ #ifndef TMVA_SOFIE_ROperator_BasicBinary
2+ #define TMVA_SOFIE_ROperator_BasicBinary
33
44#include " TMVA/SOFIE_common.hxx"
55#include " TMVA/ROperator.hxx"
@@ -13,51 +13,53 @@ namespace SOFIE{
1313
1414enum EBasicBinaryOperator { Add, Sub, Mul, Div };
1515
16- class RModel_BasicBinaryOp final : public ROperator{
16+ template <typename T, EBasicBinaryOperator Op1>
17+ struct BinaryOperatorTrait {
18+ const char *Name () { return " " ; }
19+ const char *Op () { return " " ; }
20+ };
21+ template <typename T>
22+ struct BinaryOperatorTrait <T, Add> {
23+ static const char *Name () { return " Add" ; }
24+ static const char *Op () { return " +" ; }
25+ };
26+
27+ template <typename T>
28+ struct BinaryOperatorTrait <T, Sub> {
29+ static const char *Name () { return " Sub" ; }
30+ static const char *Op () { return " -" ; }
31+ };
32+
33+ template <typename T>
34+ struct BinaryOperatorTrait <T, Mul> {
35+ static const char *Name () { return " Mul" ; }
36+ static const char *Op () { return " *" ; }
37+ };
38+
39+ template <typename T>
40+ struct BinaryOperatorTrait <T, Div> {
41+ static const char *Name () { return " Div" ; }
42+ static const char *Op () { return " /" ; }
43+ };
44+
45+ template <typename T, EBasicBinaryOperator Op>
46+ class ROperator_BasicBinary final : public ROperator{
1747private:
1848
1949 std::string fNX1 ;
2050 std::string fNX2 ;
2151 std::string fNY ;
2252 std::vector<size_t > fShape ;
23- template <typename T, EBasicBinaryOperator Op1>
24- struct BinaryOperatorTrait {
25- const char * Name () { return " " ;}
26- const char * Op () { return " " ;}
27- };
28- template <typename T>
29- struct BinaryOperatorTrait <T,Add> {
30- const char * Name () { return " Add" ;}
31- const char * Op () { return " +" ;}
32- };
33-
34- template <typename T>
35- struct BinaryOperatorTrait <T,Sub> {
36- const char * Name () { return " Sub" ;}
37- const char * Op () { return " -" ;}
38- };
39-
40- template <typename T>
41- struct BinaryOperatorTrait <T,Mul> {
42- const char * Name () { return " Mul" ;}
43- const char * Op () { return " *" ;}
44- };
45-
46- template <typename T>
47- struct BinaryOperatorTrait <T,Div> {
48- const char * Name () { return " Div" ;}
49- const char * Op () { return " /" ;}
50- };
51-
52- template <typename T, EBasicBinaryOperator Op1>
53- BinaryOperatorTrait<T,Op1> *s;
53+
54+ // template <typename T, EBasicBinaryOperator Op1>
55+ // BinaryOperatorTrait<T,Op1> *s;
5456
5557public:
56- RModel_BasicBinaryOp (){}
57- RModel_BasicBinaryOp (std::string nameX1, std::string nameX2, std::string nameY):
58+ ROperator_BasicBinary (){}
59+ ROperator_BasicBinary (std::string nameX1, std::string nameX2, std::string nameY):
5860 fNX1 (UTILITY::Clean_name(nameX1)), fNX2 (UTILITY::Clean_name(nameX2)), fNY (UTILITY::Clean_name(nameY)){}
5961
60- // type of output given input
62+ // type of output given input
6163 std::vector<ETensorType> TypeInference (std::vector<ETensorType> input){
6264 return input;
6365 }
@@ -79,7 +81,7 @@ public:
7981 }
8082 auto shapeX1 = model.GetTensorShape (fNX1 );
8183 auto shapeX2 = model.GetTensorShape (fNX2 );
82- // assume same shape X1 and X2
84+ // assume same shape X1 and X2
8385 if (shapeX1 != shapeX2) {
8486 std::string msg = " TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
8587 ConvertShapeToString (shapeX1) + " shape 2 is " + ConvertShapeToString (shapeX2);
@@ -102,9 +104,10 @@ public:
102104 // length *= i;
103105 // }
104106 size_t length = ConvertShapeToLength (fShape );
105- out << " \n //------ " + s-> Name ()+" \n " ;
107+ out << " \n //------ " + std::string (BinaryOperatorTrait<T,Op>:: Name () )+" \n " ;
106108 out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
107- out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id]" + s->Op () + " tensor_" << fNX2 << " [id];\n " ;
109+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id]" +
110+ std::string (BinaryOperatorTrait<T,Op>::Op ()) + " tensor_" << fNX2 << " [id];\n " ;
108111 out << SP << " }\n " ;
109112 return out.str ();
110113 }
@@ -116,4 +119,4 @@ public:
116119}// TMVA
117120
118121
119- #endif // TMVA_SOFIE_RModel_BasicBinaryOp
122+ #endif // TMVA_SOFIE_ROperator_BasicBinary
0 commit comments