1+ #ifndef TMVA_SOFIE_ROPERATOR_Softmax
2+ #define TMVA_SOFIE_ROPERATOR_Softmax
3+
4+ #include " TMVA/SOFIE_common.hxx"
5+ #include " TMVA/ROperator.hxx"
6+ #include " TMVA/RModel.hxx"
7+
8+ #include < sstream>
9+
10+ namespace TMVA {
11+ namespace Experimental {
12+ namespace SOFIE {
13+
14+ template <typename T>
15+ class ROperator_Softmax final : public ROperator
16+ {
17+
18+ private:
19+
20+ std::string fNX ;
21+ std::string fNY ;
22+ std::vector<size_t > fShape ;
23+
24+ public:
25+ ROperator_Softmax (){}
26+ ROperator_Softmax (std::string nameX, std::string nameY):
27+ fNX (UTILITY::Clean_name(nameX)), fNY (UTILITY::Clean_name(nameY)){}
28+
29+ std::vector<ETensorType> TypeInference (std::vector<ETensorType> input){
30+ return input;
31+ }
32+
33+ std::vector<std::vector<size_t >> ShapeInference (std::vector<std::vector<size_t >> input){
34+ auto ret = input; // suggest copy to compiler
35+ return ret;
36+ }
37+
38+ void Initialize (RModel& model){
39+ if (model.CheckIfTensorAlreadyExist (fNX ) == false ){ // input must be a graph input, or already initialized intermediate tensor
40+ throw std::runtime_error (" TMVA SOFIE Softmax Op Input Tensor is not found in model" );
41+ }
42+ fShape = model.GetTensorShape (fNX );
43+ model.AddIntermediateTensor (fNY , model.GetTensorType (fNX ), fShape );
44+ }
45+
46+
47+ std::string Generate (std::string OpName){
48+ OpName = " op_" + OpName;
49+ if (fShape .empty ()){
50+ throw std::runtime_error (" TMVA SOFIE Transpose Softmax called to Generate without being initialized first" );
51+ }
52+ std::stringstream out;
53+ int length = 1 ;
54+ for (auto & i: fShape ){
55+ length *= i;
56+ }
57+ out << " \n //------ SOFTMAX\n " ;
58+ out << SP << " double sum = 0.0;\n " ;
59+ out << SP << " for (int id = 0; id < " << length << " ; id++){\n " ;
60+ out << SP << SP << " tensor_" << fNY << " [id] = std::exp( - tensor_" << fNX << " [id]);\n " ;
61+ out << SP << SP << " sum += tensor_" << fNY << " [id];\n " ;
62+ out << SP << " }\n " ;
63+ out << SP << " for (int id = 0; id < " << length << " ; id++){\n " ;
64+ out << SP << SP << " tensor_" << fNY << " [id] /= sum;\n " ;
65+ out << SP << " }\n " ;
66+ return out.str ();
67+ }
68+
69+ };
70+
71+ }// SOFIE
72+ }// Experimental
73+ }// TMVA
74+
75+
76+ #endif // TMVA_SOFIE_ROPERATOR_Softmax
0 commit comments