Skip to content

Commit 4f71488

Browse files
committed
Added the modified ROperator_Add.hxx to resolve merge conflicts.
1 parent 44c20d2 commit 4f71488

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#ifndef TMVA_SOFIE_ROPERATOR_ADD
2+
#define TMVA_SOFIE_ROPERATOR_ADD
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
template <typename T>
15+
class ROperator_Add final : public ROperator
16+
{
17+
18+
private:
19+
20+
std::string fNX1;
21+
std::string fNX2;
22+
std::string fNY;
23+
std::vector<size_t> fShape;
24+
25+
public:
26+
ROperator_Add(){}
27+
ROperator_Add(std::string nameX1, std::string nameX2, std::string nameY):
28+
fNX1(UTILITY::Clean_name(nameX1)), fNX2(UTILITY::Clean_name(nameX2)), fNY(UTILITY::Clean_name(nameY)){}
29+
30+
// type of output given input
31+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
32+
return input;
33+
}
34+
35+
// shape of output tensors given input tensors
36+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
37+
// assume now inputs have same shape (no broadcasting)
38+
auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
39+
return ret;
40+
}
41+
42+
void Initialize(RModel& model){
43+
// input must be a graph input, or already initialized intermediate tensor
44+
if (model.CheckIfTensorAlreadyExist(fNX1) == false){
45+
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX1 + "is not found in model");
46+
}
47+
if (model.CheckIfTensorAlreadyExist(fNX2) == false) {
48+
throw std::runtime_error(std::string("TMVA SOFIE Add Op Input Tensor ") + fNX2 + "is not found in model");
49+
}
50+
auto shapeX1 = model.GetTensorShape(fNX1);
51+
auto shapeX2 = model.GetTensorShape(fNX2);
52+
// assume same shape X1 and X2
53+
if (shapeX1 != shapeX2) {
54+
std::string msg = "TMVA SOFIE Add Op: Support only inputs with same shape, shape 1 is " +
55+
ConvertShapeToString(shapeX1) + "and shape 2 is " + ConvertShapeToString(shapeX2);
56+
throw std::runtime_error(msg);
57+
}
58+
fShape = shapeX1;
59+
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX1), fShape);
60+
}
61+
62+
63+
std::string Generate(std::string OpName){
64+
OpName = "op_" + OpName;
65+
if (fShape.empty()) {
66+
throw std::runtime_error("TMVA SOFIE Add called to Generate without being initialized first");
67+
}
68+
std::stringstream out;
69+
// int length = 1;
70+
// for(auto& i: fShape){
71+
// length *= i;
72+
// }
73+
size_t length = ConvertShapeToLength(fShape);
74+
out << "\n//------ Add\n";
75+
out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n";
76+
out << SP << SP << "tensor_" << fNY << "[id] = tensor_" << fNX1 << "[id] + tensor_" << fNX2 << "[id];\n";
77+
out << SP << "}\n";
78+
return out.str();
79+
}
80+
81+
};
82+
83+
}//SOFIE
84+
}//Experimental
85+
}//TMVA
86+
87+
88+
#endif //TMVA_SOFIE_ROPERATOR_Add

0 commit comments

Comments
 (0)