Skip to content

Commit cf5b598

Browse files
committed
fix some issues
1 parent ef90559 commit cf5b598

File tree

4 files changed

+37
-36
lines changed

4 files changed

+37
-36
lines changed

paddle/operators/roi_pool_op.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
21+
22+
static constexpr int kROISize = 5;
23+
2024
class ROIPoolOp : public framework::OperatorWithKernel {
2125
public:
2226
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -38,6 +42,9 @@ class ROIPoolOp : public framework::OperatorWithKernel {
3842
PADDLE_ENFORCE(rois_dims.size() == 2,
3943
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
4044
"given as [[batch_id, x1, y1, x2, y2], …].");
45+
PADDLE_ENFORCE(rois_dims[1] == kROISize,
46+
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
47+
"given as [[batch_id, x1, y1, x2, y2], …].");
4148

4249
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
4350
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
@@ -150,7 +157,9 @@ REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
150157
roi_pool_grad, ops::ROIPoolGradOp);
151158
REGISTER_OP_CPU_KERNEL(
152159
roi_pool,
153-
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>);
160+
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
161+
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
154162
REGISTER_OP_CPU_KERNEL(
155163
roi_pool_grad,
156-
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUPlace, float>);
164+
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUPlace, float>,
165+
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);

paddle/operators/roi_pool_op.cu

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
using Tensor = framework::Tensor;
22+
2123
static constexpr int kNumCUDAThreads = 512;
2224
static constexpr int kNumMaxinumNumBlocks = 4096;
2325
static constexpr int kROISize = 5;
2426

2527
static inline int NumBlocks(const int N) {
2628
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
2729
kNumMaxinumNumBlocks);
28-
}
30+
}
2931

