@@ -15,23 +15,18 @@ limitations under the License. */
15
15
#pragma once
16
16
#include " paddle/framework/op_registry.h"
17
17
#include " paddle/operators/math/math_function.h"
18
- #include " paddle/operators/strided_memcpy.h"
19
18
20
19
namespace paddle {
21
20
namespace operators {
22
21
23
- using Tensor = framework::Tensor;
24
- using LoDTensor = framework::LoDTensor;
25
- using LoD = framework::LoD;
26
-
27
22
template <typename Place, typename T>
28
23
class CPUROIPoolOpKernel : public framework ::OpKernel<T> {
29
24
public:
30
25
void Compute (const framework::ExecutionContext& ctx) const override {
31
- auto * in = ctx.Input <Tensor>(" X" );
32
- auto * rois = ctx.Input <Tensor>(" ROIs" );
33
- auto * out = ctx.Output <Tensor>(" Out" );
34
- auto * argmax = ctx.Output <Tensor>(" Argmax" );
26
+ auto * in = ctx.Input <framework:: Tensor>(" X" );
27
+ auto * rois = ctx.Input <framework:: Tensor>(" ROIs" );
28
+ auto * out = ctx.Output <framework:: Tensor>(" Out" );
29
+ auto * argmax = ctx.Output <framework:: Tensor>(" Argmax" );
35
30
36
31
auto pooled_height = ctx.Attr <int >(" pooled_height" );
37
32
auto pooled_width = ctx.Attr <int >(" pooled_width" );
@@ -54,11 +49,6 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
54
49
T* output_data = out->mutable_data <T>(ctx.GetPlace ());
55
50
int64_t * argmax_data = argmax->mutable_data <int64_t >(ctx.GetPlace ());
56
51
57
- math::SetConstant<Place, T> set_zero;
58
- set_zero (ctx.device_context (), out, static_cast <T>(0 ));
59
- math::SetConstant<Place, int64_t > set_init;
60
- set_init (ctx.device_context (), argmax, static_cast <int64_t >(-1 ));
61
-
62
52
for (int n = 0 ; n < rois_num; ++n) {
63
53
int roi_batch_id = rois_data[0 ];
64
54
PADDLE_ENFORCE_GE (roi_batch_id, 0 );
@@ -83,7 +73,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
83
73
const float bin_size_w =
84
74
static_cast <float >(roi_width) / static_cast <float >(pooled_width);
85
75
86
- const float * batch_data = input_data + roi_batch_id * in_stride[0 ];
76
+ const T * batch_data = input_data + roi_batch_id * in_stride[0 ];
87
77
88
78
for (int c = 0 ; c < channels; ++c) {
89
79
for (int ph = 0 ; ph < pooled_height; ++ph) {
@@ -110,7 +100,8 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
110
100
// Define an empty pooling region to be zero
111
101
bool is_empty = (hend <= hstart) || (wend <= wstart);
112
102
output_data[pool_index] =
113
- is_empty ? 0 : -std::numeric_limits<float >::max ();
103
+ is_empty ? 0 : -std::numeric_limits<T>::max ();
104
+ argmax_data[pool_index] = -1 ;
114
105
115
106
for (int h = hstart; h < hend; ++h) {
116
107
for (int w = wstart; w < wend; ++w) {
@@ -139,14 +130,14 @@ template <typename Place, typename T>
139
130
class CPUROIPoolGradOpKernel : public framework ::OpKernel<T> {
140
131
public:
141
132
void Compute (const framework::ExecutionContext& ctx) const override {
142
- auto * in = ctx.Input <Tensor>(" X" );
143
- auto * rois = ctx.Input <Tensor>(" ROIs" );
144
- auto * argmax = ctx.Input <Tensor>(" Argmax" );
133
+ auto * in = ctx.Input <framework:: Tensor>(" X" );
134
+ auto * rois = ctx.Input <framework:: Tensor>(" ROIs" );
135
+ auto * argmax = ctx.Input <framework:: Tensor>(" Argmax" );
145
136
146
137
auto * out_grad =
147
- ctx.Input <Tensor>(framework::GradVarName (" Out" ));
138
+ ctx.Input <framework:: Tensor>(framework::GradVarName (" Out" ));
148
139
auto * x_grad =
149
- ctx.Output <Tensor>(framework::GradVarName (" X" ));
140
+ ctx.Output <framework:: Tensor>(framework::GradVarName (" X" ));
150
141
151
142
auto pooled_height = ctx.Attr <int >(" pooled_height" );
152
143
auto pooled_width = ctx.Attr <int >(" pooled_width" );
0 commit comments