Skip to content

Commit fcd4cf7

Browse files
authored
Add matrix_nms_op (#25333)
test=release/1.8
1 parent d171f37 commit fcd4cf7

File tree

4 files changed

+825
-0
lines changed

4 files changed

+825
-0
lines changed

paddle/fluid/operators/detection/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ detection_library(rpn_target_assign_op SRCS rpn_target_assign_op.cc)
3232
detection_library(generate_proposal_labels_op SRCS generate_proposal_labels_op.cc)
3333
detection_library(multiclass_nms_op SRCS multiclass_nms_op.cc DEPS gpc)
3434
detection_library(locality_aware_nms_op SRCS locality_aware_nms_op.cc DEPS gpc)
35+
detection_library(matrix_nms_op SRCS matrix_nms_op.cc DEPS gpc)
3536
detection_library(box_clip_op SRCS box_clip_op.cc box_clip_op.cu)
3637
detection_library(yolov3_loss_op SRCS yolov3_loss_op.cc)
3738
detection_library(yolo_box_op SRCS yolo_box_op.cc yolo_box_op.cu)
Lines changed: 389 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
limitations under the License. */
13+
14+
#include "paddle/fluid/framework/op_registry.h"
15+
#include "paddle/fluid/operators/detection/nms_util.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
using LoDTensor = framework::LoDTensor;
22+
23+
class MatrixNMSOp : public framework::OperatorWithKernel {
24+
public:
25+
using framework::OperatorWithKernel::OperatorWithKernel;
26+
27+
void InferShape(framework::InferShapeContext* ctx) const override {
28+
OP_INOUT_CHECK(ctx->HasInput("BBoxes"), "Input", "BBoxes", "MatrixNMS");
29+
OP_INOUT_CHECK(ctx->HasInput("Scores"), "Input", "Scores", "MatrixNMS");
30+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "MatrixNMS");
31+
auto box_dims = ctx->GetInputDim("BBoxes");
32+
auto score_dims = ctx->GetInputDim("Scores");
33+
auto score_size = score_dims.size();
34+
35+
if (ctx->IsRuntime()) {
36+
PADDLE_ENFORCE_EQ(score_size == 3, true,
37+
platform::errors::InvalidArgument(
38+
"The rank of Input(Scores) must be 3. "
39+
"But received rank = %d.",
40+
score_size));
41+
PADDLE_ENFORCE_EQ(box_dims.size(), 3,
42+
platform::errors::InvalidArgument(
43+
"The rank of Input(BBoxes) must be 3."
44+
"But received rank = %d.",
45+
box_dims.size()));
46+
PADDLE_ENFORCE_EQ(box_dims[2] == 4, true,
47+
platform::errors::InvalidArgument(
48+
"The last dimension of Input (BBoxes) must be 4, "
49+
"represents the layout of coordinate "
50+
"[xmin, ymin, xmax, ymax]."));
51+
PADDLE_ENFORCE_EQ(
52+
box_dims[1], score_dims[2],
53+
platform::errors::InvalidArgument(
54+
"The 2nd dimension of Input(BBoxes) must be equal to "
55+
"last dimension of Input(Scores), which represents the "
56+
"predicted bboxes."
57+
"But received box_dims[1](%s) != socre_dims[2](%s)",
58+
box_dims[1], score_dims[2]));
59+
}
60+
ctx->SetOutputDim("Out", {box_dims[1], box_dims[2] + 2});
61+
ctx->SetOutputDim("Index", {box_dims[1], 1});
62+
if (!ctx->IsRuntime()) {
63+
ctx->SetLoDLevel("Out", std::max(ctx->GetLoDLevel("BBoxes"), 1));
64+
ctx->SetLoDLevel("Index", std::max(ctx->GetLoDLevel("BBoxes"), 1));
65+
}
66+
}
67+
68+
protected:
69+
framework::OpKernelType GetExpectedKernelType(
70+
const framework::ExecutionContext& ctx) const override {
71+
return framework::OpKernelType(
72+
OperatorWithKernel::IndicateVarDataType(ctx, "Scores"),
73+
platform::CPUPlace());
74+
}
75+
};
76+
77+
template <typename T, bool gaussian>
78+
struct decay_score;
79+
80+
template <typename T>
81+
struct decay_score<T, true> {
82+
T operator()(T iou, T max_iou, T sigma) {
83+
return std::exp((max_iou * max_iou - iou * iou) * sigma);
84+
}
85+
};
86+
87+
template <typename T>
88+
struct decay_score<T, false> {
89+
T operator()(T iou, T max_iou, T sigma) {
90+
return (1. - iou) / (1. - max_iou);
91+
}
92+
};
93+
94+
template <typename T, bool gaussian>
95+
void NMSMatrix(const Tensor& bbox, const Tensor& scores,
96+
const T score_threshold, const T post_threshold,
97+
const float sigma, const int64_t top_k, const bool normalized,
98+
std::vector<int>* selected_indices,
99+
std::vector<T>* decayed_scores) {
100+
int64_t num_boxes = bbox.dims()[0];
101+
int64_t box_size = bbox.dims()[1];
102+
103+
auto score_ptr = scores.data<T>();
104+
auto bbox_ptr = bbox.data<T>();
105+
106+
std::vector<int32_t> perm(num_boxes);
107+
std::iota(perm.begin(), perm.end(), 0);
108+
auto end = std::remove_if(perm.begin(), perm.end(),
109+
[&score_ptr, score_threshold](int32_t idx) {
110+
return score_ptr[idx] <= score_threshold;
111+
});
112+
113+
auto sort_fn = [&score_ptr](int32_t lhs, int32_t rhs) {
114+
return score_ptr[lhs] > score_ptr[rhs];
115+
};
116+
117+
int64_t num_pre = std::distance(perm.begin(), end);
118+
if (num_pre <= 0) {
119+
return;
120+
}
121+
if (top_k > -1 && num_pre > top_k) {
122+
num_pre = top_k;
123+
}
124+
std::partial_sort(perm.begin(), perm.begin() + num_pre, end, sort_fn);
125+
126+
std::vector<T> iou_matrix((num_pre * (num_pre - 1)) >> 1);
127+
std::vector<T> iou_max(num_pre);
128+
129+
iou_max[0] = 0.;
130+
for (int64_t i = 1; i < num_pre; i++) {
131+
T max_iou = 0.;
132+
auto idx_a = perm[i];
133+
for (int64_t j = 0; j < i; j++) {
134+
auto idx_b = perm[j];
135+
auto iou = JaccardOverlap<T>(bbox_ptr + idx_a * box_size,
136+
bbox_ptr + idx_b * box_size, normalized);
137+
max_iou = std::max(max_iou, iou);
138+
iou_matrix[i * (i - 1) / 2 + j] = iou;
139+
}
140+
iou_max[i] = max_iou;
141+
}
142+
143+
if (score_ptr[perm[0]] > post_threshold) {
144+
selected_indices->push_back(perm[0]);
145+
decayed_scores->push_back(score_ptr[perm[0]]);
146+
}
147+
148+
decay_score<T, gaussian> decay_fn;
149+
for (int64_t i = 1; i < num_pre; i++) {
150+
T min_decay = 1.;
151+
for (int64_t j = 0; j < i; j++) {
152+
auto max_iou = iou_max[j];
153+
auto iou = iou_matrix[i * (i - 1) / 2 + j];
154+
auto decay = decay_fn(iou, max_iou, sigma);
155+
min_decay = std::min(min_decay, decay);
156+
}
157+
auto ds = min_decay * score_ptr[perm[i]];
158+
if (ds <= post_threshold) continue;
159+
selected_indices->push_back(perm[i]);
160+
decayed_scores->push_back(ds);
161+
}
162+
}
163+
164+
template <typename T>
165+
class MatrixNMSKernel : public framework::OpKernel<T> {
166+
public:
167+
size_t MultiClassMatrixNMS(const Tensor& scores, const Tensor& bboxes,
168+
std::vector<T>* out, std::vector<int>* indices,
169+
int start, int64_t background_label,
170+
int64_t nms_top_k, int64_t keep_top_k,
171+
bool normalized, T score_threshold,
172+
T post_threshold, bool use_gaussian,
173+
float gaussian_sigma) const {
174+
std::vector<int> all_indices;
175+
std::vector<T> all_scores;
176+
std::vector<T> all_classes;
177+
all_indices.reserve(scores.numel());
178+
all_scores.reserve(scores.numel());
179+
all_classes.reserve(scores.numel());
180+
181+
size_t num_det = 0;
182+
auto class_num = scores.dims()[0];
183+
Tensor score_slice;
184+
for (int64_t c = 0; c < class_num; ++c) {
185+
if (c == background_label) continue;
186+
score_slice = scores.Slice(c, c + 1);
187+
if (use_gaussian) {
188+
NMSMatrix<T, true>(bboxes, score_slice, score_threshold, post_threshold,
189+
gaussian_sigma, nms_top_k, normalized, &all_indices,
190+
&all_scores);
191+
} else {
192+
NMSMatrix<T, false>(bboxes, score_slice, score_threshold,
193+
post_threshold, gaussian_sigma, nms_top_k,
194+
normalized, &all_indices, &all_scores);
195+
}
196+
for (size_t i = 0; i < all_indices.size() - num_det; i++) {
197+
all_classes.push_back(static_cast<T>(c));
198+
}
199+
num_det = all_indices.size();
200+
}
201+
202+
if (num_det <= 0) {
203+
return num_det;
204+
}
205+
206+
if (keep_top_k > -1) {
207+
auto k = static_cast<size_t>(keep_top_k);
208+
if (num_det > k) num_det = k;
209+
}
210+
211+
std::vector<int32_t> perm(all_indices.size());
212+
std::iota(perm.begin(), perm.end(), 0);
213+
214+
std::partial_sort(perm.begin(), perm.begin() + num_det, perm.end(),
215+
[&all_scores](int lhs, int rhs) {
216+
return all_scores[lhs] > all_scores[rhs];
217+
});
218+
219+
for (size_t i = 0; i < num_det; i++) {
220+
auto p = perm[i];
221+
auto idx = all_indices[p];
222+
auto cls = all_classes[p];
223+
auto score = all_scores[p];
224+
auto bbox = bboxes.data<T>() + idx * bboxes.dims()[1];
225+
(*indices).push_back(start + idx);
226+
(*out).push_back(cls);
227+
(*out).push_back(score);
228+
for (int j = 0; j < bboxes.dims()[1]; j++) {
229+
(*out).push_back(bbox[j]);
230+
}
231+
}
232+
233+
return num_det;
234+
}
235+
236+
void Compute(const framework::ExecutionContext& ctx) const override {
237+
auto* boxes = ctx.Input<LoDTensor>("BBoxes");
238+
auto* scores = ctx.Input<LoDTensor>("Scores");
239+
auto* outs = ctx.Output<LoDTensor>("Out");
240+
auto* index = ctx.Output<LoDTensor>("Index");
241+
242+
auto background_label = ctx.Attr<int>("background_label");
243+
auto nms_top_k = ctx.Attr<int>("nms_top_k");
244+
auto keep_top_k = ctx.Attr<int>("keep_top_k");
245+
auto normalized = ctx.Attr<bool>("normalized");
246+
auto score_threshold = ctx.Attr<float>("score_threshold");
247+
auto post_threshold = ctx.Attr<float>("post_threshold");
248+
auto use_gaussian = ctx.Attr<bool>("use_gaussian");
249+
auto gaussian_sigma = ctx.Attr<float>("gaussian_sigma");
250+
251+
auto score_dims = scores->dims();
252+
auto batch_size = score_dims[0];
253+
auto num_boxes = score_dims[2];
254+
auto box_dim = boxes->dims()[2];
255+
auto out_dim = box_dim + 2;
256+
257+
Tensor boxes_slice, scores_slice;
258+
size_t num_out = 0;
259+
std::vector<size_t> offsets = {0};
260+
std::vector<T> detections;
261+
std::vector<int> indices;
262+
detections.reserve(out_dim * num_boxes * batch_size);
263+
indices.reserve(num_boxes * batch_size);
264+
for (int i = 0; i < batch_size; ++i) {
265+
scores_slice = scores->Slice(i, i + 1);
266+
scores_slice.Resize({score_dims[1], score_dims[2]});
267+
boxes_slice = boxes->Slice(i, i + 1);
268+
boxes_slice.Resize({score_dims[2], box_dim});
269+
int start = i * score_dims[2];
270+
num_out = MultiClassMatrixNMS(
271+
scores_slice, boxes_slice, &detections, &indices, start,
272+
background_label, nms_top_k, keep_top_k, normalized, score_threshold,
273+
post_threshold, use_gaussian, gaussian_sigma);
274+
offsets.push_back(offsets.back() + num_out);
275+
}
276+
277+
int64_t num_kept = offsets.back();
278+
if (num_kept == 0) {
279+
outs->mutable_data<T>({0, out_dim}, ctx.GetPlace());
280+
index->mutable_data<int>({0, 1}, ctx.GetPlace());
281+
} else {
282+
outs->mutable_data<T>({num_kept, out_dim}, ctx.GetPlace());
283+
index->mutable_data<int>({num_kept, 1}, ctx.GetPlace());
284+
std::copy(detections.begin(), detections.end(), outs->data<T>());
285+
std::copy(indices.begin(), indices.end(), index->data<int>());
286+
}
287+
288+
framework::LoD lod;
289+
lod.emplace_back(offsets);
290+
outs->set_lod(lod);
291+
index->set_lod(lod);
292+
}
293+
};
294+
295+
class MatrixNMSOpMaker : public framework::OpProtoAndCheckerMaker {
296+
public:
297+
void Make() override {
298+
AddInput("BBoxes",
299+
"(Tensor) A 3-D Tensor with shape "
300+
"[N, M, 4] represents the predicted locations of M bounding boxes"
301+
", N is the batch size. "
302+
"Each bounding box has four coordinate values and the layout is "
303+
"[xmin, ymin, xmax, ymax], when box size equals to 4.");
304+
AddInput("Scores",
305+
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the "
306+
"predicted confidence predictions. N is the batch size, C is the "
307+
"class number, M is number of bounding boxes. For each category "
308+
"there are total M scores which corresponding M bounding boxes. "
309+
" Please note, M is equal to the 2nd dimension of BBoxes. ");
310+
AddAttr<int>(
311+
"background_label",
312+
"(int, default: 0) "
313+
"The index of background label, the background label will be ignored. "
314+
"If set to -1, then all categories will be considered.")
315+
.SetDefault(0);
316+
AddAttr<float>("score_threshold",
317+
"(float) "
318+
"Threshold to filter out bounding boxes with low "
319+
"confidence score.");
320+
AddAttr<float>("post_threshold",
321+
"(float, default 0.) "
322+
"Threshold to filter out bounding boxes with low "
323+
"confidence score AFTER decaying.")
324+
.SetDefault(0.);
325+
AddAttr<int>("nms_top_k",
326+
"(int64_t) "
327+
"Maximum number of detections to be kept according to the "
328+
"confidences after the filtering detections based on "
329+
"score_threshold");
330+
AddAttr<int>("keep_top_k",
331+
"(int64_t) "
332+
"Number of total bboxes to be kept per image after NMS "
333+
"step. -1 means keeping all bboxes after NMS step.");
334+
AddAttr<bool>("normalized",
335+
"(bool, default true) "
336+
"Whether detections are normalized.")
337+
.SetDefault(true);
338+
AddAttr<bool>("use_gaussian",
339+
"(bool, default false) "
340+
"Whether to use Gaussian as decreasing function.")
341+
.SetDefault(false);
342+
AddAttr<float>("gaussian_sigma",
343+
"(float) "
344+
"Sigma for Gaussian decreasing function, only takes effect ",
345+
"when 'use_gaussian' is enabled.")
346+
.SetDefault(2.);
347+
AddOutput("Out",
348+
"(LoDTensor) A 2-D LoDTensor with shape [No, 6] represents the "
349+
"detections. Each row has 6 values: "
350+
"[label, confidence, xmin, ymin, xmax, ymax]. "
351+
"the offsets in first dimension are called LoD, the number of "
352+
"offset is N + 1, if LoD[i + 1] - LoD[i] == 0, means there is "
353+
"no detected bbox.");
354+
AddOutput("Index",
355+
"(LoDTensor) A 2-D LoDTensor with shape [No, 1] represents the "
356+
"index of selected bbox. The index is the absolute index cross "
357+
"batches.");
358+
AddComment(R"DOC(
359+
This operator does multi-class matrix non maximum suppression (NMS) on batched
360+
boxes and scores.
361+
In the NMS step, this operator greedily selects a subset of detection bounding
362+
boxes that have high scores larger than score_threshold, if providing this
363+
threshold, then selects the largest nms_top_k confidences scores if nms_top_k
364+
is larger than -1. Then this operator decays boxes score according to the
365+
Matrix NMS scheme.
366+
Aftern NMS step, at most keep_top_k number of total bboxes are to be kept
367+
per image if keep_top_k is larger than -1.
368+
This operator support multi-class and batched inputs. It applying NMS
369+
independently for each class. The outputs is a 2-D LoDTenosr, for each
370+
image, the offsets in first dimension of LoDTensor are called LoD, the number
371+
of offset is N + 1, where N is the batch size. If LoD[i + 1] - LoD[i] == 0,
372+
means there is no detected bbox for this image.
373+
374+
For more information on Matrix NMS, please refer to:
375+
https://arxiv.org/abs/2003.10152
376+
)DOC");
377+
}
378+
};
379+
380+
} // namespace operators
381+
} // namespace paddle
382+
383+
namespace ops = paddle::operators;
384+
REGISTER_OPERATOR(
385+
matrix_nms, ops::MatrixNMSOp, ops::MatrixNMSOpMaker,
386+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
387+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
388+
REGISTER_OP_CPU_KERNEL(matrix_nms, ops::MatrixNMSKernel<float>,
389+
ops::MatrixNMSKernel<double>);

0 commit comments

Comments
 (0)