Skip to content

Commit daed473

Browse files
authored
Merge pull request #14089 from heavengate/pool_exclude
add inclusive/exclusive mode in avg pool
2 parents 64f3e3e + da8ee1f commit daed473

File tree

12 files changed

+176
-76
lines changed

12 files changed

+176
-76
lines changed

paddle/fluid/API.spec

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
6767
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
6868
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
6969
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
70-
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
71-
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None))
70+
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
71+
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name', 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, True))
7272
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
7373
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
7474
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))

paddle/fluid/operators/math/pooling.cc

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
3131
const framework::Tensor& input, const std::vector<int>& ksize,
3232
const std::vector<int>& strides,
3333
const std::vector<int>& paddings, PoolProcess pool_process,
34-
framework::Tensor* output) {
34+
bool exclusive, framework::Tensor* output) {
3535
const int batch_size = input.dims()[0];
3636
const int input_height = input.dims()[2];
3737
const int input_width = input.dims()[3];
@@ -68,7 +68,8 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
6868
pool_process.compute(input_data[h * input_width + w], &ele);
6969
}
7070
}
71-
int pool_size = (hend - hstart) * (wend - wstart);
71+
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
72+
: ksize_height * ksize_width;
7273
pool_process.finalize(static_cast<T>(pool_size), &ele);
7374
output_data[ph * output_width + pw] = ele;
7475
}
@@ -93,7 +94,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
9394
const framework::Tensor& output, const framework::Tensor& output_grad,
9495
const std::vector<int>& ksize, const std::vector<int>& strides,
9596
const std::vector<int>& paddings, PoolProcess pool_grad_process,
96-
framework::Tensor* input_grad) {
97+
bool exclusive, framework::Tensor* input_grad) {
9798
const int batch_size = input.dims()[0];
9899
const int input_height = input.dims()[2];
99100
const int input_width = input.dims()[3];
@@ -124,7 +125,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
124125
int wstart = pw * stride_width - padding_width;
125126
int wend = std::min(wstart + ksize_width, input_width);
126127
wstart = std::max(wstart, 0);
127-
int pool_size = (hend - hstart) * (wend - wstart);
128+
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
129+
: ksize_height * ksize_width;
128130
float scale = 1.0 / pool_size;
129131
for (int h = hstart; h < hend; ++h) {
130132
for (int w = wstart; w < wend; ++w) {
@@ -249,7 +251,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
249251
const framework::Tensor& input, const std::vector<int>& ksize,
250252
const std::vector<int>& strides,
251253
const std::vector<int>& paddings, PoolProcess pool_process,
252-
framework::Tensor* output) {
254+
bool exclusive, framework::Tensor* output) {
253255
const int batch_size = input.dims()[0];
254256
const int input_depth = input.dims()[2];
255257
const int input_height = input.dims()[3];
@@ -300,7 +302,9 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
300302
}
301303
}
302304
int pool_size =
303-
(dend - dstart) * (hend - hstart) * (wend - wstart);
305+
exclusive
306+
? (dend - dstart) * (hend - hstart) * (wend - wstart)
307+
: ksize_depth * ksize_height * ksize_width;
304308
pool_process.finalize(static_cast<T>(pool_size), &ele);
305309
output_data[output_idx] = ele;
306310
}
@@ -326,7 +330,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
326330
const framework::Tensor& output, const framework::Tensor& output_grad,
327331
const std::vector<int>& ksize, const std::vector<int>& strides,
328332
const std::vector<int>& paddings, PoolProcess pool_grad_process,
329-
framework::Tensor* input_grad) {
333+
bool exclusive, framework::Tensor* input_grad) {
330334
const int batch_size = input.dims()[0];
331335
const int input_depth = input.dims()[2];
332336
const int input_height = input.dims()[3];
@@ -369,7 +373,9 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
369373
wstart = std::max(wstart, 0);
370374

371375
int pool_size =
372-
(dend - dstart) * (hend - hstart) * (wend - wstart);
376+
exclusive
377+
? (dend - dstart) * (hend - hstart) * (wend - wstart)
378+
: ksize_depth * ksize_height * ksize_width;
373379
float scale = 1.0 / pool_size;
374380
for (int d = dstart; d < dend; ++d) {
375381
for (int h = hstart; h < hend; ++h) {

paddle/fluid/operators/math/pooling.cu

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
2929
const int ksize_width, const int stride_height,
3030
const int stride_width, const int padding_height,
3131
const int padding_width, PoolProcess pool_process,
32-
T* output_data) {
32+
bool exclusive, T* output_data) {
3333
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
3434
index += blockDim.x * gridDim.x) {
3535
int pw = index % output_width;
@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
5252
pool_process.compute(input_data[h * input_width + w], &ele);
5353
}
5454
}
55-
int pool_size = (hend - hstart) * (wend - wstart);
55+
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
56+
: ksize_height * ksize_width;
5657
pool_process.finalize(static_cast<T>(pool_size), &ele);
5758
output_data[index] = ele;
5859
}
@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
6566
const int input_width, const int output_height, const int output_width,
6667
const int ksize_height, const int ksize_width, const int stride_height,
6768
const int stride_width, const int padding_height, const int padding_width,
68-
PoolProcess pool_process, T* input_grad) {
69+
PoolProcess pool_process, bool exclusive, T* input_grad) {
6970
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
7071
index += blockDim.x * gridDim.x) {
7172
int offsetW = index % input_width + padding_width;
@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
9596
int wend = min(wstart + ksize_width, input_width);
9697
hstart = max(hstart, 0);
9798
wstart = max(wstart, 0);
98-
int pool_size = (hend - hstart) * (wend - wstart);
99+
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
100+
: ksize_height * ksize_width;
99101
int output_sub_idx = ph * output_width + pw;
100102
pool_process.compute(input, output_data[output_sub_idx],
101103
output_grad[output_sub_idx],
@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
163165
const framework::Tensor& input, const std::vector<int>& ksize,
164166
const std::vector<int>& strides,
165167
const std::vector<int>& paddings, PoolProcess pool_process,
166-
framework::Tensor* output) {
168+
bool exclusive, framework::Tensor* output) {
167169
const int batch_size = input.dims()[0];
168170
const int input_channels = input.dims()[1];
169171
const int input_height = input.dims()[2];
@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
189191
KernelPool2D<PoolProcess, T><<<grid, threads, 0, context.stream()>>>(
190192
nthreads, input_data, input_channels, input_height, input_width,
191193
output_height, output_width, ksize_height, ksize_width, stride_height,
192-
stride_width, padding_height, padding_width, pool_process, output_data);
194+
stride_width, padding_height, padding_width, pool_process, exclusive,
195+
output_data);
193196
}
194197
};
195198

