Skip to content

Commit e906c8e

Browse files
authored
Merge pull request #14022 from jerrywgz/fix_rpn_target_assign_op
fix random fail in rpn target assign
2 parents c7379a7 + e35fd3b commit e906c8e

File tree

4 files changed

+102
-37
lines changed

4 files changed

+102
-37
lines changed

paddle/fluid/operators/detection/rpn_target_assign_op.cc

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
5252
PADDLE_ENFORCE(
5353
ctx->HasOutput("TargetBBox"),
5454
"Output(TargetBBox) of RpnTargetAssignOp should not be null");
55+
PADDLE_ENFORCE(
56+
ctx->HasOutput("BBoxInsideWeight"),
57+
"Output(BBoxInsideWeight) of RpnTargetAssignOp should not be null");
5558

5659
auto anchor_dims = ctx->GetInputDim("Anchor");
5760
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
@@ -68,6 +71,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
6871
ctx->SetOutputDim("ScoreIndex", {-1});
6972
ctx->SetOutputDim("TargetLabel", {-1, 1});
7073
ctx->SetOutputDim("TargetBBox", {-1, 4});
74+
ctx->SetOutputDim("BBoxInsideWeight", {-1, 4});
7175
}
7276

7377
protected:
@@ -169,6 +173,7 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
169173
const float rpn_positive_overlap,
170174
const float rpn_negative_overlap, std::vector<int>* fg_inds,
171175
std::vector<int>* bg_inds, std::vector<int>* tgt_lbl,
176+
std::vector<int>* fg_fake, std::vector<T>* bbox_inside_weight,
172177
std::minstd_rand engine, bool use_random) {
173178
float epsilon = 0.00001;
174179
int anchor_num = anchor_to_gt_max.dims()[0];
@@ -201,25 +206,41 @@ void ScoreAssign(const T* anchor_by_gt_overlap_data,
201206
// Reservoir Sampling
202207
int fg_num = static_cast<int>(rpn_fg_fraction * rpn_batch_size_per_im);
203208
ReservoirSampling(fg_num, &fg_inds_fake, engine, use_random);
204-
fg_num = static_cast<int>(fg_inds_fake.size());
205-
for (int64_t i = 0; i < fg_num; ++i) {
209+
int fg_fake_num = static_cast<int>(fg_inds_fake.size());
210+
for (int64_t i = 0; i < fg_fake_num; ++i) {
206211
target_label[fg_inds_fake[i]] = 1;
207212
}
208213

209-
int bg_num = rpn_batch_size_per_im - fg_num;
214+
int bg_num = rpn_batch_size_per_im - fg_fake_num;
210215
for (int64_t i = 0; i < anchor_num; ++i) {
211216
if (anchor_to_gt_max_data[i] < rpn_negative_overlap) {
212217
bg_inds_fake.push_back(i);
213218
}
214219
}
215220
ReservoirSampling(bg_num, &bg_inds_fake, engine, use_random);
216221
bg_num = static_cast<int>(bg_inds_fake.size());
222+
int fake_num = 0;
217223
for (int64_t i = 0; i < bg_num; ++i) {
224+
// fg fake found
225+
if (target_label[bg_inds_fake[i]] == 1) {
226+
fake_num++;
227+
fg_fake->emplace_back(fg_inds_fake[0]);
228+
for (int j = 0; j < 4; ++j) {
229+
bbox_inside_weight->emplace_back(T(0.));
230+
}
231+
}
218232
target_label[bg_inds_fake[i]] = 0;
219233
}
220234

235+
for (int64_t i = 0; i < (fg_fake_num - fake_num) * 4; ++i) {
236+
bbox_inside_weight->emplace_back(T(1.));
237+
}
238+
221239
for (int64_t i = 0; i < anchor_num; ++i) {
222-
if (target_label[i] == 1) fg_inds->emplace_back(i);
240+
if (target_label[i] == 1) {
241+
fg_inds->emplace_back(i);
242+
fg_fake->emplace_back(i);
243+
}
223244
if (target_label[i] == 0) bg_inds->emplace_back(i);
224245
}
225246
fg_num = fg_inds->size();
@@ -248,7 +269,8 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
248269
std::vector<int> bg_inds;
249270
std::vector<int> gt_inds;
250271
std::vector<int> tgt_lbl;
251-
272+
std::vector<int> fg_fake;
273+
std::vector<T> bbox_inside_weight;
252274
// Calculate the max IoU between anchors and gt boxes
253275
// Map from anchor to gt box that has highest overlap
254276
auto place = ctx.GetPlace();
@@ -275,32 +297,37 @@ std::vector<Tensor> SampleRpnFgBgGt(const platform::CPUDeviceContext& ctx,
275297
// Follow the Faster RCNN's implementation
276298
ScoreAssign(anchor_by_gt_overlap_data, anchor_to_gt_max, gt_to_anchor_max,
277299
rpn_batch_size_per_im, rpn_fg_fraction, rpn_positive_overlap,
278-
rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, engine,
279-
use_random);
300+
rpn_negative_overlap, &fg_inds, &bg_inds, &tgt_lbl, &fg_fake,
301+
&bbox_inside_weight, engine, use_random);
280302

281303
int fg_num = fg_inds.size();
282304
int bg_num = bg_inds.size();
283-
gt_inds.reserve(fg_num);
284-
for (int i = 0; i < fg_num; ++i) {
285-
gt_inds.emplace_back(argmax[fg_inds[i]]);
305+
int fg_fake_num = fg_fake.size();
306+
gt_inds.reserve(fg_fake_num);
307+
for (int i = 0; i < fg_fake_num; ++i) {
308+
gt_inds.emplace_back(argmax[fg_fake[i]]);
286309
}
287-
288-
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t;
289-
int* loc_index_data = loc_index_t.mutable_data<int>({fg_num}, place);
310+
Tensor loc_index_t, score_index_t, tgt_lbl_t, gt_inds_t, bbox_inside_weight_t;
311+
int* loc_index_data = loc_index_t.mutable_data<int>({fg_fake_num}, place);
290312
int* score_index_data =
291313
score_index_t.mutable_data<int>({fg_num + bg_num}, place);
292314
int* tgt_lbl_data = tgt_lbl_t.mutable_data<int>({fg_num + bg_num}, place);
293-
int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_num}, place);
294-
std::copy(fg_inds.begin(), fg_inds.end(), loc_index_data);
315+
int* gt_inds_data = gt_inds_t.mutable_data<int>({fg_fake_num}, place);
316+
T* bbox_inside_weight_data =
317+
bbox_inside_weight_t.mutable_data<T>({fg_fake_num, 4}, place);
318+
std::copy(fg_fake.begin(), fg_fake.end(), loc_index_data);
295319
std::copy(fg_inds.begin(), fg_inds.end(), score_index_data);
296320
std::copy(bg_inds.begin(), bg_inds.end(), score_index_data + fg_num);
297321
std::copy(tgt_lbl.begin(), tgt_lbl.end(), tgt_lbl_data);
298322
std::copy(gt_inds.begin(), gt_inds.end(), gt_inds_data);
323+
std::copy(bbox_inside_weight.begin(), bbox_inside_weight.end(),
324+
bbox_inside_weight_data);
299325
std::vector<Tensor> loc_score_tgtlbl_gt;
300326
loc_score_tgtlbl_gt.emplace_back(loc_index_t);
301327
loc_score_tgtlbl_gt.emplace_back(score_index_t);
302328
loc_score_tgtlbl_gt.emplace_back(tgt_lbl_t);
303329
loc_score_tgtlbl_gt.emplace_back(gt_inds_t);
330+
loc_score_tgtlbl_gt.emplace_back(bbox_inside_weight_t);
304331

305332
return loc_score_tgtlbl_gt;
306333
}
@@ -318,6 +345,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
318345
auto* score_index = context.Output<LoDTensor>("ScoreIndex");
319346
auto* tgt_bbox = context.Output<LoDTensor>("TargetBBox");
320347
auto* tgt_lbl = context.Output<LoDTensor>("TargetLabel");
348+
auto* bbox_inside_weight = context.Output<LoDTensor>("BBoxInsideWeight");
321349

322350
PADDLE_ENFORCE_EQ(gt_boxes->lod().size(), 1UL,
323351
"RpnTargetAssignOp gt_boxes needs 1 level of LoD");
@@ -340,7 +368,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
340368
score_index->mutable_data<int>({max_num}, place);
341369
tgt_bbox->mutable_data<T>({max_num, 4}, place);
342370
tgt_lbl->mutable_data<int>({max_num, 1}, place);
343-
371+
bbox_inside_weight->mutable_data<T>({max_num, 4}, place);
344372
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
345373

346374
std::random_device rnd;
@@ -394,6 +422,7 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
394422
Tensor sampled_score_index = loc_score_tgtlbl_gt[1];
395423
Tensor sampled_tgtlbl = loc_score_tgtlbl_gt[2];
396424
Tensor sampled_gt_index = loc_score_tgtlbl_gt[3];
425+
Tensor sampled_bbox_inside_weight = loc_score_tgtlbl_gt[4];
397426

398427
int loc_num = sampled_loc_index.dims()[0];
399428
int score_num = sampled_score_index.dims()[0];
@@ -432,6 +461,8 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
432461
AppendRpns<int>(score_index, total_score_num, &sampled_score_index_unmap);
433462
AppendRpns<T>(tgt_bbox, total_loc_num * 4, &sampled_tgt_bbox);
434463
AppendRpns<int>(tgt_lbl, total_score_num, &sampled_tgtlbl);
464+
AppendRpns<T>(bbox_inside_weight, total_loc_num * 4,
465+
&sampled_bbox_inside_weight);
435466
total_loc_num += loc_num;
436467

437468
total_score_num += score_num;
@@ -448,10 +479,12 @@ class RpnTargetAssignKernel : public framework::OpKernel<T> {
448479
score_index->set_lod(loc_score);
449480
tgt_bbox->set_lod(lod_loc);
450481
tgt_lbl->set_lod(loc_score);
482+
bbox_inside_weight->set_lod(lod_loc);
451483
loc_index->Resize({total_loc_num});
452484
score_index->Resize({total_score_num});
453485
tgt_bbox->Resize({total_loc_num, 4});
454486
tgt_lbl->Resize({total_score_num, 1});
487+
bbox_inside_weight->Resize({total_loc_num, 4});
455488
}
456489
};
457490

