@@ -17,24 +17,47 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
- class RoiPoolOp : public framework ::OperatorWithKernel {
20
+ class ROIPoolOp : public framework ::OperatorWithKernel {
21
21
public:
22
22
using framework::OperatorWithKernel::OperatorWithKernel;
23
23
24
24
void InferShape (framework::InferShapeContext* ctx) const override {
25
25
PADDLE_ENFORCE (ctx->HasInput (" X" ),
26
- " Input(X) of RoiPoolOp should not be null." );
27
- PADDLE_ENFORCE (ctx->HasInput (" Rois " ),
28
- " Input(Rois ) of RoiPoolOp should not be null." );
26
+ " Input(X) of ROIPoolOp should not be null." );
27
+ PADDLE_ENFORCE (ctx->HasInput (" ROIs " ),
28
+ " Input(ROIs ) of ROIPoolOp should not be null." );
29
29
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
30
- " Output(Out) of RoiPoolOp should not be null." );
30
+ " Output(Out) of ROIPoolOp should not be null." );
31
31
PADDLE_ENFORCE (ctx->HasOutput (" Argmax" ),
32
- " Output(Argmax) of RoiPoolOp should not be null." );
32
+ " Output(Argmax) of ROIPoolOp should not be null." );
33
33
auto input_dims = ctx->GetInputDim (" X" );
34
-
35
- // Initialize the output's dims to maximum,
36
- // and re-set to real dims by the value of Rois at kernel
37
- ctx->SetOutputDim (" Out" , input_dims);
34
+ auto rois_dims = ctx->GetInputDim (" ROIs" );
35
+
36
+ PADDLE_ENFORCE (input_dims.size () == 4 ,
37
+ " The format of input tensor is NCHW." );
38
+ PADDLE_ENFORCE (rois_dims.size () == 2 ,
39
+ " ROIs should be a 2-D tensor of shape (num_rois, 5)"
40
+ " given as [[batch_id, x1, y1, x2, y2], …]." );
41
+
42
+ int pooled_height = ctx->Attrs ().Get <int >(" pooled_height" );
43
+ int pooled_width = ctx->Attrs ().Get <int >(" pooled_width" );
44
+ float spatial_scale = ctx->Attrs ().Get <float >(" spatial_scale" );
45
+
46
+ PADDLE_ENFORCE_GT (pooled_height, 0 ,
47
+ " The pooled output height must greater than 0" );
48
+ PADDLE_ENFORCE_GT (pooled_width, 0 ,
49
+ " The pooled output width must greater than 0" );
50
+ PADDLE_ENFORCE_GT (spatial_scale, 0 .0f ,
51
+ " The spatial scale must greater than 0" );
52
+
53
+ auto out_dims = input_dims;
54
+ out_dims[0 ] = rois_dims[0 ];
55
+ out_dims[1 ] = input_dims[1 ];
56
+ out_dims[2 ] = pooled_height;
57
+ out_dims[3 ] = pooled_width;
58
+
59
+ ctx->SetOutputDim (" Out" , out_dims);
60
+ ctx->SetOutputDim (" Argmax" , out_dims);
38
61
}
39
62
40
63
protected:
@@ -46,7 +69,7 @@ class RoiPoolOp : public framework::OperatorWithKernel {
46
69
}
47
70
};
48
71
49
- class RoiPoolGradOp : public framework ::OperatorWithKernel {
72
+ class ROIPoolGradOp : public framework ::OperatorWithKernel {
50
73
public:
51
74
using framework::OperatorWithKernel::OperatorWithKernel;
52
75
@@ -67,44 +90,51 @@ class RoiPoolGradOp : public framework::OperatorWithKernel {
67
90
}
68
91
};
69
92
70
- class RoiPoolOpMaker : public framework ::OpProtoAndCheckerMaker {
93
+ class ROIPoolOpMaker : public framework ::OpProtoAndCheckerMaker {
71
94
public:
72
- RoiPoolOpMaker (framework::OpProto* proto,
95
+ ROIPoolOpMaker (framework::OpProto* proto,
73
96
framework::OpAttrChecker* op_checker)
74
97
: OpProtoAndCheckerMaker(proto, op_checker) {
75
98
AddInput (" X" ,
76
99
" (Tensor), "
77
- " the input of RoiPoolOp." );
78
- AddInput (" Rois" ,
100
+ " the input of ROIPoolOp. "
101
+ " The format of input tensor is NCHW. Where N is batch size, "
102
+ " C is the number of input channels, "
103
+ " H is the height of the feature, and "
104
+ " W is the width of the feature." );
105
+ AddInput (" ROIs" ,
79
106
" (Tensor), "
80
- " RoIs (Regions of Interest) to pool over. "
81
- " Should be a 2-D tensor of shape (num_rois, 5)"
82
- " given as [[batch_id, x1, y1, x2, y2], …]." );
107
+ " ROIs (Regions of Interest) to pool over. "
108
+ " should be a 2-D tensor of shape (num_rois, 5)"
109
+ " given as [[batch_id, x1, y1, x2, y2], …]. "
110
+ " Where batch_id is the id of the data, "
111
+ " (x1, y1) is the top left coordinates, and "
112
+ " (x2, y2) is the bottom right coordinates." );
83
113
AddOutput (" Out" ,
84
114
" (Tensor), "
85
- " RoI pooled output 4-D tensor of shape "
86
- " (num_rois, channels, pooled_h, pooled_w)." );
115
+ " The output of ROIPoolOp is a 4-D tensor with shape "
116
+ " (num_rois, channels, pooled_h, pooled_w)." );
87
117
AddOutput (" Argmax" ,
88
118
" (Tensor), "
89
119
" Argmaxes corresponding to indices in X used "
90
120
" for gradient computation. Only output "
91
121
" if arg “is_test” is false." ).AsIntermediate ();
92
122
AddAttr<float >(" spatial_scale" ,
93
- " (float, default 1.0), "
94
- " Multiplicative spatial scale factor "
95
- " to translate ROI coords from their input scale "
96
- " to the scale used when pooling." )
97
- .SetDefault (1.0 );
123
+ " (float, default 1.0), "
124
+ " Multiplicative spatial scale factor "
125
+ " to translate ROI coords from their input scale "
126
+ " to the scale used when pooling." )
127
+ .SetDefault (1.0 );
98
128
AddAttr<int >(" pooled_height" ,
99
- " (int, default 1), "
100
- " The pooled output height." )
101
- .SetDefault (1 );
129
+ " (int, default 1), "
130
+ " The pooled output height." )
131
+ .SetDefault (1 );
102
132
AddAttr<int >(" pooled_width" ,
103
- " (int, default 1), "
104
- " The pooled output width." )
105
- .SetDefault (1 );
133
+ " (int, default 1), "
134
+ " The pooled output width." )
135
+ .SetDefault (1 );
106
136
AddComment (R"DOC(
107
- RoiPool operator
137
+ ROIPool operator
108
138
109
139
ROI Pooling for Faster-RCNN. The link below is a further introduction:
110
140
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
@@ -116,11 +146,11 @@ ROI Pooling for Faster-RCNN. The link below is a further introduction:
116
146
} // namespace paddle
117
147
118
148
namespace ops = paddle::operators;
119
- REGISTER_OP (roi_pool, ops::RoiPoolOp , ops::RoiPoolOpMaker ,
120
- roi_pool_grad, ops::RoiPoolGradOp );
149
+ REGISTER_OP (roi_pool, ops::ROIPoolOp , ops::ROIPoolOpMaker ,
150
+ roi_pool_grad, ops::ROIPoolGradOp );
121
151
REGISTER_OP_CPU_KERNEL (
122
152
roi_pool,
123
- ops::CPURoiPoolOpKernel <paddle::platform::CPUPlace, float >);
153
+ ops::CPUROIPoolOpKernel <paddle::platform::CPUPlace, float >);
124
154
REGISTER_OP_CPU_KERNEL (
125
155
roi_pool_grad,
126
- ops::CPURoiPoolGradOpKernel <paddle::platform::CPUPlace, float >);
156
+ ops::CPUROIPoolGradOpKernel <paddle::platform::CPUPlace, float >);
0 commit comments