@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
208211
const std::vector<int>& ksize,
209212
const std::vector<int>& strides,
210213
const std::vector<int>& paddings, PoolProcess pool_process,
211-
framework::Tensor* input_grad) {
214+
bool exclusive, framework::Tensor* input_grad) {
212215
const int batch_size = input.dims()[0];
213216
const int input_channels = input.dims()[1];
214217
const int input_height = input.dims()[2];
@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
236239
nthreads, input_data, output_data, output_grad_data, input_channels,
237240
input_height, input_width, output_height, output_width, ksize_height,
238241
ksize_width, stride_height, stride_width, padding_height, padding_width,
239-
pool_process, input_grad_data);
242+
pool_process, exclusive, input_grad_data);
240243
}
241244
};
242245

@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
313316
double>;
314317

315318
template <typename PoolProcess, typename T>
316-
__global__ void KernelPool3D(const int nthreads, const T* input_data,
317-
const int channels, const int input_depth,
318-
const int input_height, const int input_width,
319-
const int output_depth, const int output_height,
320-
const int output_width, const int ksize_depth,
321-
const int ksize_height, const int ksize_width,
322-
const int stride_depth, const int stride_height,
323-
const int stride_width, const int padding_depth,
324-
const int padding_height, const int padding_width,
325-
PoolProcess pool_process, T* output_data) {
319+
__global__ void KernelPool3D(
320+
const int nthreads, const T* input_data, const int channels,
321+
const int input_depth, const int input_height, const int input_width,
322+
const int output_depth, const int output_height, const int output_width,
323+
const int ksize_depth, const int ksize_height, const int ksize_width,
324+
const int stride_depth, const int stride_height, const int stride_width,
325+
const int padding_depth, const int padding_height, const int padding_width,
326+
PoolProcess pool_process, bool exclusive, T* output_data) {
326327
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
327328
index += blockDim.x * gridDim.x) {
328329
int pw = index % output_width;
@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
351352
}
352353
}
353354
}
354-
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
355+
int pool_size = exclusive
356+
? (dend - dstart) * (hend - hstart) * (wend - wstart)
357+
: ksize_depth * ksize_height * ksize_width;
355358
pool_process.finalize(static_cast<T>(pool_size), &ele);
356359
output_data[index] = ele;
357360
}
@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
366369
const int ksize_height, const int ksize_width, const int stride_depth,
367370
const int stride_height, const int stride_width, const int padding_depth,
368371
const int padding_height, const int padding_width, PoolProcess pool_process,
369-
T* input_grad) {
372+
bool exclusive, T* input_grad) {
370373
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
371374
index += blockDim.x * gridDim.x) {
372375
int offsetW = index % input_width + padding_width;
@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
409412
dstart = max(dstart, 0);
410413
hstart = max(hstart, 0);
411414
wstart = max(wstart, 0);
412-
int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
415+
int pool_size =
416+
exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
417+
: ksize_depth * ksize_height * ksize_width;
413418
int output_sub_idx = (pd * output_height + ph) * output_width + pw;
414419
pool_process.compute(input, output_data[output_sub_idx],
415420
output_grad[output_sub_idx],
@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
484489
const framework::Tensor& input, const std::vector<int>& ksize,
485490
const std::vector<int>& strides,
486491
const std::vector<int>& paddings, PoolProcess pool_process,
487-
framework::Tensor* output) {
492+
bool exclusive, framework::Tensor* output) {
488493
const int batch_size = input.dims()[0];
489494
const int input_channels = input.dims()[1];
490495
const int input_depth = input.dims()[2];
@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
517522
nthreads, input_data, input_channels, input_depth, input_height,
518523
input_width, output_depth, output_height, output_width, ksize_depth,
519524
ksize_height, ksize_width, stride_depth, stride_height, stride_width,
520-
padding_depth, padding_height, padding_width, pool_process,
525+
padding_depth, padding_height, padding_width, pool_process, exclusive,
521526
output_data);
522527
}
523528
};
@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
537542
const std::vector<int>& ksize,
538543
const std::vector<int>& strides,
539544
const std::vector<int>& paddings, PoolProcess pool_process,
540-
framework::Tensor* input_grad) {
545+
bool exclusive, framework::Tensor* input_grad) {
541546
const int batch_size = input.dims()[0];
542547
const int input_channels = input.dims()[1];
543548
const int input_depth = input.dims()[2];
@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
573578
input_depth, input_height, input_width, output_depth, output_height,
574579
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
575580
stride_height, stride_width, padding_depth, padding_height,
576-
padding_width, pool_process, input_grad_data);
581+
padding_width, pool_process, exclusive, input_grad_data);
577582
}
578583
};
579584

