Skip to content

Commit 1c320c1

Browse files
committed
[tmva][sofie] Add Indentity operator
1 parent 5772991 commit 1c320c1

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#ifndef TMVA_SOFIE_ROPERATOR_IDENTITY
2+
#define TMVA_SOFIE_ROPERATOR_IDENTITY
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
template <typename T>
15+
class ROperator_Identity final : public ROperator
16+
{
17+
18+
private:
19+
20+
std::string fNX;
21+
std::string fNY;
22+
std::vector<size_t> fShape;
23+
24+
public:
25+
ROperator_Identity(){}
26+
ROperator_Identity(std::string nameX, std::string nameY):
27+
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){}
28+
29+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
30+
return input;
31+
}
32+
33+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
34+
auto ret = input; //suggest copy to compiler
35+
return ret;
36+
}
37+
38+
void Initialize(RModel& model){
39+
//input must be a graph input, or already initialized intermediate tensor
40+
if (model.CheckIfTensorAlreadyExist(fNX) == false){
41+
throw std::runtime_error("TMVA SOFIE Identity Op Input Tensor is not found in model");
42+
}
43+
fShape = model.GetTensorShape(fNX);
44+
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
45+
}
46+
47+
48+
std::string Generate(std::string OpName){
49+
OpName = "op_" + OpName;
50+
if (fShape.empty()) {
51+
throw std::runtime_error("TMVA SOFIE Transpose Identity called to Generate without being initialized first");
52+
}
53+
std::stringstream out;
54+
int length = 1;
55+
for(auto& i: fShape){
56+
length *= i;
57+
}
58+
out << "\n//------ IDENTITY\n";
59+
// just copy the tensor pointers
60+
out << SP << SP << "tensor_" << fNY << " = tensor_" << fNX << ";\n";
61+
return out.str();
62+
}
63+
64+
};
65+
66+
}//SOFIE
67+
}//Experimental
68+
}//TMVA
69+
70+
71+
#endif //TMVA_SOFIE_ROPERATOR_IDENTITY

0 commit comments

Comments
 (0)