Skip to content

Commit 71ea6b3

Browse files
committed
[tmva][sofie] Add Softmax operator
Use initial code from GSOC candidate student Neel Shah
1 parent 1c320c1 commit 71ea6b3

File tree

1 file changed

+76
-0
lines changed

1 file changed

+76
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#ifndef TMVA_SOFIE_ROPERATOR_Softmax
2+
#define TMVA_SOFIE_ROPERATOR_Softmax
3+
4+
#include "TMVA/SOFIE_common.hxx"
5+
#include "TMVA/ROperator.hxx"
6+
#include "TMVA/RModel.hxx"
7+
8+
#include <sstream>
9+
10+
namespace TMVA{
11+
namespace Experimental{
12+
namespace SOFIE{
13+
14+
template <typename T>
15+
class ROperator_Softmax final : public ROperator
16+
{
17+
18+
private:
19+
20+
std::string fNX;
21+
std::string fNY;
22+
std::vector<size_t> fShape;
23+
24+
public:
25+
ROperator_Softmax(){}
26+
ROperator_Softmax(std::string nameX, std::string nameY):
27+
fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){}
28+
29+
std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
30+
return input;
31+
}
32+
33+
std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
34+
auto ret = input; //suggest copy to compiler
35+
return ret;
36+
}
37+
38+
void Initialize(RModel& model){
39+
if (model.CheckIfTensorAlreadyExist(fNX) == false){ //input must be a graph input, or already initialized intermediate tensor
40+
throw std::runtime_error("TMVA SOFIE Softmax Op Input Tensor is not found in model");
41+
}
42+
fShape = model.GetTensorShape(fNX);
43+
model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShape);
44+
}
45+
46+
47+
std::string Generate(std::string OpName){
48+
OpName = "op_" + OpName;
49+
if (fShape.empty()){
50+
throw std::runtime_error("TMVA SOFIE Transpose Softmax called to Generate without being initialized first");
51+
}
52+
std::stringstream out;
53+
int length = 1;
54+
for(auto& i: fShape){
55+
length *= i;
56+
}
57+
out << "\n//------ SOFTMAX\n";
58+
out << SP << "double sum = 0.0;\n";
59+
out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
60+
out << SP << SP << "tensor_" << fNY << "[id] = std::exp( - tensor_" << fNX << "[id]);\n";
61+
out << SP << SP << "sum += tensor_" << fNY << "[id];\n";
62+
out << SP << "}\n";
63+
out << SP << "for (int id = 0; id < " << length << " ; id++){\n";
64+
out << SP << SP << "tensor_" << fNY << "[id] /= sum;\n";
65+
out << SP << "}\n";
66+
return out.str();
67+
}
68+
69+
};
70+
71+
}//SOFIE
72+
}//Experimental
73+
}//TMVA
74+
75+
76+
#endif //TMVA_SOFIE_ROPERATOR_Softmax

0 commit comments

Comments
 (0)