Skip to content

Commit 765085d

Browse files
authored
Merge pull request #13904 from jerrywgz/roialign
Add RoI align operator.
2 parents da722d6 + 9a14ca9 commit 765085d

File tree

7 files changed

+1081
-0
lines changed

7 files changed

+1081
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ paddle.fluid.layers.pad ArgSpec(args=['x', 'paddings', 'pad_value', 'name'], var
116116
paddle.fluid.layers.pad_constant_like ArgSpec(args=['x', 'y', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0.0, None))
117117
paddle.fluid.layers.label_smooth ArgSpec(args=['label', 'prior_dist', 'epsilon', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 0.1, 'float32', None))
118118
paddle.fluid.layers.roi_pool ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale'], varargs=None, keywords=None, defaults=(1, 1, 1.0))
119+
paddle.fluid.layers.roi_align ArgSpec(args=['input', 'rois', 'pooled_height', 'pooled_width', 'spatial_scale', 'sampling_ratio', 'name'], varargs=None, keywords=None, defaults=(1, 1, 1.0, -1, None))
119120
paddle.fluid.layers.dice_loss ArgSpec(args=['input', 'label', 'epsilon'], varargs=None, keywords=None, defaults=(1e-05,))
120121
paddle.fluid.layers.image_resize ArgSpec(args=['input', 'out_shape', 'scale', 'name', 'resample'], varargs=None, keywords=None, defaults=(None, None, None, 'BILINEAR'))
121122
paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', 'resample'], varargs=None, keywords=None, defaults=('BILINEAR',))
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/roi_align_op.h"
13+
14+
namespace paddle {
15+
namespace operators {
16+
17+
using Tensor = framework::Tensor;
18+
using LoDTensor = framework::LoDTensor;
19+
20+
class ROIAlignOp : public framework::OperatorWithKernel {
21+
public:
22+
using framework::OperatorWithKernel::OperatorWithKernel;
23+
24+
void InferShape(framework::InferShapeContext* ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of ROIAlignOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
28+
"Input(ROIs) of ROIAlignOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of ROIAlignOp should not be null.");
31+
auto input_dims = ctx->GetInputDim("X");
32+
auto rois_dims = ctx->GetInputDim("ROIs");
33+
34+
PADDLE_ENFORCE(input_dims.size() == 4,
35+
"The format of input tensor is NCHW.");
36+
PADDLE_ENFORCE(rois_dims.size() == 2,
37+
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
38+
"given as [[x1, y1, x2, y2], …].");
39+
PADDLE_ENFORCE(rois_dims[1] == 4,
40+
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
41+
"given as [[x1, y1, x2, y2], …].");
42+
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
43+
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
44+
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
45+
46+
PADDLE_ENFORCE_GT(pooled_height, 0,
47+
"The pooled output height must greater than 0");
48+
PADDLE_ENFORCE_GT(pooled_width, 0,
49+
"The pooled output width must greater than 0");
50+
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
51+
"The spatial scale must greater than 0");
52+
53+
auto out_dims = input_dims;
54+
out_dims[0] = rois_dims[0];
55+
out_dims[1] = input_dims[1];
56+
out_dims[2] = pooled_height;
57+
out_dims[3] = pooled_width;
58+
59+
ctx->SetOutputDim("Out", out_dims);
60+
}
61+
62+
protected:
63+
framework::OpKernelType GetExpectedKernelType(
64+
const framework::ExecutionContext& ctx) const override {
65+
return framework::OpKernelType(
66+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
67+
ctx.device_context());
68+
}
69+
};
70+
71+
class ROIAlignGradOp : public framework::OperatorWithKernel {
72+
public:
73+
using framework::OperatorWithKernel::OperatorWithKernel;
74+
75+
void InferShape(framework::InferShapeContext* ctx) const override {
76+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
77+
"The GRAD@Out of ROIAlignGradOp should not be null.");
78+
PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName("X")),
79+
"The GRAD@X of ROIAlignGradOp should not be null.");
80+
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
81+
}
82+
83+
protected:
84+
framework::OpKernelType GetExpectedKernelType(
85+
const framework::ExecutionContext& ctx) const override {
86+
return framework::OpKernelType(
87+
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
88+
ctx.device_context());
89+
}
90+
};
91+
92+
class ROIAlignOpMaker : public framework::OpProtoAndCheckerMaker {
93+
public:
94+
void Make() override {
95+
AddInput("X",
96+
"(Tensor), "
97+
"The input of ROIAlignOp. "
98+
"The format of input tensor is NCHW. Where N is batch size, "
99+
"C is the number of input channels, "
100+
"H is the height of the feature, and "
101+
"W is the width of the feature.");
102+
AddInput("ROIs",
103+
"(LoDTensor), "
104+
"ROIs (Regions of Interest) to pool over. "
105+
"should be a 2-D LoDTensor of shape (num_rois, 4)"
106+
"given as [[x1, y1, x2, y2], …]. "
107+
"(x1, y1) is the top left coordinates, and "
108+
"(x2, y2) is the bottom right coordinates.");
109+
AddOutput("Out",
110+
"(Tensor), "
111+
"The output of ROIAlignOp is a 4-D tensor with shape "
112+
"(num_rois, channels, pooled_h, pooled_w).");
113+
AddAttr<float>("spatial_scale",
114+
"(float, default 1.0), "
115+
"Multiplicative spatial scale factor "
116+
"to translate ROI coords from their input scale "
117+
"to the scale used when pooling.")
118+
.SetDefault(1.0);
119+
AddAttr<int>("pooled_height",
120+
"(int, default 1), "
121+
"The pooled output height.")
122+
.SetDefault(1);
123+
AddAttr<int>("pooled_width",
124+
"(int, default 1), "
125+
"The pooled output width.")
126+
.SetDefault(1);
127+
AddAttr<int>("sampling_ratio",
128+
"(int,default -1),"
129+
"number of sampling points in the interpolation grid"
130+
"If <=0, then grid points are adaptive to roi_width "
131+
"and pooled_w, likewise for height")
132+
.SetDefault(-1);
133+
AddComment(R"DOC(
134+
**RoIAlign Operator**
135+
136+
Region of interest align (also known as RoI align) is to perform
137+
bilinear interpolation on inputs of nonuniform sizes to obtain
138+
fixed-size feature maps (e.g. 7*7)
139+
140+
Dividing each region proposal into equal-sized sections with
141+
the pooled_width and pooled_height. Location remains the origin
142+
result.
143+
144+
In each ROI bin, the value of the four regularly sampled locations
145+
are computed directly through bilinear interpolation. The output is
146+
the mean of four locations.
147+
Thus avoid the misaligned problem.
148+
)DOC");
149+
}
150+
};
151+
152+
} // namespace operators
153+
} // namespace paddle
154+
155+
namespace ops = paddle::operators;
156+
REGISTER_OPERATOR(roi_align, ops::ROIAlignOp, ops::ROIAlignOpMaker,
157+
paddle::framework::DefaultGradOpDescMaker<true>);
158+
REGISTER_OPERATOR(roi_align_grad, ops::ROIAlignGradOp);
159+
REGISTER_OP_CPU_KERNEL(
160+
roi_align,
161+
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, float>,
162+
ops::CPUROIAlignOpKernel<paddle::platform::CPUDeviceContext, double>);
163+
REGISTER_OP_CPU_KERNEL(
164+
roi_align_grad,
165+
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, float>,
166+
ops::CPUROIAlignGradOpKernel<paddle::platform::CPUDeviceContext, double>);

0 commit comments

Comments
 (0)