Skip to content

Commit 5178d9f

Browse files
authored
support multiclass nms for multi-batch, test=develop (#28164)
1 parent 5d94a5c commit 5178d9f

File tree

2 files changed

+61
-1
lines changed

2 files changed

+61
-1
lines changed

paddle/fluid/operators/detection/multiclass_nms_op.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
290290
} else {
291291
sdata = scores_data + label * predict_dim;
292292
}
293+
293294
for (size_t j = 0; j < indices.size(); ++j) {
294295
int idx = indices[j];
295296
odata[count * out_dim] = label; // label
@@ -333,17 +334,22 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
333334
Tensor boxes_slice, scores_slice;
334335
int n = score_size == 3 ? batch_size : boxes->lod().back().size() - 1;
335336
for (int i = 0; i < n; ++i) {
337+
std::map<int, std::vector<int>> indices;
336338
if (score_size == 3) {
337339
scores_slice = scores->Slice(i, i + 1);
338340
scores_slice.Resize({score_dims[1], score_dims[2]});
339341
boxes_slice = boxes->Slice(i, i + 1);
340342
boxes_slice.Resize({score_dims[2], box_dim});
341343
} else {
342344
auto boxes_lod = boxes->lod().back();
345+
if (boxes_lod[i] == boxes_lod[i + 1]) {
346+
all_indices.push_back(indices);
347+
batch_starts.push_back(batch_starts.back());
348+
continue;
349+
}
343350
scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
344351
boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
345352
}
346-
std::map<int, std::vector<int>> indices;
347353
MultiClassNMS(ctx, scores_slice, boxes_slice, score_size, &indices,
348354
&num_nmsed_out);
349355
all_indices.push_back(indices);
@@ -375,12 +381,14 @@ class MultiClassNMSKernel : public framework::OpKernel<T> {
375381
}
376382
} else {
377383
auto boxes_lod = boxes->lod().back();
384+
if (boxes_lod[i] == boxes_lod[i + 1]) continue;
378385
scores_slice = scores->Slice(boxes_lod[i], boxes_lod[i + 1]);
379386
boxes_slice = boxes->Slice(boxes_lod[i], boxes_lod[i + 1]);
380387
if (return_index) {
381388
offset = boxes_lod[i] * score_dims[1];
382389
}
383390
}
391+
384392
int64_t s = batch_starts[i];
385393
int64_t e = batch_starts[i + 1];
386394
if (e > s) {

python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import copy
1919
from op_test import OpTest
20+
import paddle
2021
import paddle.fluid as fluid
2122
from paddle.fluid import Program, program_guard
2223

@@ -171,6 +172,9 @@ def lod_multiclass_nms(boxes, scores, background, score_threshold,
171172
lod = []
172173
head = 0
173174
for n in range(len(box_lod[0])):
175+
if box_lod[0][n] == 0:
176+
lod.append(0)
177+
continue
174178
box = boxes[head:head + box_lod[0][n]]
175179
score = scores[head:head + box_lod[0][n]]
176180
offset = head
@@ -357,6 +361,53 @@ def test_check_output(self):
357361
self.check_output()
358362

359363

364+
class TestMulticlassNMSNoBox(TestMulticlassNMSLoDInput):
365+
def setUp(self):
366+
self.set_argument()
367+
M = 1200
368+
C = 21
369+
BOX_SIZE = 4
370+
box_lod = [[0, 1200, 0]]
371+
background = 0
372+
nms_threshold = 0.3
373+
nms_top_k = 400
374+
keep_top_k = 200
375+
score_threshold = self.score_threshold
376+
normalized = False
377+
378+
scores = np.random.random((M, C)).astype('float32')
379+
380+
scores = np.apply_along_axis(softmax, 1, scores)
381+
382+
boxes = np.random.random((M, C, BOX_SIZE)).astype('float32')
383+
boxes[:, :, 0] = boxes[:, :, 0] * 10
384+
boxes[:, :, 1] = boxes[:, :, 1] * 10
385+
boxes[:, :, 2] = boxes[:, :, 2] * 10 + 10
386+
boxes[:, :, 3] = boxes[:, :, 3] * 10 + 10
387+
388+
det_outs, lod = lod_multiclass_nms(
389+
boxes, scores, background, score_threshold, nms_threshold,
390+
nms_top_k, keep_top_k, box_lod, normalized)
391+
det_outs = np.array(det_outs).astype('float32')
392+
nmsed_outs = det_outs[:, :-1].astype('float32') if len(
393+
det_outs) else det_outs
394+
self.op_type = 'multiclass_nms'
395+
self.inputs = {
396+
'BBoxes': (boxes, box_lod),
397+
'Scores': (scores, box_lod),
398+
}
399+
self.outputs = {'Out': (nmsed_outs, [lod])}
400+
self.attrs = {
401+
'background_label': 0,
402+
'nms_threshold': nms_threshold,
403+
'nms_top_k': nms_top_k,
404+
'keep_top_k': keep_top_k,
405+
'score_threshold': score_threshold,
406+
'nms_eta': 1.0,
407+
'normalized': normalized,
408+
}
409+
410+
360411
class TestIOU(unittest.TestCase):
361412
def test_iou(self):
362413
box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32')
@@ -521,4 +572,5 @@ def test_scores_Variable():
521572

522573

523574
if __name__ == '__main__':
575+
paddle.enable_static()
524576
unittest.main()

0 commit comments

Comments
 (0)