@@ -80,8 +80,18 @@ class LayoutComputeBufferChwToImageDefault
80
80
81
81
// out info
82
82
std::vector<size_t > new_dims = {1 , 1 , 1 , 1 };
83
- for (int tidx = 0 ; tidx < x_dims.size (); ++tidx) {
84
- new_dims[4 - x_dims.size () + tidx] = x_dims[tidx];
83
+ if (x_dims.size () == 5 ) {
84
+ new_dims[4 - x_dims.size () + 1 ] = x_dims[0 ] * x_dims[1 ];
85
+ for (int tidx = 2 ; tidx < x_dims.size (); ++tidx) {
86
+ new_dims[4 - x_dims.size () + tidx] = x_dims[tidx];
87
+ }
88
+ } else if (x_dims.size () < 5 ) {
89
+ for (int tidx = 0 ; tidx < x_dims.size (); ++tidx) {
90
+ new_dims[4 - x_dims.size () + tidx] = x_dims[tidx];
91
+ }
92
+ } else {
93
+ LOG (FATAL) << " unsupported layout tensor dims size, the dims size is:"
94
+ << x_dims.size ();
85
95
}
86
96
const int out_C = new_dims[1 ];
87
97
const int out_H = new_dims[2 ];
@@ -207,8 +217,18 @@ class LayoutComputeImageDefaultToBufferChw
207
217
auto x_image_shape = InitImageDimInfoWith (x_dims);
208
218
209
219
std::vector<size_t > new_dims = {1 , 1 , 1 , 1 };
210
- for (int j = 0 ; j < x_dims.size (); ++j) {
211
- new_dims[4 - x_dims.size () + j] = x_dims[j];
220
+ if (x_dims.size () == 5 ) {
221
+ new_dims[4 - x_dims.size () + 1 ] = x_dims[0 ] * x_dims[1 ];
222
+ for (int j = 2 ; j < x_dims.size (); ++j) {
223
+ new_dims[4 - x_dims.size () + j] = x_dims[j];
224
+ }
225
+ } else if (x_dims.size () < 5 ) {
226
+ for (int j = 0 ; j < x_dims.size (); ++j) {
227
+ new_dims[4 - x_dims.size () + j] = x_dims[j];
228
+ }
229
+ } else {
230
+ LOG (FATAL) << " unsupported layout tensor dims size, the dims size is: "
231
+ << x_dims.size ();
212
232
}
213
233
214
234
#ifdef LITE_WITH_LOG
@@ -322,8 +342,18 @@ class LayoutComputeBufferChwToImage2DNw
322
342
323
343
// out info
324
344
std::vector<size_t > new_dims = {1 , 1 , 1 , 1 };
325
- for (int tidx = 0 ; tidx < x_dims.size (); ++tidx) {
326
- new_dims[4 - x_dims.size () + tidx] = x_dims[tidx];
345
+ if (x_dims.size () == 5 ) {
346
+ new_dims[4 - x_dims.size () + 1 ] = x_dims[0 ] * x_dims[1 ];
347
+ for (int tidx = 2 ; tidx < x_dims.size (); ++tidx) {
348
+ new_dims[4 - x_dims.size () + tidx] = x_dims[tidx];
349
+ }
350
+ } else if (x_dims.size () < 5 ) {
351
+ for (int tidx = 0 ; tidx < x_dims.size (); ++tidx) {
352
+ new_dims[4 - x_dims.size () + tidx] = x_dims[tidx];
353
+ }
354
+ } else {
355
+ LOG (FATAL) << " unsupported layout tensor dims size, the dims size is:"
356
+ << x_dims.size ();
327
357
}
328
358
329
359
const int out_N = new_dims[0 ];
0 commit comments