Skip to content

Commit e74901f

Browse files
authored
CodeCamp2023-671 (#2422)
* add nms ops * add some file * new file * some change * Update nms_match.cpp * Update nms_match.cpp * Update __init__.py * Delete test_onnx_match.onnx * Delete tests/test_ops/test_onnx_match.onnx * Update test_nms_match_small.py * Update test_nms_match_small.py * Update nms_match.cpp remove allocate * Update nms_match.py remove some test print * Update test_nms_match_small.py * Update nms_match.cpp * Update nms_match.py * Update test_nms_match_small.py * fix the lint fix the lint * Update test_nms_match_small.py * Update test_nms_match_small.py * Update nms_match.cpp * Update test_nms_match_small.py * Update test_nms_match_small.py add input_names * Update onnxruntime.md * Update onnxruntime.md * Update test_nms_match_small.py * Update onnxruntime.md * Update onnxruntime.md * Update test_nms_match_small.py Add UT in nmsmatch * Update test_nms_match_small.py * Update test_nms_match_small.py
1 parent 59449cc commit e74901f

File tree

8 files changed

+597
-4
lines changed

8 files changed

+597
-4
lines changed
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) OpenMMLab. All rights reserved
2+
#include "nms_match.h"
3+
4+
#include <assert.h>
5+
6+
#include <algorithm>
7+
#include <cassert>
8+
#include <cmath>
9+
#include <iostream>
10+
#include <iterator>
11+
#include <numeric>
12+
#include <vector>
13+
14+
#include "ort_utils.h"
15+
16+
namespace mmdeploy {
17+
struct Box {
18+
float x1, y1, x2, y2;
19+
};
20+
21+
float nms_match_iou(Box box1, Box box2) {
22+
auto inter_x1 = std::max(box1.x1, box2.x1);
23+
auto inter_y1 = std::max(box1.y1, box2.y1);
24+
auto inter_x2 = std::min(box1.x2, box2.x2);
25+
auto inter_y2 = std::min(box1.y2, box2.y2);
26+
27+
auto eps = 1e-10;
28+
29+
auto w = std::max(static_cast<float>(0), inter_x2 - inter_x1);
30+
auto h = std::max(static_cast<float>(0), inter_y2 - inter_y1);
31+
32+
auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1);
33+
auto area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1);
34+
auto inter = w * h;
35+
auto ovr = inter / (area1 + area2 - inter + eps);
36+
return ovr;
37+
}
38+
NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info)
39+
: ort_(api), info_(info) {
40+
// create allocator
41+
allocator_ = Ort::AllocatorWithDefaultOptions();
42+
}
43+
44+
void NMSMatchKernel::Compute(OrtKernelContext* context) {
45+
const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0);
46+
const float* boxes_data = reinterpret_cast<const float*>(ort_.GetTensorData<float>(boxes));
47+
const OrtValue* scores = ort_.KernelContext_GetInput(context, 1);
48+
const float* scores_data = reinterpret_cast<const float*>(ort_.GetTensorData<float>(scores));
49+
const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2);
50+
const float iou_threshold_data = ort_.GetTensorData<float>(iou_threshold_)[0];
51+
const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3);
52+
const float score_threshold_data = ort_.GetTensorData<float>(score_threshold_)[0];
53+
54+
OrtTensorDimensions boxes_dim(ort_, boxes);
55+
OrtTensorDimensions scores_dim(ort_, scores);
56+
// loop over batch
57+
int64_t nbatch = boxes_dim[0];
58+
int64_t nboxes = boxes_dim[1];
59+
int64_t nclass = scores_dim[1];
60+
assert(boxes_dim[2] == 4); //(x1, x2, y1, y2)
61+
// alloc some temp memory
62+
bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes);
63+
64+
std::vector<int64_t> res_order;
65+
for (int64_t k = 0; k < nbatch; k++) {
66+
for (int64_t g = 0; g < nclass; g++) {
67+
for (int64_t i = 0; i < nboxes; i++) {
68+
select[i] = true;
69+
}
70+
// scores
71+
// k * nboxes * nclass means per batch
72+
// g * nboxes means per class
73+
// batch = 2 boxes = 3 classes = 4
74+
std::vector<float> tmp_sc;
75+
// get the class scores
76+
for (int i = 0; i < nboxes; i++) {
77+
tmp_sc.push_back(scores_data[k * nboxes * nclass + g * nboxes + i]);
78+
}
79+
80+
std::vector<int64_t> order(tmp_sc.size());
81+
std::iota(order.begin(), order.end(), 0);
82+
std::sort(order.begin(), order.end(),
83+
[&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; });
84+
for (int64_t _i = 0; _i < nboxes; _i++) {
85+
auto i = order[_i];
86+
if (select[i] == false) continue;
87+
std::vector<int64_t> v_i;
88+
for (int64_t _j = _i + 1; _j < nboxes; _j++) {
89+
auto j = order[_j];
90+
if (select[j] == false) continue;
91+
Box vbox1, vbox2;
92+
vbox1.x1 = boxes_data[k * nboxes * 4 + i * 4];
93+
vbox1.y1 = boxes_data[k * nboxes * 4 + i * 4 + 1];
94+
vbox1.x2 = boxes_data[k * nboxes * 4 + i * 4 + 2];
95+
vbox1.y2 = boxes_data[k * nboxes * 4 + i * 4 + 3];
96+
97+
vbox2.x1 = boxes_data[k * nboxes * 4 + j * 4];
98+
vbox2.y1 = boxes_data[k * nboxes * 4 + j * 4 + 1];
99+
vbox2.x2 = boxes_data[k * nboxes * 4 + j * 4 + 2];
100+
vbox2.y2 = boxes_data[k * nboxes * 4 + j * 4 + 3];
101+
102+
auto ovr = nms_match_iou(vbox1, vbox2);
103+
if (ovr >= iou_threshold_data) {
104+
select[j] = false;
105+
v_i.push_back(j);
106+
}
107+
}
108+
if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) {
109+
for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) {
110+
res_order.push_back(k);
111+
res_order.push_back(g);
112+
res_order.push_back(i);
113+
res_order.push_back(v_i[v_i_idx]);
114+
}
115+
}
116+
}
117+
}
118+
}
119+
std::vector<int64_t> inds_dims({(int64_t)res_order.size() / 4, 4});
120+
121+
OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size());
122+
int64_t* res_data = ort_.GetTensorMutableData<int64_t>(res);
123+
124+
memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size());
125+
126+
allocator_.Free(select);
127+
}
128+
REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp);
129+
} // namespace mmdeploy
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Copyright (c) OpenMMLab. All rights reserved.
2+
#ifndef ONNXRUNTIME_NMS_MATCH_H
3+
#define ONNXRUNTIME_NMS_MATCH_H
4+
5+
#include <assert.h>
6+
#include <onnxruntime_cxx_api.h>
7+
8+
#include <cmath>
9+
#include <mutex>
10+
#include <string>
11+
#include <vector>
12+
13+
namespace mmdeploy {
14+
struct NMSMatchKernel {
15+
NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info);
16+
17+
void Compute(OrtKernelContext* context);
18+
19+
private:
20+
Ort::CustomOpApi ort_;
21+
const OrtKernelInfo* info_;
22+
Ort::AllocatorWithDefaultOptions allocator_;
23+
};
24+
25+
struct NMSMatchOp : Ort::CustomOpBase<NMSMatchOp, NMSMatchKernel> {
26+
void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
27+
return new NMSMatchKernel(api, info);
28+
}
29+
const char* GetName() const { return "NMSMatch"; }
30+
31+
size_t GetInputTypeCount() const { return 4; }
32+
ONNXTensorElementDataType GetInputType(size_t) const {
33+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
34+
}
35+
36+
size_t GetOutputTypeCount() const { return 1; }
37+
ONNXTensorElementDataType GetOutputType(size_t) const {
38+
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
39+
}
40+
41+
// force cpu
42+
const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; }
43+
};
44+
} // namespace mmdeploy
45+
46+
#endif // ONNXRUNTIME_NMS_MATCH_H

