@@ -11,8 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
11
11
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
-
15
14
#include " paddle/fluid/operators/math/pooling.h"
15
+ #include < algorithm>
16
+ #include < vector>
16
17
17
18
namespace paddle {
18
19
namespace operators {
@@ -27,9 +28,10 @@ template <typename PoolProcess, typename T>
27
28
class Pool2dFunctor <platform::CPUDeviceContext, PoolProcess, T> {
28
29
public:
29
30
void operator ()(const platform::CPUDeviceContext& context,
30
- const framework::Tensor& input, std::vector<int >& ksize,
31
- std::vector<int >& strides, std::vector<int >& paddings,
32
- PoolProcess pool_process, framework::Tensor* output) {
31
+ const framework::Tensor& input, const std::vector<int >& ksize,
32
+ const std::vector<int >& strides,
33
+ const std::vector<int >& paddings, PoolProcess pool_process,
34
+ framework::Tensor* output) {
33
35
const int batch_size = input.dims ()[0 ];
34
36
const int input_height = input.dims ()[2 ];
35
37
const int input_width = input.dims ()[3 ];
@@ -63,11 +65,11 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
63
65
T ele = pool_process.initial ();
64
66
for (int h = hstart; h < hend; ++h) {
65
67
for (int w = wstart; w < wend; ++w) {
66
- pool_process.compute (ele, input_data[h * input_width + w]);
68
+ pool_process.compute (input_data[h * input_width + w], &ele );
67
69
}
68
70
}
69
71
int pool_size = (hend - hstart) * (wend - wstart);
70
- pool_process.finalize (ele, ( static_cast <T>(pool_size)) );
72
+ pool_process.finalize (static_cast <T>(pool_size), &ele );
71
73
output_data[ph * output_width + pw] = ele;
72
74
}
73
75
}
@@ -86,13 +88,12 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
86
88
template <typename PoolProcess, class T >
87
89
class Pool2dGradFunctor <platform::CPUDeviceContext, PoolProcess, T> {
88
90
public:
89
- void operator ()(const platform::CPUDeviceContext& context,
90
- const framework::Tensor& input,
91
- const framework::Tensor& output,
92
- const framework::Tensor& output_grad, std::vector<int >& ksize,
93
- std::vector<int >& strides, std::vector<int >& paddings,
94
- PoolProcess pool_grad_process,
95
- framework::Tensor* input_grad) {
91
+ void operator ()(
92
+ const platform::CPUDeviceContext& context, const framework::Tensor& input,
93
+ const framework::Tensor& output, const framework::Tensor& output_grad,
94
+ const std::vector<int >& ksize, const std::vector<int >& strides,
95
+ const std::vector<int >& paddings, PoolProcess pool_grad_process,
96
+ framework::Tensor* input_grad) {
96
97
const int batch_size = input.dims ()[0 ];
97
98
const int input_height = input.dims ()[2 ];
98
99
const int input_width = input.dims ()[3 ];
@@ -131,8 +132,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
131
132
input_data[h * input_width + w],
132
133
output_data[ph * output_width + pw],
133
134
output_grad_data[ph * output_width + pw],
134
- input_grad_data[h * input_width + w] ,
135
- static_cast <T>(scale) );
135
+ static_cast <T>(scale) ,
136
+ input_grad_data + h * input_width + w );
136
137
}
137
138
}
138
139
}
@@ -154,12 +155,11 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
154
155
template <class T >
155
156
class MaxPool2dGradFunctor <platform::CPUDeviceContext, T> {
156
157
public:
157
- void operator ()(const platform::CPUDeviceContext& context,
158
- const framework::Tensor& input,
159
- const framework::Tensor& output,
160
- const framework::Tensor& output_grad, std::vector<int >& ksize,
161
- std::vector<int >& strides, std::vector<int >& paddings,
162
- framework::Tensor* input_grad) {
158
+ void operator ()(
159
+ const platform::CPUDeviceContext& context, const framework::Tensor& input,
160
+ const framework::Tensor& output, const framework::Tensor& output_grad,
161
+ const std::vector<int >& ksize, const std::vector<int >& strides,
162
+ const std::vector<int >& paddings, framework::Tensor* input_grad) {
163
163
const int batch_size = input.dims ()[0 ];
164
164
const int input_height = input.dims ()[2 ];
165
165
const int input_width = input.dims ()[3 ];
@@ -246,9 +246,10 @@ template <typename PoolProcess, class T>
246
246
class Pool3dFunctor <platform::CPUDeviceContext, PoolProcess, T> {
247
247
public:
248
248
void operator ()(const platform::CPUDeviceContext& context,
249
- const framework::Tensor& input, std::vector<int >& ksize,
250
- std::vector<int >& strides, std::vector<int >& paddings,
251
- PoolProcess pool_process, framework::Tensor* output) {
249
+ const framework::Tensor& input, const std::vector<int >& ksize,
250
+ const std::vector<int >& strides,
251
+ const std::vector<int >& paddings, PoolProcess pool_process,
252
+ framework::Tensor* output) {
252
253
const int batch_size = input.dims ()[0 ];
253
254
const int input_depth = input.dims ()[2 ];
254
255
const int input_height = input.dims ()[3 ];
@@ -293,14 +294,14 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
293
294
for (int h = hstart; h < hend; ++h) {
294
295
for (int w = wstart; w < wend; ++w) {
295
296
pool_process.compute (
296
- ele ,
297
- input_data[(d * input_height + h) * input_width + w] );
297
+ input_data[(d * input_height + h) * input_width + w] ,
298
+ &ele );
298
299
}
299
300
}
300
301
}
301
302
int pool_size =
302
303
(dend - dstart) * (hend - hstart) * (wend - wstart);
303
- pool_process.finalize (ele, static_cast <T>(pool_size));
304
+ pool_process.finalize (static_cast <T>(pool_size), &ele );
304
305
output_data[output_idx] = ele;
305
306
}
306
307
}
@@ -320,13 +321,12 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
320
321
template <typename PoolProcess, class T >
321
322
class Pool3dGradFunctor <platform::CPUDeviceContext, PoolProcess, T> {
322
323
public:
323
- void operator ()(const platform::CPUDeviceContext& context,
324
- const framework::Tensor& input,
325
- const framework::Tensor& output,
326
- const framework::Tensor& output_grad, std::vector<int >& ksize,
327
- std::vector<int >& strides, std::vector<int >& paddings,
328
- PoolProcess pool_grad_process,
329
- framework::Tensor* input_grad) {
324
+ void operator ()(
325
+ const platform::CPUDeviceContext& context, const framework::Tensor& input,
326
+ const framework::Tensor& output, const framework::Tensor& output_grad,
327
+ const std::vector<int >& ksize, const std::vector<int >& strides,
328
+ const std::vector<int >& paddings, PoolProcess pool_grad_process,
329
+ framework::Tensor* input_grad) {
330
330
const int batch_size = input.dims ()[0 ];
331
331
const int input_depth = input.dims ()[2 ];
332
332
const int input_height = input.dims ()[3 ];
@@ -379,8 +379,8 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
379
379
(pd * output_height + ph) * output_width + pw;
380
380
pool_grad_process.compute (
381
381
input_data[input_idx], output_data[output_idx],
382
- output_grad_data[output_idx],
383
- input_grad_data[input_idx], static_cast <T>(scale) );
382
+ output_grad_data[output_idx], static_cast <T>(scale),
383
+ input_grad_data + input_idx );
384
384
}
385
385
}
386
386
}
@@ -404,12 +404,11 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
404
404
template <class T >
405
405
class MaxPool3dGradFunctor <platform::CPUDeviceContext, T> {
406
406
public:
407
- void operator ()(const platform::CPUDeviceContext& context,
408
- const framework::Tensor& input,
409
- const framework::Tensor& output,
410
- const framework::Tensor& output_grad, std::vector<int >& ksize,
411
- std::vector<int >& strides, std::vector<int >& paddings,
412
- framework::Tensor* input_grad) {
407
+ void operator ()(
408
+ const platform::CPUDeviceContext& context, const framework::Tensor& input,
409
+ const framework::Tensor& output, const framework::Tensor& output_grad,
410
+ const std::vector<int >& ksize, const std::vector<int >& strides,
411
+ const std::vector<int >& paddings, framework::Tensor* input_grad) {
413
412
const int batch_size = input.dims ()[0 ];
414
413
const int input_depth = input.dims ()[2 ];
415
414
const int input_height = input.dims ()[3 ];
@@ -510,9 +509,10 @@ template <typename T1, typename T2>
510
509
class MaxPool2dWithIndexFunctor <platform::CPUDeviceContext, T1, T2> {
511
510
public:
512
511
void operator ()(const platform::CPUDeviceContext& context,
513
- const framework::Tensor& input, std::vector<int >& ksize,
514
- std::vector<int >& strides, std::vector<int >& paddings,
515
- framework::Tensor* output, framework::Tensor* mask) {
512
+ const framework::Tensor& input, const std::vector<int >& ksize,
513
+ const std::vector<int >& strides,
514
+ const std::vector<int >& paddings, framework::Tensor* output,
515
+ framework::Tensor* mask) {
516
516
const int batch_size = input.dims ()[0 ];
517
517
const int input_height = input.dims ()[2 ];
518
518
const int input_width = input.dims ()[3 ];
@@ -576,8 +576,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
576
576
public:
577
577
void operator ()(const platform::CPUDeviceContext& context,
578
578
const framework::Tensor& output_grad,
579
- const framework::Tensor& mask, std::vector<int >& ksize,
580
- std::vector<int >& strides, std::vector<int >& paddings,
579
+ const framework::Tensor& mask, const std::vector<int >& ksize,
580
+ const std::vector<int >& strides,
581
+ const std::vector<int >& paddings,
581
582
framework::Tensor* input_grad) {
582
583
const int batch_size = input_grad->dims ()[0 ];
583
584
const int input_height = input_grad->dims ()[2 ];
@@ -628,9 +629,10 @@ template <typename T1, typename T2>
628
629
class MaxPool3dWithIndexFunctor <platform::CPUDeviceContext, T1, T2> {
629
630
public:
630
631
void operator ()(const platform::CPUDeviceContext& context,
631
- const framework::Tensor& input, std::vector<int >& ksize,
632
- std::vector<int >& strides, std::vector<int >& paddings,
633
- framework::Tensor* output, framework::Tensor* mask) {
632
+ const framework::Tensor& input, const std::vector<int >& ksize,
633
+ const std::vector<int >& strides,
634
+ const std::vector<int >& paddings, framework::Tensor* output,
635
+ framework::Tensor* mask) {
634
636
const int batch_size = input.dims ()[0 ];
635
637
const int input_depth = input.dims ()[2 ];
636
638
const int input_height = input.dims ()[3 ];
@@ -708,8 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
708
710
public:
709
711
void operator ()(const platform::CPUDeviceContext& context,
710
712
const framework::Tensor& output_grad,
711
- const framework::Tensor& mask, std::vector<int >& ksize,
712
- std::vector<int >& strides, std::vector<int >& paddings,
713
+ const framework::Tensor& mask, const std::vector<int >& ksize,
714
+ const std::vector<int >& strides,
715
+ const std::vector<int >& paddings,
713
716
framework::Tensor* input_grad) {
714
717
const int batch_size = input_grad->dims ()[0 ];
715
718
const int input_depth = input_grad->dims ()[2 ];
0 commit comments