Skip to content

Commit 86887b9

Browse files
authored
Cherry pick fix generate proposals labels (#28165)
* fix generate_proposal_labels in cascade-rcnn series model, test=develop * fix example code & unittest, test=develop * update code from review comments, test=develop
1 parent 5178d9f commit 86887b9

File tree

5 files changed

+303
-147
lines changed

5 files changed

+303
-147
lines changed

paddle/fluid/operators/detection/bbox_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,20 @@ void ClipTiledBoxes(const platform::DeviceContext& ctx,
149149
}
150150
}
151151

152+
// Calculate max IoU between each box and ground-truth and
153+
// each row represents one box
154+
template <typename T>
155+
void MaxIoU(const framework::Tensor& iou, framework::Tensor* max_iou) {
156+
const T* iou_data = iou.data<T>();
157+
int row = iou.dims()[0];
158+
int col = iou.dims()[1];
159+
T* max_iou_data = max_iou->data<T>();
160+
for (int i = 0; i < row; ++i) {
161+
const T* v = iou_data + i * col;
162+
T max_v = *std::max_element(v, v + col);
163+
max_iou_data[i] = max_v;
164+
}
165+
}
166+
152167
} // namespace operators
153168
} // namespace paddle

paddle/fluid/operators/detection/generate_proposal_labels_op.cc

Lines changed: 129 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@ void AppendRois(LoDTensor* out, int64_t offset, Tensor* to_add) {
3333
memcpy(out_data + offset, to_add_data, to_add->numel() * sizeof(T));
3434
}
3535

36+
// Filter the ground-truth in RoIs and the RoIs with non-positive area.
37+
// The ground-truth has max overlap with itself so the max_overlap is 1
38+
// and the corresponding RoI will be removed.
39+
template <typename T>
40+
void FilterRoIs(const platform::DeviceContext& ctx, const Tensor& rpn_rois,
41+
const Tensor& max_overlap, Tensor* keep) {
42+
const T* rpn_rois_dt = rpn_rois.data<T>();
43+
const T* max_overlap_dt = max_overlap.data<T>();
44+
int rois_num = max_overlap.numel();
45+
keep->Resize({rois_num});
46+
int* keep_data = keep->mutable_data<int>(ctx.GetPlace());
47+
int keep_len = 0;
48+
for (int i = 0; i < rois_num; ++i) {
49+
if ((rpn_rois_dt[i * 4 + 2] - rpn_rois_dt[i * 4 + 0] + 1) > 0 &&
50+
(rpn_rois_dt[i * 4 + 3] - rpn_rois_dt[i * 4 + 1] + 1) > 0 &&
51+
max_overlap_dt[i] < 1.) {
52+
keep_data[keep_len++] = i;
53+
}
54+
}
55+
keep->Resize({keep_len});
56+
}
57+
3658
class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
3759
public:
3860
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -98,12 +120,21 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
98120
im_info_dims.size(), im_info_dims));
99121

100122
int class_nums = ctx->Attrs().Get<int>("class_nums");
123+
bool is_cascade_rcnn = ctx->Attrs().Get<bool>("is_cascade_rcnn");
124+
if (is_cascade_rcnn) {
125+
PADDLE_ENFORCE_EQ(
126+
ctx->HasInput("MaxOverlap"), true,
127+
platform::errors::NotFound(
128+
"Input(MaxOverlap) of GenerateProposalLabelsOp "
129+
"should not be null when is_cascade_rcnn is True."));
130+
}
101131

102132
ctx->SetOutputDim("Rois", {-1, 4});
103133
ctx->SetOutputDim("LabelsInt32", {-1, 1});
104134
ctx->SetOutputDim("BboxTargets", {-1, 4 * class_nums});
105135
ctx->SetOutputDim("BboxInsideWeights", {-1, 4 * class_nums});
106136
ctx->SetOutputDim("BboxOutsideWeights", {-1, 4 * class_nums});
137+
ctx->SetOutputDim("MaxOverlapWithGT", {-1});
107138
}
108139

