1+ #ifndef TMVA_SOFIE_ROPERATOR_ADD
2+ #define TMVA_SOFIE_ROPERATOR_ADD
3+
4+ #include " TMVA/SOFIE_common.hxx"
5+ #include " TMVA/ROperator.hxx"
6+ #include " TMVA/RModel.hxx"
7+
8+ #include < sstream>
9+
10+ namespace TMVA {
11+ namespace Experimental {
12+ namespace SOFIE {
13+
14+ template <typename T>
15+ class ROperator_Add final : public ROperator
16+ {
17+
18+ private:
19+
20+ std::string fNX1 ;
21+ std::string fNX2 ;
22+ std::string fNY ;
23+ std::vector<size_t > fShape ;
24+
25+ public:
26+ ROperator_Add (){}
27+ ROperator_Add (std::string nameX1, std::string nameX2, std::string nameY):
28+ fNX1 (UTILITY::Clean_name(nameX1)), fNX2 (UTILITY::Clean_name(nameX2)), fNY (UTILITY::Clean_name(nameY)){}
29+
30+ // type of output given input
31+ std::vector<ETensorType> TypeInference (std::vector<ETensorType> input){
32+ return input;
33+ }
34+
35+ // shape of output tensors given input tensors
36+ std::vector<std::vector<size_t >> ShapeInference (std::vector<std::vector<size_t >> input){
37+ // assume now inputs have same shape (no broadcasting)
38+ auto ret = std::vector<std::vector<size_t >>(1 , input[0 ]); // return vector size 1 with first input
39+ return ret;
40+ }
41+
42+ void Initialize (RModel& model){
43+ // input must be a graph input, or already initialized intermediate tensor
44+ if (model.CheckIfTensorAlreadyExist (fNX1 ) == false ){
45+ throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX1 + " is not found in model" );
46+ }
47+ if (model.CheckIfTensorAlreadyExist (fNX2 ) == false ) {
48+ throw std::runtime_error (std::string (" TMVA SOFIE Add Op Input Tensor " ) + fNX2 + " is not found in model" );
49+ }
50+ auto shapeX1 = model.GetTensorShape (fNX1 );
51+ auto shapeX2 = model.GetTensorShape (fNX2 );
52+ // assume same shape X1 and X2
53+ if (shapeX1 != shapeX2) {
54+ std::string msg = " TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
55+ ConvertShapeToString (shapeX1) + " and shape 2 is " + ConvertShapeToString (shapeX2);
56+ throw std::runtime_error (msg);
57+ }
58+ fShape = shapeX1;
59+ model.AddIntermediateTensor (fNY , model.GetTensorType (fNX1 ), fShape );
60+ }
61+
62+
63+ std::string Generate (std::string OpName){
64+ OpName = " op_" + OpName;
65+ if (fShape .empty ()) {
66+ throw std::runtime_error (" TMVA SOFIE Add called to Generate without being initialized first" );
67+ }
68+ std::stringstream out;
69+ // int length = 1;
70+ // for(auto& i: fShape){
71+ // length *= i;
72+ // }
73+ size_t length = ConvertShapeToLength (fShape );
74+ out << " \n //------ Add\n " ;
75+ out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
76+ out << SP << SP << " tensor_" << fNY << " [id] = tensor_" << fNX1 << " [id] + tensor_" << fNX2 << " [id];\n " ;
77+ out << SP << " }\n " ;
78+ return out.str ();
79+ }
80+
81+ };
82+
83+ }// SOFIE
84+ }// Experimental
85+ }// TMVA
86+
87+
88+ #endif // TMVA_SOFIE_ROPERATOR_Add
0 commit comments