@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
29
29
const int ksize_width, const int stride_height,
30
30
const int stride_width, const int padding_height,
31
31
const int padding_width, PoolProcess pool_process,
32
- T* output_data) {
32
+ bool exclusive, T* output_data) {
33
33
for (int index = blockIdx .x * blockDim .x + threadIdx .x ; index < nthreads;
34
34
index += blockDim .x * gridDim .x ) {
35
35
int pw = index % output_width;
@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
52
52
pool_process.compute (input_data[h * input_width + w], &ele);
53
53
}
54
54
}
55
- int pool_size = (hend - hstart) * (wend - wstart);
55
+ int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
56
+ : ksize_height * ksize_width;
56
57
pool_process.finalize (static_cast <T>(pool_size), &ele);
57
58
output_data[index] = ele;
58
59
}
@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
65
66
const int input_width, const int output_height, const int output_width,
66
67
const int ksize_height, const int ksize_width, const int stride_height,
67
68
const int stride_width, const int padding_height, const int padding_width,
68
- PoolProcess pool_process, T* input_grad) {
69
+ PoolProcess pool_process, bool exclusive, T* input_grad) {
69
70
for (int index = blockIdx .x * blockDim .x + threadIdx .x ; index < nthreads;
70
71
index += blockDim .x * gridDim .x ) {
71
72
int offsetW = index % input_width + padding_width;
@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
95
96
int wend = min (wstart + ksize_width, input_width);
96
97
hstart = max (hstart, 0 );
97
98
wstart = max (wstart, 0 );
98
- int pool_size = (hend - hstart) * (wend - wstart);
99
+ int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
100
+ : ksize_height * ksize_width;
99
101
int output_sub_idx = ph * output_width + pw;
100
102
pool_process.compute (input, output_data[output_sub_idx],
101
103
output_grad[output_sub_idx],
@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
163
165
const framework::Tensor& input, const std::vector<int >& ksize,
164
166
const std::vector<int >& strides,
165
167
const std::vector<int >& paddings, PoolProcess pool_process,
166
- framework::Tensor* output) {
168
+ bool exclusive, framework::Tensor* output) {
167
169
const int batch_size = input.dims ()[0 ];
168
170
const int input_channels = input.dims ()[1 ];
169
171
const int input_height = input.dims ()[2 ];
@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
189
191
KernelPool2D<PoolProcess, T><<<grid, threads, 0 , context.stream()>>> (
190
192
nthreads, input_data, input_channels, input_height, input_width,
191
193
output_height, output_width, ksize_height, ksize_width, stride_height,
192
- stride_width, padding_height, padding_width, pool_process, output_data);
194
+ stride_width, padding_height, padding_width, pool_process, exclusive,
195
+ output_data);
193
196
}
194
197
};
195
198
@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
208
211
const std::vector<int >& ksize,
209
212
const std::vector<int >& strides,
210
213
const std::vector<int >& paddings, PoolProcess pool_process,
211
- framework::Tensor* input_grad) {
214
+ bool exclusive, framework::Tensor* input_grad) {
212
215
const int batch_size = input.dims ()[0 ];
213
216
const int input_channels = input.dims ()[1 ];
214
217
const int input_height = input.dims ()[2 ];
@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
236
239
nthreads, input_data, output_data, output_grad_data, input_channels,
237
240
input_height, input_width, output_height, output_width, ksize_height,
238
241
ksize_width, stride_height, stride_width, padding_height, padding_width,
239
- pool_process, input_grad_data);
242
+ pool_process, exclusive, input_grad_data);
240
243
}
241
244
};
242
245
@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
313
316
double >;
314
317
315
318
template <typename PoolProcess, typename T>
316
- __global__ void KernelPool3D (const int nthreads, const T* input_data,
317
- const int channels, const int input_depth,
318
- const int input_height, const int input_width,
319
- const int output_depth, const int output_height,
320
- const int output_width, const int ksize_depth,
321
- const int ksize_height, const int ksize_width,
322
- const int stride_depth, const int stride_height,
323
- const int stride_width, const int padding_depth,
324
- const int padding_height, const int padding_width,
325
- PoolProcess pool_process, T* output_data) {
319
+ __global__ void KernelPool3D (
320
+ const int nthreads, const T* input_data, const int channels,
321
+ const int input_depth, const int input_height, const int input_width,
322
+ const int output_depth, const int output_height, const int output_width,
323
+ const int ksize_depth, const int ksize_height, const int ksize_width,
324
+ const int stride_depth, const int stride_height, const int stride_width,
325
+ const int padding_depth, const int padding_height, const int padding_width,
326
+ PoolProcess pool_process, bool exclusive, T* output_data) {
326
327
for (int index = blockIdx .x * blockDim .x + threadIdx .x ; index < nthreads;
327
328
index += blockDim .x * gridDim .x ) {
328
329
int pw = index % output_width;
@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
351
352
}
352
353
}
353
354
}
354
- int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
355
+ int pool_size = exclusive
356
+ ? (dend - dstart) * (hend - hstart) * (wend - wstart)
357
+ : ksize_depth * ksize_height * ksize_width;
355
358
pool_process.finalize (static_cast <T>(pool_size), &ele);
356
359
output_data[index] = ele;
357
360
}
@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
366
369
const int ksize_height, const int ksize_width, const int stride_depth,
367
370
const int stride_height, const int stride_width, const int padding_depth,
368
371
const int padding_height, const int padding_width, PoolProcess pool_process,
369
- T* input_grad) {
372
+ bool exclusive, T* input_grad) {
370
373
for (int index = blockIdx .x * blockDim .x + threadIdx .x ; index < nthreads;
371
374
index += blockDim .x * gridDim .x ) {
372
375
int offsetW = index % input_width + padding_width;
@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
409
412
dstart = max (dstart, 0 );
410
413
hstart = max (hstart, 0 );
411
414
wstart = max (wstart, 0 );
412
- int pool_size = (dend - dstart) * (hend - hstart) * (wend - wstart);
415
+ int pool_size =
416
+ exclusive ? (dend - dstart) * (hend - hstart) * (wend - wstart)
417
+ : ksize_depth * ksize_height * ksize_width;
413
418
int output_sub_idx = (pd * output_height + ph) * output_width + pw;
414
419
pool_process.compute (input, output_data[output_sub_idx],
415
420
output_grad[output_sub_idx],
@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
484
489
const framework::Tensor& input, const std::vector<int >& ksize,
485
490
const std::vector<int >& strides,
486
491
const std::vector<int >& paddings, PoolProcess pool_process,
487
- framework::Tensor* output) {
492
+ bool exclusive, framework::Tensor* output) {
488
493
const int batch_size = input.dims ()[0 ];
489
494
const int input_channels = input.dims ()[1 ];
490
495
const int input_depth = input.dims ()[2 ];
@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
517
522
nthreads, input_data, input_channels, input_depth, input_height,
518
523
input_width, output_depth, output_height, output_width, ksize_depth,
519
524
ksize_height, ksize_width, stride_depth, stride_height, stride_width,
520
- padding_depth, padding_height, padding_width, pool_process,
525
+ padding_depth, padding_height, padding_width, pool_process, exclusive,
521
526
output_data);
522
527
}
523
528
};
@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
537
542
const std::vector<int >& ksize,
538
543
const std::vector<int >& strides,
539
544
const std::vector<int >& paddings, PoolProcess pool_process,
540
- framework::Tensor* input_grad) {
545
+ bool exclusive, framework::Tensor* input_grad) {
541
546
const int batch_size = input.dims ()[0 ];
542
547
const int input_channels = input.dims ()[1 ];
543
548
const int input_depth = input.dims ()[2 ];
@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
573
578
input_depth, input_height, input_width, output_depth, output_height,
574
579
output_width, ksize_depth, ksize_height, ksize_width, stride_depth,
575
580
stride_height, stride_width, padding_depth, padding_height,
576
- padding_width, pool_process, input_grad_data);
581
+ padding_width, pool_process, exclusive, input_grad_data);
577
582
}
578
583
};
579
584
0 commit comments