1+ #ifndef TMVA_SOFIE_RModel_BasicBinaryOp
2+ #define TMVA_SOFIE_RModel_BasicBinaryOp
3+
4+ #include " TMVA/SOFIE_common.hxx"
5+ #include " TMVA/ROperator.hxx"
6+ #include " TMVA/RModel.hxx"
7+
8+ #include < sstream>
9+
10+ namespace TMVA {
11+ namespace Experimental {
12+ namespace SOFIE {
13+
14+ enum EBasicBinaryOperator { Add, Sub, Mul, Div };
15+ template <typename T, EBasicBinaryOperator Op>
16+ struct BinaryOperatorTrait {
17+ const char * Name () { return " " ;}
18+ const char * Op () { return " " ;}
19+ };
20+
21+ class RModel_BasicBinaryOp final : public ROperator{
22+ private:
23+
24+ std::string fNX1 ;
25+ std::string fNX2 ;
26+ std::string fNY ;
27+ std::vector<size_t > fShape ;
28+ template <typename T>
29+ struct BinaryOperatorTrait <T,Add> {
30+ const char * Name () { return " Add" ;}
31+ const char * Op () { return " +" ;}
32+ };
33+
34+ template <typename T>
35+ struct BinaryOperatorTrait <T,Sub> {
36+ const char * Name () { return " Sub" ;}
37+ const char * Op () { return " -" ;}
38+ };
39+
40+ template <typename T>
41+ struct BinaryOperatorTrait <T,Mul> {
42+ const char * Name () { return " Mul" ;}
43+ const char * Op () { return " *" ;}
44+ };
45+
46+ template <typename T>
47+ struct BinaryOperatorTrait <T,Div> {
48+ const char * Name () { return " Div" ;}
49+ const char * Op () { return " /" ;}
50+ };
51+ template <typename T, EBasicBinaryOperator Op>
52+ BinaryOperatorTrait<T,Op> *s;
53+
54+ public:
55+ RModel_BasicBinaryOp (){}
56+ RModel_BasicBinaryOp (std::string nameX1, std::string nameX2, std::string nameY):
57+ fNX1 (UTILITY::Clean_name(nameX1)), fNX2 (UTILITY::Clean_name(nameX2)), fNY (UTILITY::Clean_name(nameY)){}
58+
59+ // type of output given input
60+ std::vector<ETensorType> TypeInference (std::vector<ETensorType> input){
61+ return input;
62+ }
63+
64+ // shape of output tensors given input tensors
65+ std::vector<std::vector<size_t >> ShapeInference (std::vector<std::vector<size_t >> input){
66+ // assume now inputs have same shape (no broadcasting)
67+ auto ret = std::vector<std::vector<size_t >>(1 , input[0 ]); // return vector size 1 with first input
68+ return ret;
69+ }
70+
71+ void Initialize (RModel& model){
72+ // input must be a graph input, or already initialized intermediate tensor
73+ if (model.CheckIfTensorAlreadyExist (fNX1 ) == false ){
74+ throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX1 + " is not found in model" );
75+ }
76+ if (model.CheckIfTensorAlreadyExist (fNX2 ) == false ) {
77+ throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX2 + " is not found in model" );
78+ }
79+ auto shapeX1 = model.GetTensorShape (fNX1 );
80+ auto shapeX2 = model.GetTensorShape (fNX2 );
81+ // assume same shape X1 and X2
82+ if (shapeX1 != shapeX2) {
83+ std::string msg = " TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
84+ ConvertShapeToString (shapeX1) + " shape 2 is " + ConvertShapeToString (shapeX2);
85+ throw std::runtime_error (msg);
86+ }
87+ fShape = shapeX1;
88+ model.AddIntermediateTensor (fNY , model.GetTensorType (fNX1 ), fShape );
89+ }
90+
91+
92+ std::string Generate (std::string OpName){
93+ OpName = " op_" + OpName;
94+
95+ if (fShape .empty ()) {
96+ throw std::runtime_error (" TMVA SOFIE Binary Op called to Generate without being initialized first" );
97+ }
98+ std::stringstream out;
99+ // int length = 1;
100+ // for(auto& i: fShape){
101+ // length *= i;
102+ // }
103+ size_t length = ConvertShapeToLength (fShape );
104+ out << " \n //------ " + s->Name ()+" \n " ;
105+ out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
106+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id]" + s->Op () + " tensor_" << fNX2 << " [id];\n " ;
107+ out << SP << " }\n " ;
108+ return out.str ();
109+ }
110+
111+ };
112+
113+ }// SOFIE
114+ }// Experimental
115+ }// TMVA
116+
117+
118+ #endif // TMVA_SOFIE_RModel_BasicBinaryOp
0 commit comments