@@ -514,6 +547,9 @@ class RpnTargetAssignOpMaker : public framework::OpProtoAndCheckerMaker {
514547
"TargetLabel",
515548
"(Tensor<int>), The target labels of each anchor with shape "
516549
"[F + B, 1], F and B are sampled foreground and backgroud number.");
550+
AddOutput("BBoxInsideWeight",
551+
"(Tensor), The bbox inside weight with shape "
552+
"[F, 4], F is the sampled foreground number.");
517553
AddComment(R"DOC(
518554
This operator can be, for a given set of ground truth bboxes and the
519555
anchors, to assign classification and regression targets to each prediction.

python/paddle/fluid/layers/detection.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def rpn_target_assign(bbox_pred,
116116
Returns:
117117
tuple:
118118
A tuple(predicted_scores, predicted_location, target_label,
119-
target_bbox) is returned. The predicted_scores and
120-
predicted_location is the predicted result of the RPN.
119+
target_bbox, bbox_inside_weight) is returned. The predicted_scores
120+
and predicted_location is the predicted result of the RPN.
121121
The target_label and target_bbox is the ground truth,
122122
respectively. The predicted_location is a 2D Tensor with shape
123123
[F, 4], and the shape of target_bbox is same as the shape of
@@ -126,6 +126,8 @@ def rpn_target_assign(bbox_pred,
126126
[F + B, 1], and the shape of target_label is same as the shape
127127
of the predicted_scores, B is the number of the background
128128
anchors, the F and B is depends on the input of this operator.
129+
Bbox_inside_weight represents whether the predicted loc is fake_fg
130+
or not and the shape is [F, 4].
129131
130132
Examples:
131133
.. code-block:: python
@@ -138,7 +140,7 @@ def rpn_target_assign(bbox_pred,
138140
append_batch_size=False, dtype='float32')
139141
gt_boxes = layers.data(name='gt_boxes', shape=[10, 4],
140142
append_batch_size=False, dtype='float32')
141-
loc_pred, score_pred, loc_target, score_target =
143+
loc_pred, score_pred, loc_target, score_target, bbox_inside_weight =
142144
fluid.layers.rpn_target_assign(bbox_pred=bbox_pred,
143145
cls_logits=cls_logits,
144146
anchor_box=anchor_box,
@@ -152,6 +154,8 @@ def rpn_target_assign(bbox_pred,
152154
target_label = helper.create_variable_for_type_inference(dtype='int32')
153155
target_bbox = helper.create_variable_for_type_inference(
154156
dtype=anchor_box.dtype)
157+
bbox_inside_weight = helper.create_variable_for_type_inference(
158+
dtype=anchor_box.dtype)
155159
helper.append_op(
156160
type="rpn_target_assign",
157161
inputs={
@@ -164,7 +168,8 @@ def rpn_target_assign(bbox_pred,
164168
'LocationIndex': loc_index,
165169
'ScoreIndex': score_index,
166170
'TargetLabel': target_label,
167-
'TargetBBox': target_bbox
171+
'TargetBBox': target_bbox,
172+
'BBoxInsideWeight': bbox_inside_weight
168173
},
169174
attrs={
170175
'rpn_batch_size_per_im': rpn_batch_size_per_im,
@@ -179,13 +184,14 @@ def rpn_target_assign(bbox_pred,
179184
score_index.stop_gradient = True
180185
target_label.stop_gradient = True
181186
target_bbox.stop_gradient = True
187+
bbox_inside_weight.stop_gradient = True
182188

183189
cls_logits = nn.reshape(x=cls_logits, shape=(-1, 1))
184190
bbox_pred = nn.reshape(x=bbox_pred, shape=(-1, 4))
185191
predicted_cls_logits = nn.gather(cls_logits, score_index)
186192
predicted_bbox_pred = nn.gather(bbox_pred, loc_index)
187193

188-
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox
194+
return predicted_cls_logits, predicted_bbox_pred, target_label, target_bbox, bbox_inside_weight
189195

190196

191197
def detection_output(loc,

python/paddle/fluid/tests/test_detection.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_rpn_target_assign(self):
301301
dtype='float32',
302302
lod_level=1,
303303
append_batch_size=False)
304-
pred_scores, pred_loc, tgt_lbl, tgt_bbox = layers.rpn_target_assign(
304+
pred_scores, pred_loc, tgt_lbl, tgt_bbox, bbox_inside_weight = layers.rpn_target_assign(
305305
bbox_pred=bbox_pred,
306306
cls_logits=cls_logits,
307307
anchor_box=anchor_box,
@@ -313,15 +313,18 @@ def test_rpn_target_assign(self):
313313
rpn_straddle_thresh=0.0,
314314
rpn_fg_fraction=0.5,
315315
rpn_positive_overlap=0.7,
316-
rpn_negative_overlap=0.3)
316+
rpn_negative_overlap=0.3,
317+
use_random=False)
317318

318319
self.assertIsNotNone(pred_scores)
319320
self.assertIsNotNone(pred_loc)
320321
self.assertIsNotNone(tgt_lbl)
321322
self.assertIsNotNone(tgt_bbox)
323+
self.assertIsNotNone(bbox_inside_weight)
322324
assert pred_scores.shape[1] == 1
323325
assert pred_loc.shape[1] == 4
324326
assert pred_loc.shape[1] == tgt_bbox.shape[1]
327+
print(str(program))
325328

326329

327330
class TestGenerateProposals(unittest.TestCase):

python/paddle/fluid/tests/unittests/test_rpn_target_assign_op.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,38 @@ def rpn_target_assign(anchor_by_gt_overlap,
5050
fg_inds, size=(len(fg_inds) - num_fg), replace=False)
5151
else:
5252
disable_inds = fg_inds[num_fg:]
53+
5354
labels[disable_inds] = -1
5455
fg_inds = np.where(labels == 1)[0]
56+
bbox_inside_weight = np.zeros((len(fg_inds), 4), dtype=np.float32)
5557

5658
num_bg = rpn_batch_size_per_im - np.sum(labels == 1)
5759
bg_inds = np.where(anchor_to_gt_max < rpn_negative_overlap)[0]
5860
if len(bg_inds) > num_bg and use_random:
5961
enable_inds = bg_inds[np.random.randint(len(bg_inds), size=num_bg)]
6062
else:
6163
enable_inds = bg_inds[:num_bg]
64+
65+
fg_fake_inds = np.array([], np.int32)
66+
fg_value = np.array([fg_inds[0]], np.int32)
67+
fake_num = 0
68+
for bg_id in enable_inds:
69+
if bg_id in fg_inds:
70+
fake_num += 1
71+
fg_fake_inds = np.hstack([fg_fake_inds, fg_value])
6272
labels[enable_inds] = 0
73+
74+
bbox_inside_weight[fake_num:, :] = 1
6375
fg_inds = np.where(labels == 1)[0]
6476
bg_inds = np.where(labels == 0)[0]
65-
66-
loc_index = fg_inds
67-
score_index = np.hstack((fg_inds, bg_inds))
77+
loc_index = np.hstack([fg_fake_inds, fg_inds])
78+
score_index = np.hstack([fg_inds, bg_inds])
6879
labels = labels[score_index]
6980
assert not np.any(labels == -1), "Wrong labels with -1"
7081

71-
gt_inds = anchor_to_gt_argmax[fg_inds]
82+
gt_inds = anchor_to_gt_argmax[loc_index]
7283

73-
return loc_index, score_index, labels, gt_inds
84+
return loc_index, score_index, labels, gt_inds, bbox_inside_weight
7485

7586

7687
def get_anchor(n, c, h, w):
@@ -123,9 +134,12 @@ def rpn_target_assign_in_python(all_anchors,
123134
gt_boxes_slice = gt_boxes_slice[not_crowd_inds]
124135
iou = _bbox_overlaps(inside_anchors, gt_boxes_slice)
125136

126-
loc_inds, score_inds, labels, gt_inds = rpn_target_assign(
127-
iou, rpn_batch_size_per_im, rpn_positive_overlap,
128-
rpn_negative_overlap, rpn_fg_fraction, use_random)
137+
loc_inds, score_inds, labels, gt_inds, bbox_inside_weight = \
138+
rpn_target_assign(iou, rpn_batch_size_per_im,
139+
rpn_positive_overlap,
140+
rpn_negative_overlap,
141+
rpn_fg_fraction,
142+
use_random)
129143
# unmap to all anchor
130144
loc_inds = inds_inside[loc_inds]
131145
score_inds = inds_inside[score_inds]
@@ -139,15 +153,18 @@ def rpn_target_assign_in_python(all_anchors,
139153
score_indexes = score_inds
140154
tgt_labels = labels
141155
tgt_bboxes = box_deltas
156+
bbox_inside_weights = bbox_inside_weight
142157
else:
143158
loc_indexes = np.concatenate(
144159
[loc_indexes, loc_inds + i * anchor_num])
145160
score_indexes = np.concatenate(
146161
[score_indexes, score_inds + i * anchor_num])
147162
tgt_labels = np.concatenate([tgt_labels, labels])
148163
tgt_bboxes = np.vstack([tgt_bboxes, box_deltas])
164+
bbox_inside_weights = np.vstack([bbox_inside_weights, \
165+
bbox_inside_weight])
149166

150-
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels
167+
return loc_indexes, score_indexes, tgt_bboxes, tgt_labels, bbox_inside_weights
151168

152169

153170
class TestRpnTargetAssignOp(OpTest):
@@ -182,10 +199,12 @@ def setUp(self):
182199
rpn_fg_fraction = 0.5
183200
use_random = False
184201

185-
loc_index, score_index, tgt_bbox, labels = rpn_target_assign_in_python(
186-
all_anchors, gt_boxes, is_crowd, im_info, lod, rpn_straddle_thresh,
187-
rpn_batch_size_per_im, rpn_positive_overlap, rpn_negative_overlap,
188-
rpn_fg_fraction, use_random)
202+
loc_index, score_index, tgt_bbox, labels, bbox_inside_weights = \
203+
rpn_target_assign_in_python(all_anchors, gt_boxes, is_crowd,
204+
im_info, lod, rpn_straddle_thresh,
205+
rpn_batch_size_per_im, rpn_positive_overlap,
206+
rpn_negative_overlap,
207+
rpn_fg_fraction, use_random)
189208
labels = labels[:, np.newaxis]
190209

191210
self.op_type = "rpn_target_assign"
@@ -207,7 +226,8 @@ def setUp(self):
207226
'LocationIndex': loc_index.astype('int32'),
208227
'ScoreIndex': score_index.astype('int32'),
209228
'TargetBBox': tgt_bbox.astype('float32'),
210-
'TargetLabel': labels.astype('int32')
229+
'TargetLabel': labels.astype('int32'),
230+
'BBoxInsideWeight': bbox_inside_weights.astype('float32')
211231
}
212232

213233
def test_check_output(self):

0 commit comments

Comments
 (0)