Skip to content

Commit 7960928

Browse files
author
wanghaox
committed
add roi pool operator
1 parent 9216da3 commit 7960928

File tree

3 files changed

+604
-0
lines changed

3 files changed

+604
-0
lines changed

paddle/operators/roi_pool_op.cc

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/roi_pool_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class RoiPoolOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of RoiPoolOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("Rois"),
28+
"Input(Rois) of RoiPoolOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of RoiPoolOp should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("Argmax"),
32+
"Output(Argmax) of RoiPoolOp should not be null.");
33+
auto input_dims = ctx->GetInputDim("X");
34+
35+
// Initialize the output's dims to maximum,
36+
// and re-set to real dims by the value of Rois at kernel
37+
ctx->SetOutputDim("Out", input_dims);
38+
}
39+
40+
protected:
41+
framework::OpKernelType GetKernelType(
42+
const framework::ExecutionContext& ctx) const override {
43+
return framework::OpKernelType(
44+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
45+
ctx.device_context());
46+
}
47+
};
48+
49+
class RoiPoolGradOp : public framework::OperatorWithKernel {
50+
public:
51+
using framework::OperatorWithKernel::OperatorWithKernel;
52+
53+
void InferShape(framework::InferShapeContext* ctx) const override {
54+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
55+
"The gradient of Out should not be null.");
56+
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
57+
"The gradient of X should not be null.");
58+
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
59+
}
60+
61+
protected:
62+
framework::OpKernelType GetKernelType(
63+
const framework::ExecutionContext& ctx) const override {
64+
return framework::OpKernelType(
65+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
66+
ctx.device_context());
67+
}
68+
};
69+
70+
class RoiPoolOpMaker : public framework::OpProtoAndCheckerMaker {
71+
public:
72+
RoiPoolOpMaker(framework::OpProto* proto,
73+
framework::OpAttrChecker* op_checker)
74+
: OpProtoAndCheckerMaker(proto, op_checker) {
75+
AddInput("X",
76+
"(Tensor), "
77+
"the input of RoiPoolOp.");
78+
AddInput("Rois",
79+
"(Tensor), "
80+
"RoIs (Regions of Interest) to pool over. "
81+
"Should be a 2-D tensor of shape (num_rois, 5)"
82+
"given as [[batch_id, x1, y1, x2, y2], …].");
83+
AddOutput("Out",
84+
"(Tensor), "
85+
"RoI pooled output 4-D tensor of shape "
86+
"(num_rois, channels, pooled_h, pooled_w).");
87+
AddOutput("Argmax",
88+
"(Tensor), "
89+
"Argmaxes corresponding to indices in X used "
90+
"for gradient computation. Only output "
91+
"if arg “is_test” is false.").AsIntermediate();
92+
AddAttr<float>("spatial_scale",
93+
"(float, default 1.0), "
94+
"Multiplicative spatial scale factor "
95+
"to translate ROI coords from their input scale "
96+
"to the scale used when pooling.")
97+
.SetDefault(1.0);
98+
AddAttr<int>("pooled_height",
99+
"(int, default 1), "
100+
"The pooled output height.")
101+
.SetDefault(1);
102+
AddAttr<int>("pooled_width",
103+
"(int, default 1), "
104+
"The pooled output width.")
105+
.SetDefault(1);
106+
AddComment(R"DOC(
107+
RoiPool operator
108+
109+
ROI Pooling for Faster-RCNN. The link below is a further introduction:
110+
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
111+
)DOC");
112+
}
113+
};
114+
115+
} // namespace operators
116+
} // namespace paddle
117+
118+
namespace ops = paddle::operators;
119+
REGISTER_OP(roi_pool, ops::RoiPoolOp, ops::RoiPoolOpMaker,
120+
roi_pool_grad, ops::RoiPoolGradOp);
121+
REGISTER_OP_CPU_KERNEL(
122+
roi_pool,
123+
ops::CPURoiPoolOpKernel<paddle::platform::CPUPlace, float>);
124+
REGISTER_OP_CPU_KERNEL(
125+
roi_pool_grad,
126+
ops::CPURoiPoolGradOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/roi_pool_op.cu

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/platform/cuda_helper.h"
16+
#include "paddle/operators/roi_pool_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
#define FLT_MAX __FLT_MAX__
22+
23+
constexpr int PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS = 512;
24+
constexpr int PADDLE_OPERATORS_ROIPOOL_MAXIMUM_NUM_BLOCKS = 4096;
25+
26+
inline int PADDLE_OPERATORS_ROIPOOL_GET_BLOCKS(const int N) {
27+
return std::min((N + PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS - 1)
28+
/ PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS,
29+
PADDLE_OPERATORS_ROIPOOL_MAXIMUM_NUM_BLOCKS);
30+
}
31+
32+
template <typename T>
33+
__global__ void GPURoiPoolForward(
34+
const int nthreads,
35+
const T* input_data,
36+
const int64_t* input_rois,
37+
const float spatial_scale,
38+
const int channels,
39+
const int height,
40+
const int width,
41+
const int pooled_height,
42+
const int pooled_width,
43+
T* output_data,
44+
int64_t* argmax_data) {
45+
int index = blockIdx.x * blockDim.x + threadIdx.x;
46+
int offset = blockDim.x * gridDim.x;
47+
for (size_t i = index; i < nthreads; i += offset) {
48+
int pw = index % pooled_width;
49+
int ph = (index / pooled_width) % pooled_height;
50+
int c = (index / pooled_width / pooled_height) % channels;
51+
int n = index / pooled_width / pooled_height / channels;
52+
53+
const int64_t* offset_input_rois = input_rois + n * 5;
54+
int roi_batch_ind = offset_input_rois[0];
55+
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
56+
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
57+
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
58+
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
59+
60+
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
61+
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
62+
T bin_size_h = static_cast<T>(roi_height)
63+
/ static_cast<T>(pooled_height);
64+
T bin_size_w = static_cast<T>(roi_width)
65+
/ static_cast<T>(pooled_width);
66+
67+
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
68+
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
69+
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
70+
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
71+
72+
hstart = min(max(hstart + roi_start_h, 0), height);
73+
hend = min(max(hend + roi_start_h, 0), height);
74+
wstart = min(max(wstart + roi_start_w, 0), width);
75+
wend = min(max(wend + roi_start_w, 0), width);
76+
bool is_empty = (hend <= hstart) || (wend <= wstart);
77+
78+
T maxval = is_empty ? 0 : -FLT_MAX;
79+
int maxidx = -1;
80+
const T* offset_input_data =
81+
input_data + (roi_batch_ind * channels + c) * height * width;
82+
for (int h = hstart; h < hend; ++h) {
83+
for (int w = wstart; w < wend; ++w) {
84+
int input_data_index = h * width + w;
85+
if (offset_input_data[input_data_index] > maxval) {
86+
maxval = offset_input_data[input_data_index];
87+
maxidx = input_data_index;
88+
}
89+
}
90+
}
91+
output_data[index] = maxval;
92+
if (argmax_data) {
93+
argmax_data[index] = maxidx;
94+
}
95+
}
96+
}
97+
98+
template <typename T>
99+
__global__ void GPURoiPoolBackward(
100+
const int nthreads,
101+
const int64_t* input_rois,
102+
const T* output_grad,
103+
const int64_t* argmax_data,
104+
const int num_rois,
105+
const float spatial_scale,
106+
const int channels,
107+
const int height,
108+
const int width,
109+
const int pooled_height,
110+
const int pooled_width,
111+
T* input_grad) {
112+
int index = blockIdx.x * blockDim.x + threadIdx.x;
113+
int offset = blockDim.x * gridDim.x;
114+
for (int i = index; i < nthreads; i += offset) {
115+
int pw = index % pooled_width;
116+
int ph = (index / pooled_width) % pooled_height;
117+
int c = (index / pooled_width / pooled_height) % channels;
118+
int n = index / pooled_width / pooled_height / channels;
119+
120+
const int64_t* offset_input_rois = input_rois + n * 5;
121+
int roi_batch_ind = offset_input_rois[0];
122+
int input_offset = (roi_batch_ind * channels + c) * height * width;
123+
int output_offset = (n * channels + c) * pooled_height * pooled_width;
124+
const T* offset_output_grad = output_grad + output_offset;
125+
T* offset_input_grad = input_grad + input_offset;
126+
const int64_t* offset_argmax_data = argmax_data + output_offset;
127+
128+
int argmax = offset_argmax_data[ph * pooled_width + pw];
129+
if (argmax != -1) {
130+
platform::CudaAtomicAdd(offset_input_grad + argmax,
131+
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
132+
}
133+
}
134+
}
135+
136+
137+
template <typename Place, typename T>
138+
class GPURoiPoolOpKernel : public framework::OpKernel<T> {
139+
public:
140+
void Compute(const framework::ExecutionContext& ctx) const override {
141+
auto* in = ctx.Input<Tensor>("X");
142+
auto* rois = ctx.Input<Tensor>("Rois");
143+
auto* out = ctx.Output<Tensor>("Out");
144+
auto* argmax = ctx.Output<Tensor>("Argmax");
145+
146+
auto pooled_height = ctx.Attr<int>("pooled_height");
147+
auto pooled_width = ctx.Attr<int>("pooled_width");
148+
auto spatial_scale = ctx.Attr<float>("spatial_scale");
149+
150+
PADDLE_ENFORCE_GT(pooled_height, 0,
151+
"The pooled output height must greater than 0");
152+
PADDLE_ENFORCE_GT(pooled_width, 0,
153+
"The pooled output width must greater than 0");
154+
PADDLE_ENFORCE_GT(spatial_scale, 0,
155+
"The spatial scale must greater than 0");
156+
157+
auto in_dims = in->dims();
158+
auto in_stride = framework::stride(in_dims);
159+
int channels = in_dims[1];
160+
int height = in_dims[2];
161+
int width = in_dims[3];
162+
163+
int rois_num = rois->dims()[0];
164+
auto out_dims = in_dims;
165+
out_dims[0] = rois_num;
166+
out_dims[1] = in_dims[1];
167+
out_dims[2] = pooled_height;
168+
out_dims[3] = pooled_width;
169+
170+
out->Resize(out_dims);
171+
out->mutable_data<T>(ctx.GetPlace());
172+
math::SetConstant<Place, T> set_zero;
173+
set_zero(ctx.device_context(), out, static_cast<T>(0));
174+
argmax->Resize(out->dims());
175+
argmax->mutable_data<int64_t>(ctx.GetPlace());
176+
math::SetConstant<Place, int64_t> set_init;
177+
set_init(ctx.device_context(), argmax, static_cast<int64_t>(-1));
178+
179+
if (rois_num== 0) return;
180+
181+
int output_size = out->numel();
182+
int blocks = PADDLE_OPERATORS_ROIPOOL_GET_BLOCKS(output_size);
183+
int threads = PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS;
184+
185+
GPURoiPoolForward<T>
186+
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
187+
output_size,
188+
in->data<T>(),
189+
rois->data<int64_t>(),
190+
spatial_scale,
191+
channels,
192+
height,
193+
width,
194+
pooled_height,
195+
pooled_width,
196+
out->mutable_data<T>(ctx.GetPlace()),
197+
argmax->mutable_data<int64_t>(ctx.GetPlace()));
198+
199+
return;
200+
}
201+
};
202+
203+
template <typename Place, typename T>
204+
class GPURoiPoolGradOpKernel : public framework::OpKernel<T> {
205+
public:
206+
void Compute(const framework::ExecutionContext& ctx) const override {
207+
auto* in = ctx.Input<Tensor>("X");
208+
auto* rois = ctx.Input<Tensor>("Rois");
209+
auto* argmax = ctx.Input<Tensor>("Argmax");
210+
211+
auto* out_grad =
212+
ctx.Input<Tensor>(framework::GradVarName("Out"));
213+
auto* x_grad =
214+
ctx.Output<Tensor>(framework::GradVarName("X"));
215+
216+
auto pooled_height = ctx.Attr<int>("pooled_height");
217+
auto pooled_width = ctx.Attr<int>("pooled_width");
218+
auto spatial_scale = ctx.Attr<float>("spatial_scale");
219+
220+
int rois_num = rois->dims()[0];
221+
int channels = in->dims()[1];
222+
int height = in->dims()[2];
223+
int width = in->dims()[3];
224+
225+
if (x_grad) {
226+
x_grad->Resize(in->dims());
227+
x_grad->mutable_data<T>(ctx.GetPlace());
228+
math::SetConstant<Place, T> set_zero;
229+
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
230+
231+
int output_grad_size = out_grad->numel();
232+
int blocks = PADDLE_OPERATORS_ROIPOOL_GET_BLOCKS(output_grad_size);
233+
int threads = PADDLE_OPERATORS_ROIPOOL_CUDA_NUM_THREADS;
234+
235+
if (output_grad_size > 0) {
236+
GPURoiPoolBackward<T>
237+
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
238+
output_grad_size,
239+
rois->data<int64_t>(),
240+
out_grad->data<T>(),
241+
argmax->data<int64_t>(),
242+
rois_num,
243+
spatial_scale,
244+
channels,
245+
height,
246+
width,
247+
pooled_height,
248+
pooled_width,
249+
x_grad->mutable_data<T>(ctx.GetPlace()));
250+
}
251+
return;
252+
}
253+
}
254+
};
255+
256+
} // namespace operators
257+
} // namespace paddle
258+
259+
namespace ops = paddle::operators;
260+
REGISTER_OP_GPU_KERNEL(
261+
roi_pool,
262+
ops::GPURoiPoolOpKernel<paddle::platform::GPUPlace, float>);
263+
REGISTER_OP_GPU_KERNEL(
264+
roi_pool_grad,
265+
ops::GPURoiPoolGradOpKernel<paddle::platform::GPUPlace, float>);

0 commit comments

Comments
 (0)