@@ -128,14 +128,14 @@ public:
128128 }
129129 }
130130
131- std::string Generate (std::string OpName ) override {
131+ std::string Generate (std::string opName ) override {
132132 if (fIsOutputConstant ) {
133133 // no code to generate here for constant output. Tensor output is defined in Session constructor
134134 return " //---------------------------------------\n " ;
135135 }
136- OpName = " op_" + OpName ;
136+ opName = " op_" + opName ;
137137 std::stringstream out;
138- out << " //--------- Gather operator \n " ;
138+ out << " //--------- Gather " << opName << " --> " << ConvertShapeToString ( fShapeY ) << " \n " ;
139139 // The shape of the output is q + r - 1
140140 size_t r = fShapeX .size ();
141141 // Indices of shape q
@@ -157,19 +157,17 @@ public:
157157 out << SP << " }\n " ;
158158 }
159159
160-
161160 // Fill the output Y[j_0, j_1, ..., j_{axis - 1}, i_0, i_1, ..., i_{q - 1}, j_{axis + 1}, ..., j_{r - 1}]
162161 // [0 ... axis) [axis ... axis + q) [axis + q ... q + r - 1)
163162 // iterate in [0 ... axis) [0 ... q) [axis ... r - 1)
164163 // for j_0, j_1, ..., j_{axis-1}
164+
165165 for (size_t j = 0 ; j < size_t (fAttrAxis ); j++) {
166166 std::string index = " j_" + std::to_string (j);
167167 for (size_t k = 0 ; k <= j; k++) out << SP;
168168 out << " for (size_t " << index << " = 0; " << index << " < " << fShapeY [j] << " ; " << index << " ++) {\n " ;
169169 }
170170 // for i_0, i_1, ..., i_{q - 1}
171- if (q == 0 && fAttrAxis == 0 )
172- out << SP << " {\n " ; // add a scope for local variables in case is not in a for loop
173171 for (size_t i = 0 ; i < q; i++) {
174172 std::string index = " i_" + std::to_string (i);
175173 for (size_t k = 0 ; k <= i + fAttrAxis ; k++) out << SP;
@@ -182,6 +180,10 @@ public:
182180 out << " for (size_t " << index << " = 0; " << index << " < " << fShapeY [q + j] << " ; " << index << " ++) {\n " ;
183181 }
184182
183+ // add a scope for local variables in case above loop are not done
184+ if (fAttrAxis == 0 && q == 0 && r <= 1 )
185+ out << SP << " { // scalar case \n " ;
186+
185187 // output index
186188 for (size_t k = 0 ; k < q + r; k++) out << SP;
187189 out << " size_t y_index = " ;
@@ -200,6 +202,9 @@ public:
200202 out << " j_" << q+j;
201203 if (stridesY[q+j].dim != 1 ) out << " * " << stridesY[q+j];
202204 }
205+ // empty case
206+ if (fAttrAxis == 0 && q == 0 && r <= 1 )
207+ out << " 0" ;
203208 out << " ;\n " ;
204209
205210 // input Indices
@@ -210,7 +215,11 @@ public:
210215 out << " i_" << i;
211216 if (stridesIndices[i].dim != 1 ) out << " * " << stridesIndices[i];
212217 }
218+ // empty case
219+ if (q == 0 )
220+ out << " 0" ;
213221 out << " ;\n " ;
222+
214223 // K
215224 for (size_t k = 0 ; k < q + r; k++) out << SP;
216225 out << " size_t k = static_cast<size_t>(" << " tensor_" << fNIndices << " [i_index]" << " );\n " ;
@@ -239,6 +248,9 @@ public:
239248 for (size_t k = 0 ; k <j; k++) out << SP;
240249 out << " }\n " ;
241250 }
251+ // close empty scope if it was opened
252+ if (q == 0 && fAttrAxis == 0 && r <= 1 )
253+ out << SP << " } // close Gather scope for scalar case \n " ;
242254
243255
244256 return out.str ();
0 commit comments