109140
protected:
@@ -142,7 +173,6 @@ std::vector<std::vector<int>> SampleFgBgGt(
142173
int64_t row = iou->dims()[0];
143174
int64_t col = iou->dims()[1];
144175
float epsilon = 0.00001;
145-
const T* rpn_rois_dt = rpn_rois.data<T>();
146176
// Follow the Faster RCNN's implementation
147177
for (int64_t i = 0; i < row; ++i) {
148178
const T* v = proposal_to_gt_overlaps + i * col;
@@ -151,11 +181,6 @@ std::vector<std::vector<int>> SampleFgBgGt(
151181
if ((i < gt_num) && (crowd_data[i])) {
152182
max_overlap = -1.0;
153183
}
154-
if (is_cascade_rcnn &&
155-
((rpn_rois_dt[i * 4 + 2] - rpn_rois_dt[i * 4 + 0] + 1) <= 0 ||
156-
(rpn_rois_dt[i * 4 + 3] - rpn_rois_dt[i * 4 + 1] + 1) <= 0)) {
157-
continue;
158-
}
159184
if (max_overlap >= fg_thresh) {
160185
// fg mapped gt label index
161186
for (int64_t j = 0; j < col; ++j) {
@@ -232,12 +257,13 @@ std::vector<std::vector<int>> SampleFgBgGt(
232257

233258
template <typename T>
234259
void GatherBoxesLabels(const platform::CPUDeviceContext& context,
235-
const Tensor& boxes, const Tensor& gt_boxes,
236-
const Tensor& gt_classes,
260+
const Tensor& boxes, const Tensor& max_overlap,
261+
const Tensor& gt_boxes, const Tensor& gt_classes,
237262
const std::vector<int>& fg_inds,
238263
const std::vector<int>& bg_inds,
239264
const std::vector<int>& gt_inds, Tensor* sampled_boxes,
240-
Tensor* sampled_labels, Tensor* sampled_gts) {
265+
Tensor* sampled_labels, Tensor* sampled_gts,
266+
Tensor* sampled_max_overlap) {
241267
int fg_num = fg_inds.size();
242268
int bg_num = bg_inds.size();
243269
Tensor fg_inds_t, bg_inds_t, gt_box_inds_t, gt_label_inds_t;
@@ -264,6 +290,13 @@ void GatherBoxesLabels(const platform::CPUDeviceContext& context,
264290
bg_labels.mutable_data<int>({bg_num}, context.GetPlace());
265291
math::set_constant(context, &bg_labels, 0);
266292
Concat<int>(context, fg_labels, bg_labels, sampled_labels);
293+
294+
Tensor fg_max_overlap, bg_max_overlap;
295+
fg_max_overlap.mutable_data<T>({fg_num}, context.GetPlace());
296+
CPUGather<T>(context, max_overlap, fg_inds_t, &fg_max_overlap);
297+
bg_max_overlap.mutable_data<T>({bg_num}, context.GetPlace());
298+
CPUGather<T>(context, max_overlap, bg_inds_t, &bg_max_overlap);
299+
Concat<T>(context, fg_max_overlap, bg_max_overlap, sampled_max_overlap);
267300
}
268301

269302
template <typename T>
@@ -274,43 +307,58 @@ std::vector<Tensor> SampleRoisForOneImage(
274307
const float fg_thresh, const float bg_thresh_hi, const float bg_thresh_lo,
275308
const std::vector<float>& bbox_reg_weights, const int class_nums,
276309
std::minstd_rand engine, bool use_random, bool is_cascade_rcnn,
277-
bool is_cls_agnostic) {
310+
bool is_cls_agnostic, const Tensor& max_overlap) {
278311
// 1.1 map to original image
279312
auto im_scale = im_info.data<T>()[2];
280-
281313
Tensor rpn_rois;
282314
rpn_rois.mutable_data<T>(rpn_rois_in.dims(), context.GetPlace());
283315
const T* rpn_rois_in_dt = rpn_rois_in.data<T>();
284316
T* rpn_rois_dt = rpn_rois.data<T>();
285-
int gt_num = gt_boxes.dims()[0] * 4;
317+
286318
for (int i = 0; i < rpn_rois.numel(); ++i) {
287-
if (i < gt_num && is_cascade_rcnn) {
288-
rpn_rois_dt[i] = rpn_rois_in_dt[i];
319+
rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale;
320+
}
321+
322+
int proposals_num = 1;
323+
324+
if (is_cascade_rcnn) {
325+
Tensor keep;
326+
FilterRoIs<T>(context, rpn_rois, max_overlap, &keep);
327+
Tensor roi_filter;
328+
// Tensor box_filter;
329+
if (keep.numel() == 0) {
330+
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
331+
roi_filter.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
332+
set_zero(context, &roi_filter, static_cast<T>(0));
289333
} else {
290-
rpn_rois_dt[i] = rpn_rois_in_dt[i] / im_scale;
334+
proposals_num = keep.numel();
335+
roi_filter.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
336+
CPUGather<T>(context, rpn_rois, keep, &roi_filter);
291337
}
338+
T* roi_filter_dt = roi_filter.data<T>();
339+
memcpy(rpn_rois_dt, roi_filter_dt, roi_filter.numel() * sizeof(T));
340+
rpn_rois.Resize(roi_filter.dims());
341+
} else {
342+
proposals_num = rpn_rois.dims()[0];
292343
}
293-
294344
// 1.2 compute overlaps
295-
int proposals_num = rpn_rois.dims()[0];
296-
if (!is_cascade_rcnn) {
297-
proposals_num += gt_boxes.dims()[0];
298-
}
345+
proposals_num += gt_boxes.dims()[0];
346+
299347
Tensor proposal_to_gt_overlaps;
300348
proposal_to_gt_overlaps.mutable_data<T>({proposals_num, gt_boxes.dims()[0]},
301349
context.GetPlace());
302350

303351
Tensor boxes;
304352
boxes.mutable_data<T>({proposals_num, kBoxDim}, context.GetPlace());
305-
if (!is_cascade_rcnn) {
306-
Concat<T>(context, gt_boxes, rpn_rois, &boxes);
307-
} else {
308-
T* boxes_dt = boxes.data<T>();
309-
for (int i = 0; i < boxes.numel(); ++i) {
310-
boxes_dt[i] = rpn_rois_dt[i];
311-
}
312-
}
353+
Concat<T>(context, gt_boxes, rpn_rois, &boxes);
313354
BboxOverlaps<T>(boxes, gt_boxes, &proposal_to_gt_overlaps);
355+
356+
Tensor proposal_with_max_overlap;
357+
proposal_with_max_overlap.mutable_data<T>({proposals_num},
358+
context.GetPlace());
359+
360+
MaxIoU<T>(proposal_to_gt_overlaps, &proposal_with_max_overlap);
361+
314362
// Generate proposal index
315363
std::vector<std::vector<int>> fg_bg_gt =
316364
SampleFgBgGt<T>(context, &proposal_to_gt_overlaps, is_crowd,
@@ -321,17 +369,19 @@ std::vector<Tensor> SampleRoisForOneImage(
321369
std::vector<int> mapped_gt_inds = fg_bg_gt[2]; // mapped_gt_labels
322370

323371
// Gather boxes and labels
324-
Tensor sampled_boxes, sampled_labels, sampled_gts;
372+
Tensor sampled_boxes, sampled_labels, sampled_gts, sampled_max_overlap;
325373
int fg_num = fg_inds.size();
326374
int bg_num = bg_inds.size();
327375
int boxes_num = fg_num + bg_num;
328376
framework::DDim bbox_dim({boxes_num, kBoxDim});
329377
sampled_boxes.mutable_data<T>(bbox_dim, context.GetPlace());
330378
sampled_labels.mutable_data<int>({boxes_num}, context.GetPlace());
331379
sampled_gts.mutable_data<T>({fg_num, kBoxDim}, context.GetPlace());
332-
GatherBoxesLabels<T>(context, boxes, gt_boxes, gt_classes, fg_inds, bg_inds,
333-
mapped_gt_inds, &sampled_boxes, &sampled_labels,
334-
&sampled_gts);
380+
sampled_max_overlap.mutable_data<T>({boxes_num}, context.GetPlace());
381+
GatherBoxesLabels<T>(context, boxes, proposal_with_max_overlap, gt_boxes,
382+
gt_classes, fg_inds, bg_inds, mapped_gt_inds,
383+
&sampled_boxes, &sampled_labels, &sampled_gts,
384+
&sampled_max_overlap);
335385

336386
// Compute targets
337387
Tensor bbox_targets_single;
@@ -390,6 +440,7 @@ std::vector<Tensor> SampleRoisForOneImage(
390440
res.emplace_back(bbox_targets);
391441
res.emplace_back(bbox_inside_weights);
392442
res.emplace_back(bbox_outside_weights);
443+
res.emplace_back(sampled_max_overlap);
393444
return res;
394445
}
395446

@@ -409,6 +460,7 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
409460
auto* bbox_inside_weights = context.Output<LoDTensor>("BboxInsideWeights");
410461
auto* bbox_outside_weights =
411462
context.Output<LoDTensor>("BboxOutsideWeights");
463+
auto* max_overlap_with_gt = context.Output<LoDTensor>("MaxOverlapWithGT");
412464

413465
int batch_size_per_im = context.Attr<int>("batch_size_per_im");
414466
float fg_fraction = context.Attr<float>("fg_fraction");
@@ -446,16 +498,21 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
446498
"received level of LoD is [%d], LoD is [%s].",
447499
gt_boxes->lod().size(), gt_boxes->lod()));
448500
int64_t n = static_cast<int64_t>(rpn_rois->lod().back().size() - 1);
449-
450-
rois->mutable_data<T>({n * batch_size_per_im, kBoxDim}, context.GetPlace());
451-
labels_int32->mutable_data<int>({n * batch_size_per_im, 1},
452-
context.GetPlace());
453-
bbox_targets->mutable_data<T>({n * batch_size_per_im, kBoxDim * class_nums},
501+
int64_t rois_num = rpn_rois->dims()[0];
502+
int64_t gts_num = gt_boxes->dims()[0];
503+
int64_t init_num =
504+
is_cascade_rcnn ? rois_num + gts_num : n * batch_size_per_im;
505+
506+
rois->mutable_data<T>({init_num, kBoxDim}, context.GetPlace());
507+
labels_int32->mutable_data<int>({init_num, 1}, context.GetPlace());
508+
bbox_targets->mutable_data<T>({init_num, kBoxDim * class_nums},
454509
context.GetPlace());
455-
bbox_inside_weights->mutable_data<T>(
456-
{n * batch_size_per_im, kBoxDim * class_nums}, context.GetPlace());
457-
bbox_outside_weights->mutable_data<T>(
458-
{n * batch_size_per_im, kBoxDim * class_nums}, context.GetPlace());
510+
bbox_inside_weights->mutable_data<T>({init_num, kBoxDim * class_nums},
511+
context.GetPlace());
512+
bbox_outside_weights->mutable_data<T>({init_num, kBoxDim * class_nums},
513+
context.GetPlace());
514+
max_overlap_with_gt->Resize({init_num});
515+
max_overlap_with_gt->mutable_data<T>(context.GetPlace());
459516

460517
std::random_device rnd;
461518
std::minstd_rand engine;
@@ -486,25 +543,36 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
486543
Tensor gt_boxes_slice =
487544
gt_boxes->Slice(gt_boxes_lod[i], gt_boxes_lod[i + 1]);
488545
Tensor im_info_slice = im_info->Slice(i, i + 1);
546+
Tensor max_overlap_slice;
547+
if (is_cascade_rcnn) {
548+
auto* max_overlap = context.Input<Tensor>("MaxOverlap");
549+
max_overlap_slice =
550+
max_overlap->Slice(rpn_rois_lod[i], rpn_rois_lod[i + 1]);
551+
} else {
552+
max_overlap_slice.mutable_data<T>({rpn_rois_slice.dims()[0]},
553+
context.GetPlace());
554+
}
489555
std::vector<Tensor> tensor_output = SampleRoisForOneImage<T>(
490556
dev_ctx, rpn_rois_slice, gt_classes_slice, is_crowd_slice,
491557
gt_boxes_slice, im_info_slice, batch_size_per_im, fg_fraction,
492558
fg_thresh, bg_thresh_hi, bg_thresh_lo, bbox_reg_weights, class_nums,
493-
engine, use_random, is_cascade_rcnn, is_cls_agnostic);
559+
engine, use_random, is_cascade_rcnn, is_cls_agnostic,
560+
max_overlap_slice);
494561
Tensor sampled_rois = tensor_output[0];
495562
Tensor sampled_labels_int32 = tensor_output[1];
496563
Tensor sampled_bbox_targets = tensor_output[2];
497564
Tensor sampled_bbox_inside_weights = tensor_output[3];
498565
Tensor sampled_bbox_outside_weights = tensor_output[4];
566+
Tensor sampled_max_overlap = tensor_output[5];
499567

500568
AppendRois<T>(rois, kBoxDim * num_rois, &sampled_rois);
501569
AppendRois<int>(labels_int32, num_rois, &sampled_labels_int32);
502-
AppendRois<T>(bbox_targets, kBoxDim * num_rois * class_nums,
503-
&sampled_bbox_targets);
504-
AppendRois<T>(bbox_inside_weights, kBoxDim * num_rois * class_nums,
505-
&sampled_bbox_inside_weights);
506-
AppendRois<T>(bbox_outside_weights, kBoxDim * num_rois * class_nums,
570+
int64_t offset = kBoxDim * num_rois * class_nums;
571+
AppendRois<T>(bbox_targets, offset, &sampled_bbox_targets);
572+
AppendRois<T>(bbox_inside_weights, offset, &sampled_bbox_inside_weights);
573+
AppendRois<T>(bbox_outside_weights, offset,
507574
&sampled_bbox_outside_weights);
575+
AppendRois<T>(max_overlap_with_gt, num_rois, &sampled_max_overlap);
508576

509577
num_rois += sampled_rois.dims()[0];
510578
lod0.emplace_back(num_rois);
@@ -521,6 +589,8 @@ class GenerateProposalLabelsKernel : public framework::OpKernel<T> {
521589
bbox_targets->Resize({num_rois, kBoxDim * class_nums});
522590
bbox_inside_weights->Resize({num_rois, kBoxDim * class_nums});
523591
bbox_outside_weights->Resize({num_rois, kBoxDim * class_nums});
592+
max_overlap_with_gt->Resize({num_rois});
593+
max_overlap_with_gt->set_lod(lod);
524594
}
525595
};
526596

@@ -550,6 +620,12 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
550620
"(Tensor), This input is a 2D Tensor with shape [B, 3]. "
551621
"B is the number of input images, "
552622
"each element consists of im_height, im_width, im_scale.");
623+
AddInput("MaxOverlap",
624+
"(LoDTensor), This input is a 1D LoDTensor with shape [N]."
625+
"N is the number of Input(RpnRois), "
626+
"each element is the maximum overlap between "
627+
"the proposal RoI and ground-truth.")
628+
.AsDispensable();
553629

554630
AddOutput(
555631
"Rois",
@@ -573,6 +649,12 @@ class GenerateProposalLabelsOpMaker : public framework::OpProtoAndCheckerMaker {
573649
"(LoDTensor), This output is a 2D LoDTensor with shape [P, 4 * "
574650
"class_nums], "
575651
"each element indicates whether a box should contribute to loss.");
652+
AddOutput("MaxOverlapWithGT",
653+
"(LoDTensor), This output is a 1D LoDTensor with shape [P], "
654+
"each element indicates the maxoverlap "
655+
"between output RoIs and ground-truth. "
656+
"The output RoIs may include ground-truth "
657+
"and the output maxoverlap may contain 1.");
576658

577659
AddAttr<int>("batch_size_per_im", "Batch size of rois per images.");
578660
AddAttr<float>("fg_fraction",

0 commit comments

Comments
 (0)