@@ -44,9 +44,12 @@ public:
4444 std::vector<std::vector<size_t >> ShapeInference (std::vector<std::vector<size_t >> input){
4545 if (input.size () > 1 ) throw std::runtime_error (" TMVA SOFIE Tranpose Op Shape Inference only need 1 input tensor" );
4646 auto & data = input[0 ];
47+ if (fAttrPerm .size () != data.size () )
48+ throw std::runtime_error (" TMVA SOFIE Tranpose Op - Invalid axes attributes" );
49+
4750 std::vector<size_t > output_shape (fAttrPerm .size ());
4851 for (size_t i = 0 ; i < fAttrPerm .size (); i++){
49- output_shape[fAttrPerm [i]] = data[i ];
52+ output_shape[i] = data[fAttrPerm [i] ];
5053 }
5154 std::vector<std::vector<size_t >> ret;
5255 ret.push_back (output_shape);
@@ -60,18 +63,14 @@ public:
6063 }
6164 fShapeData = model.GetTensorShape (fNData );
6265 if (fAttrPerm .empty ()){
66+ fAttrPerm .reserve (fShapeData .size ());
6367 for (int i = fShapeData .size () - 1 ; i >= 0 ; i--){
6468 fAttrPerm .push_back (i);
6569 }
6670 }
67-
68- std::vector<size_t > output_shape (fAttrPerm .size ());
69- for (size_t i = 0 ; i < fAttrPerm .size (); i++){
70- output_shape[fAttrPerm [i]] = fShapeData [i];
71- }
72-
73- model.AddIntermediateTensor (fNOutput , model.GetTensorType (fNData ), output_shape);
74- fShapeOutput = output_shape;
71+ std::vector<std::vector<size_t >> inputs = { fShapeData };
72+ fShapeOutput = ShapeInference (inputs).front ();
73+ model.AddIntermediateTensor (fNOutput , model.GetTensorType (fNData ), fShapeOutput );
7574 }
7675
7776 std::string Generate (std::string OpName){
@@ -80,32 +79,48 @@ public:
8079 throw std::runtime_error (" TMVA SOFIE Transpose Op called to Generate without being initialized first" );
8180 }
8281 int dim = fShapeData .size ();
83- int length=1 ;
84- std::vector<int > sizeofindex (dim);
85- for (int i = dim - 1 ; i>=0 ; i--){
86- sizeofindex[i] = length;
87- length *= fShapeData [i];
88- }
89- std::vector<int > index_goto (dim);
90- for (int i = 0 ; i < dim; i++){
91- index_goto[fAttrPerm [i]] = i;
92- }
93- std::vector<int > new_sizeofindex (dim);
94- int t = 1 ;
95- for (int i = dim - 1 ; i>=0 ; i--){
96- new_sizeofindex[i] = t;
97- t *= fShapeOutput [i];
98- }
82+ auto inStrides = UTILITY::ComputeStrideFromShape (fShapeData );
83+ auto outStrides = UTILITY::ComputeStrideFromShape (fShapeOutput );
84+ size_t length = inStrides[0 ]*fShapeData [0 ]; // total tensor size
85+ assert (length == outStrides[0 ]*fShapeOutput [0 ]);
9986
10087 std::stringstream out;
88+ // Implement transpose operator using consecutive read inputs.
89+ // But
90+ // tensorOut[id] = tensorInput[ inStrides[0]*i0 + inStrides[1]*i1 + inStrides[2]*i2 + ...]
91+ // now if (j0,j1,j2) are the output indices
92+ // j0 = id / outStrides[0]
93+ // j1 = (id % outStrides[0])/outStrides[1]
94+ // j2 = (id % outStrides[1])/outStrides[2]
95+ // ......
96+ // and we have j_k = i_fAttrPerm[k]
97+ // since we are using consecutive writes we should find the inverse of fAttrPerm
10198 out << SP << " ///------- Transpose operator\n " << std::endl;
102- out << SP << " for (int id = 0; id < " << length << " ; id++){\n " ;
103- out << SP << SP << " tensor_" << fNOutput << " [" ;
104- for (int i =0 ; i < dim; i++){
105- out << " id / " << sizeofindex[i] << " % " << fShapeData [i] << " * " << new_sizeofindex[index_goto[i]];
106- if (i != dim - 1 ) out << " + " ;
99+ out << SP << " for (size_t id = 0; id < " << length << " ; id++){\n " ;
100+ out << SP << SP << " tensor_" << fNOutput << " [id] = tensor_" << fNData << " [ " ;
101+ // compute output j indices
102+ std::vector<std::string> i_out (dim);
103+ for (int k =0 ; k < dim; k++){
104+ if (k == 0 )
105+ i_out[k] = " id" ;
106+ else
107+ i_out[k] = " (id % " + std::to_string (outStrides[k-1 ]) + " )" ;
108+ if (k < dim-1 )
109+ i_out[k] += " / " + std::to_string (outStrides[k]);
110+ }
111+ // use now them for input tensors
112+ // need to invert the fAttrPerm[k]
113+ for (int k =0 ; k < dim; k++){
114+ // find value in fAtrrPerm corresponding to k
115+ int l = std::find (fAttrPerm .begin (), fAttrPerm .end (), k) - fAttrPerm .begin ();
116+ assert (l > 0 && l < dim);
117+ out << " ( " << i_out[l] << " )" ;
118+ if (k < dim-1 ) {
119+ out << " * " << inStrides[k];
120+ out << " + " ;
121+ }
107122 }
108- out << " ] = " << " tensor_ " << fNData << " [id] ;\n " ;
123+ out << " ];\n " ;
109124 out << SP << " }\n " ;
110125 return out.str ();
111126 }
0 commit comments