Skip to content

Commit 0690cca

Browse files
authored
Merge pull request #5831 from wanghaox/roi_pool
Roi pool operator
2 parents 65c859d + cf5b598 commit 0690cca

File tree

4 files changed

+714
-0
lines changed

4 files changed

+714
-0
lines changed

paddle/operators/roi_pool_op.cc

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/roi_pool_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
22+
static constexpr int kROISize = 5;
23+
24+
class ROIPoolOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext* ctx) const override {
29+
PADDLE_ENFORCE(ctx->HasInput("X"),
30+
"Input(X) of ROIPoolOp should not be null.");
31+
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
32+
"Input(ROIs) of ROIPoolOp should not be null.");
33+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
34+
"Output(Out) of ROIPoolOp should not be null.");
35+
PADDLE_ENFORCE(ctx->HasOutput("Argmax"),
36+
"Output(Argmax) of ROIPoolOp should not be null.");
37+
auto input_dims = ctx->GetInputDim("X");
38+
auto rois_dims = ctx->GetInputDim("ROIs");
39+
40+
PADDLE_ENFORCE(input_dims.size() == 4,
41+
"The format of input tensor is NCHW.");
42+
PADDLE_ENFORCE(rois_dims.size() == 2,
43+
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
44+
"given as [[batch_id, x1, y1, x2, y2], …].");
45+
PADDLE_ENFORCE(rois_dims[1] == kROISize,
46+
"ROIs should be a 2-D tensor of shape (num_rois, 5)"
47+
"given as [[batch_id, x1, y1, x2, y2], …].");
48+
49+
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
50+
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
51+
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
52+
53+
PADDLE_ENFORCE_GT(pooled_height, 0,
54+
"The pooled output height must greater than 0");
55+
PADDLE_ENFORCE_GT(pooled_width, 0,
56+
"The pooled output width must greater than 0");
57+
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
58+
"The spatial scale must greater than 0");
59+
60+
auto out_dims = input_dims;
61+
out_dims[0] = rois_dims[0];
62+
out_dims[1] = input_dims[1];
63+
out_dims[2] = pooled_height;
64+
out_dims[3] = pooled_width;
65+
66+
ctx->SetOutputDim("Out", out_dims);
67+
ctx->SetOutputDim("Argmax", out_dims);
68+
}
69+
70+
protected:
71+
framework::OpKernelType GetKernelType(
72+
const framework::ExecutionContext& ctx) const override {
73+
return framework::OpKernelType(
74+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
75+
ctx.device_context());
76+
}
77+
};
78+
79+
class ROIPoolGradOp : public framework::OperatorWithKernel {
80+
public:
81+
using framework::OperatorWithKernel::OperatorWithKernel;
82+
83+
void InferShape(framework::InferShapeContext* ctx) const override {
84+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
85+
"The gradient of Out should not be null.");
86+
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
87+
"The gradient of X should not be null.");
88+
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
89+
}
90+
91+
protected:
92+
framework::OpKernelType GetKernelType(
93+
const framework::ExecutionContext& ctx) const override {
94+
return framework::OpKernelType(
95+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
96+
ctx.device_context());
97+
}
98+
};
99+
100+
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
101+
public:
102+
ROIPoolOpMaker(framework::OpProto* proto,
103+
framework::OpAttrChecker* op_checker)
104+
: OpProtoAndCheckerMaker(proto, op_checker) {
105+
AddInput("X",
106+
"(Tensor), "
107+
"the input of ROIPoolOp. "
108+
"The format of input tensor is NCHW. Where N is batch size, "
109+
"C is the number of input channels, "
110+
"H is the height of the feature, and "
111+
"W is the width of the feature.");
112+
AddInput("ROIs",
113+
"(Tensor), "
114+
"ROIs (Regions of Interest) to pool over. "
115+
"should be a 2-D tensor of shape (num_rois, 5)"
116+
"given as [[batch_id, x1, y1, x2, y2], …]. "
117+
"Where batch_id is the id of the data, "
118+
"(x1, y1) is the top left coordinates, and "
119+
"(x2, y2) is the bottom right coordinates.");
120+
AddOutput("Out",
121+
"(Tensor), "
122+
"The output of ROIPoolOp is a 4-D tensor with shape "
123+
"(num_rois, channels, pooled_h, pooled_w).");
124+
AddOutput("Argmax",
125+
"(Tensor), "
126+
"Argmaxes corresponding to indices in X used "
127+
"for gradient computation. Only output "
128+
"if arg “is_test” is false.").AsIntermediate();
129+
AddAttr<float>("spatial_scale",
130+
"(float, default 1.0), "
131+
"Multiplicative spatial scale factor "
132+
"to translate ROI coords from their input scale "
133+
"to the scale used when pooling.")
134+
.SetDefault(1.0);
135+
AddAttr<int>("pooled_height",
136+
"(int, default 1), "
137+
"The pooled output height.")
138+
.SetDefault(1);
139+
AddAttr<int>("pooled_width",
140+
"(int, default 1), "
141+
"The pooled output width.")
142+
.SetDefault(1);
143+
AddComment(R"DOC(
144+
ROIPool operator
145+
146+
ROI Pooling for Faster-RCNN. The link below is a further introduction:
147+
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
148+
)DOC");
149+
}
150+
};
151+
152+
} // namespace operators
153+
} // namespace paddle
154+
155+
namespace ops = paddle::operators;
156+
REGISTER_OP(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
157+
roi_pool_grad, ops::ROIPoolGradOp);
158+
REGISTER_OP_CPU_KERNEL(
159+
roi_pool,
160+
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, float>,
161+
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);
162+
REGISTER_OP_CPU_KERNEL(
163+
roi_pool_grad,
164+
ops::CPUROIPoolGradOpKernel<paddle::platform::CPUPlace, float>,
165+
ops::CPUROIPoolOpKernel<paddle::platform::CPUPlace, double>);