paddle/fluid/operators/math/pooling.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Pool2dFunctor {
8989
const std::vector<int>& ksize,
9090
const std::vector<int>& strides,
9191
const std::vector<int>& paddings, PoolProcess pool_compute,
92-
framework::Tensor* output);
92+
bool exclusive, framework::Tensor* output);
9393
};
9494

9595
template <typename DeviceContext, typename PoolProcess, typename T>
@@ -101,7 +101,7 @@ class Pool2dGradFunctor {
101101
const std::vector<int>& ksize,
102102
const std::vector<int>& strides,
103103
const std::vector<int>& paddings, PoolProcess pool_compute,
104-
framework::Tensor* input_grad);
104+
bool exclusive, framework::Tensor* input_grad);
105105
};
106106

107107
template <typename DeviceContext, class T>
@@ -123,7 +123,7 @@ class Pool3dFunctor {
123123
const std::vector<int>& ksize,
124124
const std::vector<int>& strides,
125125
const std::vector<int>& paddings, PoolProcess pool_compute,
126-
framework::Tensor* output);
126+
bool exclusive, framework::Tensor* output);
127127
};
128128

129129
template <typename DeviceContext, typename PoolProcess, typename T>
@@ -135,7 +135,7 @@ class Pool3dGradFunctor {
135135
const std::vector<int>& ksize,
136136
const std::vector<int>& strides,
137137
const std::vector<int>& paddings, PoolProcess pool_compute,
138-
framework::Tensor* input_grad);
138+
bool exclusive, framework::Tensor* input_grad);
139139
};
140140

141141
template <typename DeviceContext, class T>

paddle/fluid/operators/pool_cudnn_op.cu.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
4141
T *output_data = output->mutable_data<T>(ctx.GetPlace());
4242

4343
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
44+
bool exclusive = ctx.Attr<bool>("exclusive");
4445
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
4546
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
4647
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
@@ -72,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
7273
if (pooling_type == "max") {
7374
pooling_mode = PoolingMode::kMaximum;
7475
} else {
75-
pooling_mode = PoolingMode::kAverage;
76+
pooling_mode = exclusive ? PoolingMode::kAverageExclusive
77+
: PoolingMode::kAverageInclusive;
7678
}
7779

7880
cudnnPoolingDescriptor_t cudnn_pool_desc =
@@ -101,6 +103,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
101103
Tensor *input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
102104

103105
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
106+
bool exclusive = ctx.Attr<bool>("exclusive");
104107
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
105108
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
106109
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
@@ -141,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
141144
pooling_mode = PoolingMode::kMaximum;
142145
}
143146
} else {
144-
pooling_mode = PoolingMode::kAverage;
147+
pooling_mode = exclusive ? PoolingMode::kAverageExclusive
148+
: PoolingMode::kAverageInclusive;
145149
}
146150

147151
cudnnPoolingDescriptor_t cudnn_pool_desc =

0 commit comments

Comments
 (0)