3032
template <typename T>
3133
__global__ void GPUROIPoolForward(
@@ -64,7 +66,7 @@ static inline int NumBlocks(const int N) {
6466
wend = min(max(wend + roi_start_w, 0), width);
6567
bool is_empty = (hend <= hstart) || (wend <= wstart);
6668

67-
T maxval = is_empty ? 0 : -std::numeric_limits<float>::max();
69+
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
6870
int maxidx = -1;
6971
const T* offset_input_data =
7072
input_data + (roi_batch_ind * channels + c) * height * width;
@@ -143,14 +145,6 @@ class GPUROIPoolOpKernel : public framework::OpKernel<T> {
143145
int width = in_dims[3];
144146

145147
size_t rois_num = rois->dims()[0];
146-
147-
out->mutable_data<T>(ctx.GetPlace());
148-
math::SetConstant<Place, T> set_zero;
149-
set_zero(ctx.device_context(), out, static_cast<T>(0));
150-
argmax->mutable_data<int64_t>(ctx.GetPlace());
151-
math::SetConstant<Place, int64_t> set_init;
152-
set_init(ctx.device_context(), argmax, static_cast<int64_t>(-1));
153-
154148
if (rois_num== 0) return;
155149

156150
int output_size = out->numel();
@@ -230,7 +224,9 @@ class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
230224
namespace ops = paddle::operators;
231225
REGISTER_OP_GPU_KERNEL(
232226
roi_pool,
233-
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>);
227+
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
228+
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
234229
REGISTER_OP_GPU_KERNEL(
235230
roi_pool_grad,
236-
ops::GPUROIPoolGradOpKernel<paddle::platform::GPUPlace, float>);
231+
ops::GPUROIPoolGradOpKernel<paddle::platform::GPUPlace, float>,
232+
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);

paddle/operators/roi_pool_op.h

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,18 @@ limitations under the License. */
1515
#pragma once
1616
#include "paddle/framework/op_registry.h"
1717
#include "paddle/operators/math/math_function.h"
18-
#include "paddle/operators/strided_memcpy.h"
1918

2019
namespace paddle {
2120
namespace operators {
2221

23-
using Tensor = framework::Tensor;
24-
using LoDTensor = framework::LoDTensor;
25-
using LoD = framework::LoD;
26-
2722
template <typename Place, typename T>
2823
class CPUROIPoolOpKernel : public framework::OpKernel<T> {
2924
public:
3025
void Compute(const framework::ExecutionContext& ctx) const override {
31-
auto* in = ctx.Input<Tensor>("X");
32-
auto* rois = ctx.Input<Tensor>("ROIs");
33-
auto* out = ctx.Output<Tensor>("Out");
34-
auto* argmax = ctx.Output<Tensor>("Argmax");
26+
auto* in = ctx.Input<framework::Tensor>("X");
27+
auto* rois = ctx.Input<framework::Tensor>("ROIs");
28+
auto* out = ctx.Output<framework::Tensor>("Out");
29+
auto* argmax = ctx.Output<framework::Tensor>("Argmax");
3530

3631
auto pooled_height = ctx.Attr<int>("pooled_height");
3732
auto pooled_width = ctx.Attr<int>("pooled_width");
@@ -54,11 +49,6 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
5449
T* output_data = out->mutable_data<T>(ctx.GetPlace());
5550
int64_t* argmax_data = argmax->mutable_data<int64_t>(ctx.GetPlace());
5651

57-
math::SetConstant<Place, T> set_zero;
58-
set_zero(ctx.device_context(), out, static_cast<T>(0));
59-
math::SetConstant<Place, int64_t> set_init;
60-
set_init(ctx.device_context(), argmax, static_cast<int64_t>(-1));
61-
6252
for (int n = 0; n < rois_num; ++n) {
6353
int roi_batch_id = rois_data[0];
6454
PADDLE_ENFORCE_GE(roi_batch_id, 0);
@@ -83,7 +73,7 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
8373
const float bin_size_w =
8474
static_cast<float>(roi_width) / static_cast<float>(pooled_width);
8575

86-
const float* batch_data = input_data + roi_batch_id * in_stride[0];
76+
const T* batch_data = input_data + roi_batch_id * in_stride[0];
8777

8878
for (int c = 0; c < channels; ++c) {
8979
for (int ph = 0; ph < pooled_height; ++ph) {
@@ -110,7 +100,8 @@ class CPUROIPoolOpKernel : public framework::OpKernel<T> {
110100
// Define an empty pooling region to be zero
111101
bool is_empty = (hend <= hstart) || (wend <= wstart);
112102
output_data[pool_index] =
113-
is_empty ? 0 : -std::numeric_limits<float>::max();
103+
is_empty ? 0 : -std::numeric_limits<T>::max();
104+
argmax_data[pool_index] = -1;
114105

115106
for (int h = hstart; h < hend; ++h) {
116107
for (int w = wstart; w < wend; ++w) {
@@ -139,14 +130,14 @@ template <typename Place, typename T>
139130
class CPUROIPoolGradOpKernel : public framework::OpKernel<T> {
140131
public:
141132
void Compute(const framework::ExecutionContext& ctx) const override {
142-
auto* in = ctx.Input<Tensor>("X");
143-
auto* rois = ctx.Input<Tensor>("ROIs");
144-
auto* argmax = ctx.Input<Tensor>("Argmax");
133+
auto* in = ctx.Input<framework::Tensor>("X");
134+
auto* rois = ctx.Input<framework::Tensor>("ROIs");
135+
auto* argmax = ctx.Input<framework::Tensor>("Argmax");
145136

146137
auto* out_grad =
147-
ctx.Input<Tensor>(framework::GradVarName("Out"));
138+
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
148139
auto* x_grad =
149-
ctx.Output<Tensor>(framework::GradVarName("X"));
140+
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
150141

151142
auto pooled_height = ctx.Attr<int>("pooled_height");
152143
auto pooled_width = ctx.Attr<int>("pooled_width");

python/paddle/v2/fluid/tests/test_roi_pool_op.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ def calc_roi_pool(self):
7777
wstart = min(max(wstart + roi_start_w, 0), self.width)
7878
wend = min(max(wend + roi_start_w, 0), self.width)
7979

80-
out_data[i, c, ph, pw] = 0
80+
is_empty = (hend <= hstart) or (wend <= wstart)
81+
if is_empty:
82+
out_data[i, c, ph, pw] = 0
83+
else:
84+
out_data[i, c, ph, pw] = -sys.float_info.max
85+
8186
argmax_data[i, c, ph, pw] = -1
8287

8388
for h in range(hstart, hend):

0 commit comments

Comments
 (0)