docs/en/06-custom-ops/onnxruntime.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
- [Inputs](#inputs-3)
2828
- [Outputs](#outputs-3)
2929
- [Type Constraints](#type-constraints-3)
30+
- [NMSMatch](#nmsmatch)
31+
- [Description](#description-2)
32+
- [Parameters](#parameters-2)
33+
- [Inputs](#inputs-2)
34+
- [Outputs](#outputs-2)
35+
- [Type Constraints](#type-constraints-2)
3036

3137
<!-- TOC -->
3238

@@ -174,3 +180,36 @@ Perform RoIAlignRotated on output feature, used in bbox_head of most two-stage r
174180
#### Type Constraints
175181

176182
- T:tensor(float32)
183+
184+
### NMSMatch
185+
186+
#### Description
187+
188+
Non Max Suppression with the suppression box match.
189+
190+
#### Parameters
191+
192+
| Type | Parameter | Description |
193+
| ------- | ----------- | --------------------------------- |
194+
| `float` | `iou_thr` | The IoU threshold for NMSMatch. |
195+
| `float` | `score_thr` | The score threshold for NMSMatch. |
196+
197+
#### Inputs
198+
199+
<dl>
200+
<dt><tt>inputs[0]</tt>: T</dt>
201+
<dd>Input boxes; 3-D tensor of shape (b, N, 4), where b is the batch size, N is the number of boxes and 4 means the coordinate.</dd>
202+
<dt><tt>inputs[1]</tt>: T</dt>
203+
<dd>Input scores; 3-D tensor of shape (b, c, N), where b is the batch size, c is the class size and N is the number of boxes.</dd>
204+
</dl>
205+
206+
#### Outputs
207+
208+
<dl>
209+
<dt><tt>outputs[0]</tt>: T</dt>
210+
<dd>Output feature; 2-D tensor of shape (K, 4), K is the number of matched boxes, 4 is batch id, class id, select boxes, suppressed boxes.</dd>
211+
</dl>
212+
213+
#### Type Constraints
214+
215+
- T:tensor(float32)

docs/zh_cn/06-custom-ops/onnxruntime.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
- [Inputs](#inputs-3)
2828
- [Outputs](#outputs-3)
2929
- [Type Constraints](#type-constraints-3)
30+
- [NMSMatch](#nmsmatch)
31+
- [Description](#description-2)
32+
- [Parameters](#parameters-2)
33+
- [Inputs](#inputs-2)
34+
- [Outputs](#outputs-2)
35+
- [Type Constraints](#type-constraints-2)
3036

3137
<!-- TOC -->
3238

@@ -174,3 +180,36 @@ Perform RoIAlignRotated on output feature, used in bbox_head of most two-stage r
174180
#### Type Constraints
175181

176182
- T:tensor(float32)
183+
184+
### NMSMatch
185+
186+
#### Description
187+
188+
Non Max Suppression with the suppression box match.
189+
190+
#### Parameters
191+
192+
| Type | Parameter | Description |
193+
| ------- | ----------- | --------------------------------- |
194+
| `float` | `iou_thr` | The IoU threshold for NMSMatch. |
195+
| `float` | `score_thr` | The score threshold for NMSMatch. |
196+
197+
#### Inputs
198+
199+
<dl>
200+
<dt><tt>inputs[0]</tt>: T</dt>
201+
<dd>Input boxes; 3-D tensor of shape (b, N, 4), where b is the batch size, N is the number of boxes and 4 means the coordinate.</dd>
202+
<dt><tt>inputs[1]</tt>: T</dt>
203+
<dd>Input scores; 3-D tensor of shape (b, c, N), where b is the batch size, c is the class size and N is the number of boxes.</dd>
204+
</dl>
205+
206+
#### Outputs
207+
208+
<dl>
209+
<dt><tt>outputs[0]</tt>: T</dt>
210+
<dd>Output feature; 2-D tensor of shape (K, 4), K is the number of matched boxes, 4 is batch id, class id, select boxes, suppressed boxes.</dd>
211+
</dl>
212+
213+
#### Type Constraints
214+
215+
- T:tensor(float32)

mmdeploy/mmcv/ops/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from . import roi_align # noqa: F401,F403
77
from . import roi_align_rotated # noqa: F401,F403
88
from . import transformer # noqa: F401,F403
9-
from .nms import ONNXNMSop, TRTBatchedNMSop, multiclass_nms
10-
from .nms_rotated import (ONNXNMSRotatedOp, TRTBatchedRotatedNMSop,
11-
multiclass_nms_rotated)
9+
from .nms import ONNXNMSop, TRTBatchedNMSop, multiclass_nms # noqa: F401,F403
10+
from .nms_match import ONNXNMSMatchOp, multiclass_nms_match
11+
from .nms_rotated import multiclass_nms_rotated # noqa: F401,F403
12+
from .nms_rotated import ONNXNMSRotatedOp, TRTBatchedRotatedNMSop
1213

1314
__all__ = [
1415
'ONNXNMSop', 'TRTBatchedNMSop', 'TRTBatchedRotatedNMSop',
15-
'ONNXNMSRotatedOp', 'multiclass_nms', 'multiclass_nms_rotated'
16+
'ONNXNMSRotatedOp', 'multiclass_nms_rotated'
17+
'multiclass_nms', 'ONNXNMSMatchOp', 'multiclass_nms_match'
1618
]

mmdeploy/mmcv/ops/nms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mmdeploy.core import FUNCTION_REWRITER, mark
88
from mmdeploy.utils import IR, is_dynamic_batch
99
from mmdeploy.utils.constants import Backend
10+
from .nms_match import multiclass_nms_match
1011
from .nms_rotated import multiclass_nms_rotated
1112

1213

@@ -529,6 +530,15 @@ def multiclass_nms(boxes: Tensor,
529530
score_threshold=score_threshold,
530531
pre_top_k=pre_top_k,
531532
keep_top_k=keep_top_k)
533+
elif nms_type == 'nms_match':
534+
return multiclass_nms_match(
535+
boxes,
536+
scores,
537+
max_output_boxes_per_class=max_output_boxes_per_class,
538+
iou_threshold=iou_threshold,
539+
score_threshold=score_threshold,
540+
pre_top_k=pre_top_k,
541+
keep_top_k=keep_top_k)
532542
else:
533543
raise NotImplementedError(f'Unsupported nms type: {nms_type}.')
534544

0 commit comments

Comments
 (0)