1- #ifndef TMVA_SOFIE_ROPERATOR_ADD
2- #define TMVA_SOFIE_ROPERATOR_ADD
1+ #ifndef TMVA_SOFIE_ROperator_BasicBinary
2+ #define TMVA_SOFIE_ROperator_BasicBinary
33
44#include " TMVA/SOFIE_common.hxx"
55#include " TMVA/ROperator.hxx"
@@ -11,23 +11,55 @@ namespace TMVA{
1111namespace Experimental {
1212namespace SOFIE {
1313
14+ enum EBasicBinaryOperator { Add, Sub, Mul, Div };
15+
16+ template <typename T, EBasicBinaryOperator Op1>
17+ struct BinaryOperatorTrait {
18+ const char *Name () { return " " ; }
19+ const char *Op () { return " " ; }
20+ };
21+ template <typename T>
22+ struct BinaryOperatorTrait <T, Add> {
23+ static const char *Name () { return " Add" ; }
24+ static const char *Op () { return " +" ; }
25+ };
26+
27+ template <typename T>
28+ struct BinaryOperatorTrait <T, Sub> {
29+ static const char *Name () { return " Sub" ; }
30+ static const char *Op () { return " -" ; }
31+ };
32+
1433template <typename T>
15- class ROperator_Add final : public ROperator
16- {
34+ struct BinaryOperatorTrait <T, Mul> {
35+ static const char *Name () { return " Mul" ; }
36+ static const char *Op () { return " *" ; }
37+ };
1738
39+ template <typename T>
40+ struct BinaryOperatorTrait <T, Div> {
41+ static const char *Name () { return " Div" ; }
42+ static const char *Op () { return " /" ; }
43+ };
44+
45+ template <typename T, EBasicBinaryOperator Op>
46+ class ROperator_BasicBinary final : public ROperator{
1847private:
1948
2049 std::string fNX1 ;
2150 std::string fNX2 ;
2251 std::string fNY ;
2352 std::vector<size_t > fShape ;
2453
54+ // template <typename T, EBasicBinaryOperator Op1>
55+ // BinaryOperatorTrait<T,Op1> *s;
56+
2557public:
26- ROperator_Add (){}
27- ROperator_Add (std::string nameX1, std::string nameX2, std::string nameY):
58+ ROperator_BasicBinary (){}
59+ ROperator_BasicBinary (std::string nameX1, std::string nameX2, std::string nameY):
2860 fNX1 (UTILITY::Clean_name(nameX1)), fNX2 (UTILITY::Clean_name(nameX2)), fNY (UTILITY::Clean_name(nameY)){}
2961
30- // type of output given input
62+ // type of output given input
3163 std::vector<ETensorType> TypeInference (std::vector<ETensorType> input){
3264 return input;
3365 }
@@ -42,16 +74,16 @@ public:
4274 void Initialize (RModel& model){
4375 // input must be a graph input, or already initialized intermediate tensor
4476 if (model.CheckIfTensorAlreadyExist (fNX1 ) == false ){
45- throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX1 + " is not found in model" );
77+ throw std::runtime_error (std::string (" TMVA SOFIE Binary Op Input Tensor " ) + fNX1 + " is not found in model" );
4678 }
4779 if (model.CheckIfTensorAlreadyExist (fNX2 ) == false ) {
48- throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX1 + " is not found in model" );
80+ throw std::runtime_error (std::string (" TMVA SOFIE Binary Op Input Tensor " ) + fNX2 + " is not found in model" );
4981 }
5082 auto shapeX1 = model.GetTensorShape (fNX1 );
5183 auto shapeX2 = model.GetTensorShape (fNX2 );
52- // assume same shape X1 and X2
84+ // assume same shape X1 and X2
5385 if (shapeX1 != shapeX2) {
54- std::string msg = " TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
86+ std::string msg = " TMVA SOFIE Binary Op: Support only inputs with same shape, shape 1 is " +
5587 ConvertShapeToString (shapeX1) + " shape 2 is " + ConvertShapeToString (shapeX2);
5688 throw std::runtime_error (msg);
5789 }
@@ -62,18 +94,20 @@ public:
6294
6395 std::string Generate (std::string OpName){
6496 OpName = " op_" + OpName;
97+
6598 if (fShape .empty ()) {
66- throw std::runtime_error (" TMVA SOFIE Add called to Generate without being initialized first" );
99+ throw std::runtime_error (" TMVA SOFIE Binary Op called to Generate without being initialized first" );
67100 }
68101 std::stringstream out;
69102 // int length = 1;
70103 // for(auto& i: fShape){
71104 // length *= i;
72105 // }
73106 size_t length = ConvertShapeToLength (fShape );
74- out << " \n //------ Add \n " ;
107+ out << " \n //------ " + std::string (BinaryOperatorTrait<T,Op>:: Name ())+ " \n " ;
75108 out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
76- out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id] + tensor_" << fNX2 << " [id];\n " ;
109+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id]" +
110+ std::string (BinaryOperatorTrait<T,Op>::Op ()) + " tensor_" << fNX2 << " [id];\n " ;
77111 out << SP << " }\n " ;
78112 return out.str ();
79113 }
@@ -85,4 +119,4 @@ public:
85119}// TMVA
86120
87121
88- #endif // TMVA_SOFIE_ROPERATOR_Add
122+ #endif // TMVA_SOFIE_ROperator_BasicBinary
0 commit comments