Skip to content

Commit 81ecf2d

Browse files
authored
[OpenCL] Add mali conv 1x1 opt: 3 tiling methods (#6062) (#6067)
1 parent 99fafb8 commit 81ecf2d

File tree

5 files changed

+703
-79
lines changed

5 files changed

+703
-79
lines changed

lite/backends/opencl/cl_image_converter.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,70 @@ void CLImageConverterNBlock::ImageToNCHW(void *image,
563563
const DDim &image_dim,
564564
const DDim &tensor_dim) {}
565565

566+
DDim CLImageConverterN2Block::InitImageDimInfoWith(const DDim &tensor_dim) {
567+
CHECK(tensor_dim.size() == 4) << " Tensor dim is not 4.";
568+
size_t N, C, H, W;
569+
N = tensor_dim[0];
570+
C = tensor_dim[1];
571+
H = tensor_dim[2];
572+
W = tensor_dim[3];
573+
size_t width = (C + 3) / 4 * 2 * 4;
574+
size_t height = ((N + 7) / 8) * H * W;
575+
return DDim(
576+
std::vector<DDim::value_type>({static_cast<DDim::value_type>(width),
577+
static_cast<DDim::value_type>(height)}));
578+
}
579+
580+
void CLImageConverterN2Block::NCHWToImage(float *nchw,
581+
void *image,
582+
const DDim &tensor_dim) {
583+
CHECK(tensor_dim.size() == 4) << " Tensor dim is not 4.";
584+
size_t N, C, H, W;
585+
N = tensor_dim[0];
586+
C = tensor_dim[1];
587+
H = tensor_dim[2];
588+
W = tensor_dim[3];
589+
590+
DDim in_image_dim = InitImageDimInfoWith(tensor_dim);
591+
592+
VLOG(3) << " tensor dim: " << tensor_dim;
593+
VLOG(3) << " image dim: " << in_image_dim;
594+
595+
size_t height = in_image_dim[1];
596+
size_t n_block = height / (W * H);
597+
size_t c_block = (C + 3) / 4;
598+
599+
float *image_fp32 = static_cast<float *>(image);
600+
half_t *image_fp16 = static_cast<half_t *>(image);
601+
602+
float *p = nchw;
603+
size_t i0 = 0;
604+
for (size_t n = 0; n < n_block * 8; n++) {
605+
for (size_t c = 0; c < c_block * 4; c++) {
606+
for (size_t h = 0; h < H; h++) {
607+
for (size_t w = 0; w < W; w++) {
608+
size_t img_idx = ((n / 8) * W * H + h * W + w) * c_block * 4 * 8 +
609+
(c / 4) * 32 + ((n % 8) / 4) * 16 + (c % 4) * 4 +
610+
(n % 8) % 4;
611+
if (n < N && c < C) {
612+
fp16_support_ ? image_fp16[img_idx] = Float2Half(*p)
613+
: image_fp32[img_idx] = *p;
614+
p++;
615+
} else {
616+
fp16_support_ ? image_fp16[img_idx] = Float2Half(0.f)
617+
: image_fp32[img_idx] = 0.f;
618+
}
619+
}
620+
}
621+
}
622+
}
623+
}
624+
625+
void CLImageConverterN2Block::ImageToNCHW(void *image,
626+
float *tensor,
627+
const DDim &image_dim,
628+
const DDim &tensor_dim) {}
629+
566630
DDim CLImageConverterDWFilter::InitImageDimInfoWith(const DDim &tensor_dim) {
567631
CHECK(tensor_dim.size() == 4) << " Tensor dim is not 4.";
568632
size_t N, C, H, W;

lite/backends/opencl/cl_image_converter.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,16 @@ class CLImageConverterNBlock : public CLImageConverterBase {
143143
const DDim &tensor_dim) override;
144144
};
145145

146+
class CLImageConverterN2Block : public CLImageConverterBase {
147+
public:
148+
DDim InitImageDimInfoWith(const DDim &tensor_dim) override;
149+
void NCHWToImage(float *tensor, void *image, const DDim &tensor_dim) override;
150+
void ImageToNCHW(void *image,
151+
float *tensor,
152+
const DDim &image_dim,
153+
const DDim &tensor_dim) override;
154+
};
155+
146156
class CLImageConverterDWFilter : public CLImageConverterBase {
147157
public:
148158
DDim InitImageDimInfoWith(const DDim &tensor_dim) override;

0 commit comments

Comments
 (0)