paddle/operators/roi_pool_op.cu

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/roi_pool_op.h"
16+
#include "paddle/platform/cuda_helper.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
23+
static constexpr int kNumCUDAThreads = 512;
24+
static constexpr int kNumMaxinumNumBlocks = 4096;
25+
static constexpr int kROISize = 5;
26+
27+
static inline int NumBlocks(const int N) {
28+
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
29+
kNumMaxinumNumBlocks);
30+
}
31+
32+
template <typename T>
33+
__global__ void GPUROIPoolForward(
34+
const int nthreads, const T* input_data, const int64_t* input_rois,
35+
const float spatial_scale, const int channels, const int height,
36+
const int width, const int pooled_height, const int pooled_width,
37+
T* output_data, int64_t* argmax_data) {
38+
int index = blockIdx.x * blockDim.x + threadIdx.x;
39+
int offset = blockDim.x * gridDim.x;
40+
for (size_t i = index; i < nthreads; i += offset) {
41+
int pw = index % pooled_width;
42+
int ph = (index / pooled_width) % pooled_height;
43+
int c = (index / pooled_width / pooled_height) % channels;
44+
int n = index / pooled_width / pooled_height / channels;
45+
46+
const int64_t* offset_input_rois = input_rois + n * kROISize;
47+
int roi_batch_ind = offset_input_rois[0];
48+
int roi_start_w = round(offset_input_rois[1] * spatial_scale);
49+
int roi_start_h = round(offset_input_rois[2] * spatial_scale);
50+
int roi_end_w = round(offset_input_rois[3] * spatial_scale);
51+
int roi_end_h = round(offset_input_rois[4] * spatial_scale);
52+
53+
int roi_width = max(roi_end_w - roi_start_w + 1, 1);
54+
int roi_height = max(roi_end_h - roi_start_h + 1, 1);
55+
T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
56+
T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);
57+
58+
int hstart = static_cast<int>(floor(static_cast<T>(ph) * bin_size_h));
59+
int wstart = static_cast<int>(floor(static_cast<T>(pw) * bin_size_w));
60+
int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) * bin_size_h));
61+
int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) * bin_size_w));
62+
63+
hstart = min(max(hstart + roi_start_h, 0), height);
64+
hend = min(max(hend + roi_start_h, 0), height);
65+
wstart = min(max(wstart + roi_start_w, 0), width);
66+
wend = min(max(wend + roi_start_w, 0), width);
67+
bool is_empty = (hend <= hstart) || (wend <= wstart);
68+
69+
T maxval = is_empty ? 0 : -std::numeric_limits<T>::max();
70+
int maxidx = -1;
71+
const T* offset_input_data =
72+
input_data + (roi_batch_ind * channels + c) * height * width;
73+
for (int h = hstart; h < hend; ++h) {
74+
for (int w = wstart; w < wend; ++w) {
75+
int input_data_index = h * width + w;
76+
if (offset_input_data[input_data_index] > maxval) {
77+
maxval = offset_input_data[input_data_index];
78+
maxidx = input_data_index;
79+
}
80+
}
81+
}
82+
output_data[index] = maxval;
83+
if (argmax_data) {
84+
argmax_data[index] = maxidx;
85+
}
86+
}
87+
}
88+
89+
template <typename T>
90+
__global__ void GPUROIPoolBackward(
91+
const int nthreads,
92+
const int64_t* input_rois,
93+
const T* output_grad,
94+
const int64_t* argmax_data,
95+
const int num_rois,
96+
const float spatial_scale,
97+
const int channels,
98+
const int height,
99+
const int width,
100+
const int pooled_height,
101+
const int pooled_width,
102+
T* input_grad) {
103+
int index = blockIdx.x * blockDim.x + threadIdx.x;
104+
int offset = blockDim.x * gridDim.x;
105+
for (int i = index; i < nthreads; i += offset) {
106+
int pw = index % pooled_width;
107+
int ph = (index / pooled_width) % pooled_height;
108+
int c = (index / pooled_width / pooled_height) % channels;
109+
int n = index / pooled_width / pooled_height / channels;
110+
111+
const int64_t* offset_input_rois = input_rois + n * kROISize;
112+
int roi_batch_ind = offset_input_rois[0];
113+
int input_offset = (roi_batch_ind * channels + c) * height * width;
114+
int output_offset = (n * channels + c) * pooled_height * pooled_width;
115+
const T* offset_output_grad = output_grad + output_offset;
116+
T* offset_input_grad = input_grad + input_offset;
117+
const int64_t* offset_argmax_data = argmax_data + output_offset;
118+
119+
int argmax = offset_argmax_data[ph * pooled_width + pw];
120+
if (argmax != -1) {
121+
platform::CudaAtomicAdd(offset_input_grad + argmax,
122+
static_cast<T>(offset_output_grad[ph * pooled_width + pw]));
123+
}
124+
}
125+
}
126+
127+
128+
template <typename Place, typename T>
129+
class GPUROIPoolOpKernel : public framework::OpKernel<T> {
130+
public:
131+
void Compute(const framework::ExecutionContext& ctx) const override {
132+
auto* in = ctx.Input<Tensor>("X");
133+
auto* rois = ctx.Input<Tensor>("ROIs");
134+
auto* out = ctx.Output<Tensor>("Out");
135+
auto* argmax = ctx.Output<Tensor>("Argmax");
136+
137+
auto pooled_height = ctx.Attr<int>("pooled_height");
138+
auto pooled_width = ctx.Attr<int>("pooled_width");
139+
auto spatial_scale = ctx.Attr<float>("spatial_scale");
140+
141+
auto in_dims = in->dims();
142+
auto in_stride = framework::stride(in_dims);
143+
int channels = in_dims[1];
144+
int height = in_dims[2];
145+
int width = in_dims[3];
146+
147+
size_t rois_num = rois->dims()[0];
148+
if (rois_num== 0) return;
149+
150+
int output_size = out->numel();
151+
int blocks = NumBlocks(output_size);
152+
int threads = kNumCUDAThreads;
153+
154+
GPUROIPoolForward<T>
155+
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
156+
output_size,
157+
in->data<T>(),
158+
rois->data<int64_t>(),
159+
spatial_scale,
160+
channels,
161+
height,
162+
width,
163+
pooled_height,
164+
pooled_width,
165+
out->mutable_data<T>(ctx.GetPlace()),
166+
argmax->mutable_data<int64_t>(ctx.GetPlace()));
167+
}
168+
};
169+
170+
template <typename Place, typename T>
171+
class GPUROIPoolGradOpKernel : public framework::OpKernel<T> {
172+
public:
173+
void Compute(const framework::ExecutionContext& ctx) const override {
174+
auto* in = ctx.Input<Tensor>("X");
175+
auto* rois = ctx.Input<Tensor>("ROIs");
176+
auto* argmax = ctx.Input<Tensor>("Argmax");
177+
178+
auto* out_grad =
179+
ctx.Input<Tensor>(framework::GradVarName("Out"));
180+
auto* x_grad =
181+
ctx.Output<Tensor>(framework::GradVarName("X"));
182+
183+
auto pooled_height = ctx.Attr<int>("pooled_height");
184+
auto pooled_width = ctx.Attr<int>("pooled_width");
185+
auto spatial_scale = ctx.Attr<float>("spatial_scale");
186+
187+
size_t rois_num = rois->dims()[0];
188+
int channels = in->dims()[1];
189+
int height = in->dims()[2];
190+
int width = in->dims()[3];
191+
192+
if (x_grad) {
193+
x_grad->mutable_data<T>(ctx.GetPlace());
194+
math::SetConstant<Place, T> set_zero;
195+
set_zero(ctx.device_context(), x_grad, static_cast<T>(0));
196+
197+
int output_grad_size = out_grad->numel();
198+
int blocks = NumBlocks(output_grad_size);
199+
int threads = kNumCUDAThreads;
200+
201+
if (output_grad_size > 0) {
202+
GPUROIPoolBackward<T>
203+
<<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
204+
output_grad_size,
205+
rois->data<int64_t>(),
206+
out_grad->data<T>(),
207+
argmax->data<int64_t>(),
208+
rois_num,
209+
spatial_scale,
210+
channels,
211+
height,
212+
width,
213+
pooled_height,
214+
pooled_width,
215+
x_grad->mutable_data<T>(ctx.GetPlace()));
216+
}
217+
}
218+
}
219+
};
220+
221+
} // namespace operators
222+
} // namespace paddle
223+
224+
namespace ops = paddle::operators;
225+
REGISTER_OP_GPU_KERNEL(
226+
roi_pool,
227+
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, float>,
228+
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);
229+
REGISTER_OP_GPU_KERNEL(
230+
roi_pool_grad,
231+
ops::GPUROIPoolGradOpKernel<paddle::platform::GPUPlace, float>,
232+
ops::GPUROIPoolOpKernel<paddle::platform::GPUPlace, double>);

0 commit comments

Comments
 (0)