@@ -498,8 +498,8 @@ template class Pool3dGradFunctor<
498
498
* Ksize, strides, paddings are two elements. These two elements represent
499
499
* height and width, respectively.
500
500
*/
501
- template <typename T >
502
- class MaxPool2dWithIndexFunctor <platform::CPUPlace, T > {
501
+ template <typename T1, typename T2 >
502
+ class MaxPool2dWithIndexFunctor <platform::CPUPlace, T1, T2 > {
503
503
public:
504
504
void operator ()(const platform::DeviceContext& context,
505
505
const framework::Tensor& input, std::vector<int >& ksize,
@@ -520,9 +520,9 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
520
520
const int input_stride = input_height * input_width;
521
521
const int output_stride = output_height * output_width;
522
522
523
- const T * input_data = input.data <T >();
524
- T * output_data = output->mutable_data <T >(context.GetPlace ());
525
- T * mask_data = mask->mutable_data <T >(context.GetPlace ());
523
+ const T1 * input_data = input.data <T1 >();
524
+ T1 * output_data = output->mutable_data <T1 >(context.GetPlace ());
525
+ T2 * mask_data = mask->mutable_data <T2 >(context.GetPlace ());
526
526
527
527
for (int i = 0 ; i < batch_size; i++) {
528
528
for (int c = 0 ; c < output_channels; ++c) {
@@ -535,7 +535,7 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
535
535
int wend = std::min (wstart + ksize_width, input_width);
536
536
wstart = std::max (wstart, 0 );
537
537
538
- T ele = static_cast <T >(-FLT_MAX);
538
+ T1 ele = static_cast <T1 >(-FLT_MAX);
539
539
int index = -1 ;
540
540
for (int h = hstart; h < hend; ++h) {
541
541
for (int w = wstart; w < wend; ++w) {
@@ -563,8 +563,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
563
563
* Ksize, strides, paddings are two elements. These two elements represent
564
564
* height and width, respectively.
565
565
*/
566
- template <typename T >
567
- class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, T > {
566
+ template <typename T1, typename T2 >
567
+ class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, T1, T2 > {
568
568
public:
569
569
void operator ()(const platform::DeviceContext& context,
570
570
const framework::Tensor& output_grad,
@@ -580,9 +580,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
580
580
const int input_stride = input_height * input_width;
581
581
const int output_stride = output_height * output_width;
582
582
583
- const T * mask_data = mask.data <T >();
584
- const T * output_grad_data = output_grad.data <T >();
585
- T * input_grad_data = input_grad->mutable_data <T >(context.GetPlace ());
583
+ const T2 * mask_data = mask.data <T2 >();
584
+ const T1 * output_grad_data = output_grad.data <T1 >();
585
+ T1 * input_grad_data = input_grad->mutable_data <T1 >(context.GetPlace ());
586
586
587
587
for (int n = 0 ; n < batch_size; ++n) {
588
588
for (int c = 0 ; c < output_channels; ++c) {
@@ -602,18 +602,18 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
602
602
}
603
603
};
604
604
605
- template class MaxPool2dWithIndexFunctor <platform::CPUPlace, float >;
606
- template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, float >;
607
- template class MaxPool2dWithIndexFunctor <platform::CPUPlace, double >;
608
- template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, double >;
605
+ template class MaxPool2dWithIndexFunctor <platform::CPUPlace, float , int >;
606
+ template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, float , int >;
607
+ template class MaxPool2dWithIndexFunctor <platform::CPUPlace, double , int >;
608
+ template class MaxPool2dWithIndexGradFunctor <platform::CPUPlace, double , int >;
609
609
610
610
/*
611
611
* All tensors are in NCDHW format.
612
612
* Ksize, strides, paddings are three elements. These three elements represent
613
613
* depth, height and width, respectively.
614
614
*/
615
- template <typename T >
616
- class MaxPool3dWithIndexFunctor <platform::CPUPlace, T > {
615
+ template <typename T1, typename T2 >
616
+ class MaxPool3dWithIndexFunctor <platform::CPUPlace, T1, T2 > {
617
617
public:
618
618
void operator ()(const platform::DeviceContext& context,
619
619
const framework::Tensor& input, std::vector<int >& ksize,
@@ -639,9 +639,9 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
639
639
const int input_stride = input_depth * input_height * input_width;
640
640
const int output_stride = output_depth * output_height * output_width;
641
641
642
- const T * input_data = input.data <T >();
643
- T * output_data = output->mutable_data <T >(context.GetPlace ());
644
- T * mask_data = mask->mutable_data <T >(context.GetPlace ());
642
+ const T1 * input_data = input.data <T1 >();
643
+ T1 * output_data = output->mutable_data <T1 >(context.GetPlace ());
644
+ T2 * mask_data = mask->mutable_data <T2 >(context.GetPlace ());
645
645
646
646
for (int i = 0 ; i < batch_size; i++) {
647
647
for (int c = 0 ; c < output_channels; ++c) {
@@ -659,7 +659,7 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
659
659
wstart = std::max (wstart, 0 );
660
660
661
661
int output_idx = (pd * output_height + ph) * output_width + pw;
662
- T ele = static_cast <T >(-FLT_MAX);
662
+ T1 ele = static_cast <T1 >(-FLT_MAX);
663
663
int index = -1 ;
664
664
for (int d = dstart; d < dend; ++d) {
665
665
for (int h = hstart; h < hend; ++h) {
@@ -691,8 +691,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
691
691
* Ksize, strides, paddings are three elements. These three elements represent
692
692
* depth, height and width, respectively.
693
693
*/
694
- template <typename T >
695
- class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, T > {
694
+ template <typename T1, typename T2 >
695
+ class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, T1, T2 > {
696
696
public:
697
697
void operator ()(const platform::DeviceContext& context,
698
698
const framework::Tensor& output_grad,
@@ -710,9 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
710
710
const int input_stride = input_depth * input_height * input_width;
711
711
const int output_stride = output_depth * output_height * output_width;
712
712
713
- const T * mask_data = mask.data <T >();
714
- const T * output_grad_data = output_grad.data <T >();
715
- T * input_grad_data = input_grad->mutable_data <T >(context.GetPlace ());
713
+ const T2 * mask_data = mask.data <T2 >();
714
+ const T1 * output_grad_data = output_grad.data <T1 >();
715
+ T1 * input_grad_data = input_grad->mutable_data <T1 >(context.GetPlace ());
716
716
717
717
for (int n = 0 ; n < batch_size; ++n) {
718
718
for (int c = 0 ; c < output_channels; ++c) {
@@ -735,10 +735,10 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
735
735
}
736
736
};
737
737
738
- template class MaxPool3dWithIndexFunctor <platform::CPUPlace, float >;
739
- template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, float >;
740
- template class MaxPool3dWithIndexFunctor <platform::CPUPlace, double >;
741
- template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, double >;
738
+ template class MaxPool3dWithIndexFunctor <platform::CPUPlace, float , int >;
739
+ template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, float , int >;
740
+ template class MaxPool3dWithIndexFunctor <platform::CPUPlace, double , int >;
741
+ template class MaxPool3dWithIndexGradFunctor <platform::CPUPlace, double , int >;
742
742
} // namespace math
743
743
} // namespace operators
744
744
} // namespace paddle
0 commit comments