Skip to content

Commit faea248

Browse files
author
chengduo
authored
Merge pull request #5749 from chengduoZH/fix_pool_with_index_op
Fix pool max with index.(Mask type should be int, not float)
2 parents 134eaf2 + bc3ec53 commit faea248

File tree

7 files changed

+180
-198
lines changed

7 files changed

+180
-198
lines changed

paddle/operators/math/pooling.cc

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
498498
* Ksize, strides, paddings are two elements. These two elements represent
499499
* height and width, respectively.
500500
*/
501-
template <typename T>
502-
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
501+
template <typename T1, typename T2>
502+
class MaxPool2dWithIndexFunctor<platform::CPUPlace, T1, T2> {
503503
public:
504504
void operator()(const platform::DeviceContext& context,
505505
const framework::Tensor& input, std::vector<int>& ksize,
@@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
520520
const int input_stride = input_height * input_width;
521521
const int output_stride = output_height * output_width;
522522

523-
const T* input_data = input.data<T>();
524-
T* output_data = output->mutable_data<T>(context.GetPlace());
525-
T* mask_data = mask->mutable_data<T>(context.GetPlace());
523+
const T1* input_data = input.data<T1>();
524+
T1* output_data = output->mutable_data<T1>(context.GetPlace());
525+
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
526526

527527
for (int i = 0; i < batch_size; i++) {
528528
for (int c = 0; c < output_channels; ++c) {
@@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
535535
int wend = std::min(wstart + ksize_width, input_width);
536536
wstart = std::max(wstart, 0);
537537

538-
T ele = static_cast<T>(-FLT_MAX);
538+
T1 ele = static_cast<T1>(-FLT_MAX);
539539
int index = -1;
540540
for (int h = hstart; h < hend; ++h) {
541541
for (int w = wstart; w < wend; ++w) {
@@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
563563
* Ksize, strides, paddings are two elements. These two elements represent
564564
* height and width, respectively.
565565
*/
566-
template <typename T>
567-
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
566+
template <typename T1, typename T2>
567+
class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
568568
public:
569569
void operator()(const platform::DeviceContext& context,
570570
const framework::Tensor& output_grad,
@@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
580580
const int input_stride = input_height * input_width;
581581
const int output_stride = output_height * output_width;
582582

583-
const T* mask_data = mask.data<T>();
584-
const T* output_grad_data = output_grad.data<T>();
585-
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
583+
const T2* mask_data = mask.data<T2>();
584+
const T1* output_grad_data = output_grad.data<T1>();
585+
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
586586

587587
for (int n = 0; n < batch_size; ++n) {
588588
for (int c = 0; c < output_channels; ++c) {
@@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
602602
}
603603
};
604604

605-
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float>;
606-
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float>;
607-
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double>;
608-
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double>;
605+
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, float, int>;
606+
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, float, int>;
607+
template class MaxPool2dWithIndexFunctor<platform::CPUPlace, double, int>;
608+
template class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, double, int>;
609609

610610
/*
611611
* All tensors are in NCDHW format.
612612
* Ksize, strides, paddings are three elements. These three elements represent
613613
* depth, height and width, respectively.
614614
*/
615-
template <typename T>
616-
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
615+
template <typename T1, typename T2>
616+
class MaxPool3dWithIndexFunctor<platform::CPUPlace, T1, T2> {
617617
public:
618618
void operator()(const platform::DeviceContext& context,
619619
const framework::Tensor& input, std::vector<int>& ksize,
@@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
639639
const int input_stride = input_depth * input_height * input_width;
640640
const int output_stride = output_depth * output_height * output_width;
641641

642-
const T* input_data = input.data<T>();
643-
T* output_data = output->mutable_data<T>(context.GetPlace());
644-
T* mask_data = mask->mutable_data<T>(context.GetPlace());
642+
const T1* input_data = input.data<T1>();
643+
T1* output_data = output->mutable_data<T1>(context.GetPlace());
644+
T2* mask_data = mask->mutable_data<T2>(context.GetPlace());
645645

646646
for (int i = 0; i < batch_size; i++) {
647647
for (int c = 0; c < output_channels; ++c) {
@@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
659659
wstart = std::max(wstart, 0);
660660

661661
int output_idx = (pd * output_height + ph) * output_width + pw;
662-
T ele = static_cast<T>(-FLT_MAX);
662+
T1 ele = static_cast<T1>(-FLT_MAX);
663663
int index = -1;
664664
for (int d = dstart; d < dend; ++d) {
665665
for (int h = hstart; h < hend; ++h) {
@@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
691691
* Ksize, strides, paddings are three elements. These three elements represent
692692
* depth, height and width, respectively.
693693
*/
694-
template <typename T>
695-
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
694+
template <typename T1, typename T2>
695+
class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T1, T2> {
696696
public:
697697
void operator()(const platform::DeviceContext& context,
698698
const framework::Tensor& output_grad,
@@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
710710
const int input_stride = input_depth * input_height * input_width;
711711
const int output_stride = output_depth * output_height * output_width;
712712

713-
const T* mask_data = mask.data<T>();
714-
const T* output_grad_data = output_grad.data<T>();
715-
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
713+
const T2* mask_data = mask.data<T2>();
714+
const T1* output_grad_data = output_grad.data<T1>();
715+
T1* input_grad_data = input_grad->mutable_data<T1>(context.GetPlace());
716716

717717
for (int n = 0; n < batch_size; ++n) {
718718
for (int c = 0; c < output_channels; ++c) {
@@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
735735
}
736736
};
737737

738-
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float>;
739-
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float>;
740-
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double>;
741-
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double>;
738+
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, float, int>;
739+
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, float, int>;
740+
template class MaxPool3dWithIndexFunctor<platform::CPUPlace, double, int>;
741+
template class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, double, int>;
742742
} // namespace math
743743
} // namespace operators
744744
} // namespace paddle

0 commit comments

Comments
 (0)