Skip to content

Commit e448fff

Browse files
[Cherry-pick][OpenCL][kernel] Support pad_left != pad_right or pad_up != pad_down (#5763)
1 parent c5a3d15 commit e448fff

File tree

2 files changed

+12
-26
lines changed

2 files changed

+12
-26
lines changed

lite/backends/opencl/cl_kernel/image/pool_kernel.cl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,18 @@ __kernel void pool_max(__read_only image2d_t input,
2424
__private const int ksize_w,
2525
__private const int stride_h,
2626
__private const int stride_w,
27-
__private const int pad_top,
28-
__private const int pad_left) {
27+
__private const int4 pad) {
2928
const int out_c = get_global_id(0);
3029
const int out_w = get_global_id(1);
3130
const int out_nh = get_global_id(2);
3231
const int out_n = out_nh / out_height;
3332
const int out_h = out_nh % out_height;
3433

35-
int start_h = out_h * stride_h - pad_top;
34+
int start_h = out_h * stride_h - (pad.x - pad.y);
3635
int end_h = min(start_h + ksize_h, in_height);
3736
start_h = max(start_h, 0);
3837

39-
int start_w = out_w * stride_w - pad_left;
38+
int start_w = out_w * stride_w - (pad.w - pad.z);
4039
int end_w = min(start_w + ksize_w, in_width);
4140
start_w = max(start_w, 0);
4241

@@ -65,19 +64,18 @@ __kernel void pool_avg(__read_only image2d_t input,
6564
__private const int ksize_w,
6665
__private const int stride_h,
6766
__private const int stride_w,
68-
__private const int pad_top,
69-
__private const int pad_left) {
67+
__private const int4 pad) {
7068
const int out_c = get_global_id(0);
7169
const int out_w = get_global_id(1);
7270
const int out_nh = get_global_id(2);
7371
const int out_n = out_nh / out_height;
7472
const int out_h = out_nh % out_height;
7573

76-
int start_h = out_h * stride_h - pad_top;
74+
int start_h = out_h * stride_h - pad.x;
7775
int end_h = min(start_h + ksize_h, in_height);
7876
start_h = max(start_h, 0);
7977

80-
int start_w = out_w * stride_w - pad_left;
78+
int start_w = out_w * stride_w - pad.z;
8179
int end_w = min(start_w + ksize_w, in_width);
8280
start_w = max(start_w, 0);
8381

@@ -96,7 +94,7 @@ __kernel void pool_avg(__read_only image2d_t input,
9694
div = (CL_DTYPE)((end_h - start_h)*(end_w - start_w));
9795
#else
9896
div = (CL_DTYPE)(ksize_w * ksize_h);
99-
#endif
97+
#endif
10098
CL_DTYPE4 avg = sum / div;
10199
const int pos_out_x = mad24(out_c, out_width, out_w);
102100
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(pos_out_x, out_nh), avg);
@@ -112,8 +110,7 @@ __kernel void pool_avg_global(__read_only image2d_t input,
112110
__private const int ksize_w,
113111
__private const int stride_h,
114112
__private const int stride_w,
115-
__private const int pad_top,
116-
__private const int pad_left) {
113+
__private const int4 pad) {
117114
const int out_c = get_global_id(0);
118115
const int out_w = get_global_id(1); // =1
119116
const int out_nh = get_global_id(2); // = n*1
@@ -182,8 +179,7 @@ __kernel void pool_max_global(__read_only image2d_t input,
182179
__private const int ksize_w,
183180
__private const int stride_h,
184181
__private const int stride_w,
185-
__private const int pad_top,
186-
__private const int pad_left) {
182+
__private const int4 pad) {
187183
const int out_c = get_global_id(0);
188184
const int out_w = get_global_id(1); // =1
189185
const int out_nh = get_global_id(2); // = n*1

lite/kernels/opencl/pool_image_compute.cc

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
#endif
2828
#include "lite/backends/opencl/cl_utility.h"
2929

30-
#undef LITE_WITH_LOG
31-
3230
namespace paddle {
3331
namespace lite {
3432
namespace kernels {
@@ -44,7 +42,6 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
4442

4543
void PrepareForRun() override {
4644
const auto& param = *param_.get_mutable<param_t>();
47-
4845
kernel_func_name_ += param.pooling_type;
4946
const bool global_pooling = param.global_pooling;
5047
const bool exclusive = param.exclusive;
@@ -93,6 +90,8 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
9390
}
9491
}
9592

93+
cl_int4 pad = {paddings[0], paddings[1], paddings[2], paddings[3]};
94+
9695
#ifdef LITE_WITH_LOG
9796
VLOG(4) << "in_dims : [" << in_dims.size() << "]" << in_dims[0] << " "
9897
<< in_dims[1] << " " << in_dims[2] << " " << in_dims[3];
@@ -108,12 +107,6 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
108107
<< paddings[1] << " " << paddings[2] << " " << paddings[3];
109108
#endif
110109

111-
bool pads_equal =
112-
(paddings[0] == paddings[1]) && (paddings[2] == paddings[3]);
113-
if (!pads_equal) {
114-
LOG(FATAL)
115-
<< "padding requires pad_left == pad_right, pad_top == pad_bottom";
116-
}
117110
auto& context = ctx_->As<OpenCLContext>();
118111
CHECK(context.cl_context() != nullptr);
119112

@@ -162,9 +155,7 @@ class PoolComputeImage2D : public KernelLite<TARGET(kOpenCL),
162155
CL_CHECK_FATAL(status);
163156
status = kernel.setArg(++arg_idx, static_cast<const int>(strides[1]));
164157
CL_CHECK_FATAL(status);
165-
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[2]));
166-
CL_CHECK_FATAL(status);
167-
status = kernel.setArg(++arg_idx, static_cast<const int>(paddings[0]));
158+
status = kernel.setArg(++arg_idx, pad);
168159
CL_CHECK_FATAL(status);
169160

170161
status = EnqueueNDRangeKernel(context,
@@ -203,4 +194,3 @@ REGISTER_LITE_KERNEL(pool2d,
203194
PRECISION(kFP16),
204195
DATALAYOUT(kImageDefault))})
205196
.Finalize();
206-
#define LITE_WITH_LOG

0 commit comments

Comments
 (0)