Skip to content

Commit 35cabd3

Browse files
authored
[OpenCL]support layout for dims size 5 (#5778)
* support layout for dims size 5 test=develop * add unsupport cue test=develop
1 parent e448fff commit 35cabd3

File tree

1 file changed

+36
-6
lines changed

1 file changed

+36
-6
lines changed

lite/kernels/opencl/layout_image_compute.cc

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,18 @@ class LayoutComputeBufferChwToImageDefault
8080

8181
// out info
8282
std::vector<size_t> new_dims = {1, 1, 1, 1};
83-
for (int tidx = 0; tidx < x_dims.size(); ++tidx) {
84-
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
83+
if (x_dims.size() == 5) {
84+
new_dims[4 - x_dims.size() + 1] = x_dims[0] * x_dims[1];
85+
for (int tidx = 2; tidx < x_dims.size(); ++tidx) {
86+
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
87+
}
88+
} else if (x_dims.size() < 5) {
89+
for (int tidx = 0; tidx < x_dims.size(); ++tidx) {
90+
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
91+
}
92+
} else {
93+
LOG(FATAL) << "unsupported layout tensor dims size, the dims size is:"
94+
<< x_dims.size();
8595
}
8696
const int out_C = new_dims[1];
8797
const int out_H = new_dims[2];
@@ -207,8 +217,18 @@ class LayoutComputeImageDefaultToBufferChw
207217
auto x_image_shape = InitImageDimInfoWith(x_dims);
208218

209219
std::vector<size_t> new_dims = {1, 1, 1, 1};
210-
for (int j = 0; j < x_dims.size(); ++j) {
211-
new_dims[4 - x_dims.size() + j] = x_dims[j];
220+
if (x_dims.size() == 5) {
221+
new_dims[4 - x_dims.size() + 1] = x_dims[0] * x_dims[1];
222+
for (int j = 2; j < x_dims.size(); ++j) {
223+
new_dims[4 - x_dims.size() + j] = x_dims[j];
224+
}
225+
} else if (x_dims.size() < 5) {
226+
for (int j = 0; j < x_dims.size(); ++j) {
227+
new_dims[4 - x_dims.size() + j] = x_dims[j];
228+
}
229+
} else {
230+
LOG(FATAL) << "unsupported layout tensor dims size, the dims size is: "
231+
<< x_dims.size();
212232
}
213233

214234
#ifdef LITE_WITH_LOG
@@ -322,8 +342,18 @@ class LayoutComputeBufferChwToImage2DNw
322342

323343
// out info
324344
std::vector<size_t> new_dims = {1, 1, 1, 1};
325-
for (int tidx = 0; tidx < x_dims.size(); ++tidx) {
326-
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
345+
if (x_dims.size() == 5) {
346+
new_dims[4 - x_dims.size() + 1] = x_dims[0] * x_dims[1];
347+
for (int tidx = 2; tidx < x_dims.size(); ++tidx) {
348+
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
349+
}
350+
} else if (x_dims.size() < 5) {
351+
for (int tidx = 0; tidx < x_dims.size(); ++tidx) {
352+
new_dims[4 - x_dims.size() + tidx] = x_dims[tidx];
353+
}
354+
} else {
355+
LOG(FATAL) << "unsupported layout tensor dims size, the dims size is:"
356+
<< x_dims.size();
327357
}
328358

329359
const int out_N = new_dims[0];

0 commit comments

Comments
 (0)