Skip to content

Commit f2d0b15

Browse files
author
q.yao
authored
simplify non batch nms (#99)
1 parent a543d41 commit f2d0b15

File tree

1 file changed

+115
-22
lines changed
  • mmdeploy/codebase/mmdet/core/post_processing

1 file changed

+115
-22
lines changed

mmdeploy/codebase/mmdet/core/post_processing/bbox_nms.py

Lines changed: 115 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import mmdeploy
66
from mmdeploy.core import FUNCTION_REWRITER, mark
77
from mmdeploy.mmcv.ops import ONNXNMSop, TRTBatchedNMSop
8+
from mmdeploy.utils import is_dynamic_batch
89

910

1011
def select_nms_index(scores: torch.Tensor,
@@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor,
8283
keep_top_k: int = -1):
8384
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
8485
85-
This function helps exporting to onnx with batch and multiclass NMS op.
86-
It only supports class-agnostic detection results. That is, the scores
87-
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
88-
(N, num_boxes, 4).
89-
90-
Args:
91-
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
92-
scores (Tensor): The detection scores of shape
93-
[N, num_boxes, num_classes].
94-
max_output_boxes_per_class (int): Maximum number of output
95-
boxes per class of nms. Defaults to 1000.
96-
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
97-
score_threshold (float): score threshold of nms.
98-
Defaults to 0.05.
99-
pre_top_k (int): Number of top K boxes to keep before nms.
100-
Defaults to -1.
101-
keep_top_k (int): Number of top K boxes to keep after nms.
102-
Defaults to -1.
103-
104-
Returns:
105-
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
106-
and `labels` of shape [N, num_det].
86+
This function helps exporting to onnx with batch and multiclass NMS op. It
87+
only supports class-agnostic detection results. That is, the scores is of
88+
shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
89+
4).
10790
"""
10891
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
10992
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
@@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor,
129112
return dets, labels
130113

131114

115+
def _multiclass_nms_single(boxes: Tensor,
116+
scores: Tensor,
117+
max_output_boxes_per_class: int = 1000,
118+
iou_threshold: float = 0.5,
119+
score_threshold: float = 0.05,
120+
pre_top_k: int = -1,
121+
keep_top_k: int = -1):
122+
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
123+
124+
Single batch nms could be optimized.
125+
"""
126+
max_output_boxes_per_class = torch.LongTensor([max_output_boxes_per_class])
127+
iou_threshold = torch.tensor([iou_threshold], dtype=torch.float32)
128+
score_threshold = torch.tensor([score_threshold], dtype=torch.float32)
129+
130+
# pre topk
131+
if pre_top_k > 0:
132+
max_scores, _ = scores.max(-1)
133+
_, topk_inds = max_scores.squeeze(0).topk(pre_top_k)
134+
boxes = boxes[:, topk_inds, :]
135+
scores = scores[:, topk_inds, :]
136+
137+
scores = scores.permute(0, 2, 1)
138+
selected_indices = ONNXNMSop.apply(boxes, scores,
139+
max_output_boxes_per_class,
140+
iou_threshold, score_threshold)
141+
142+
cls_inds = selected_indices[:, 1]
143+
box_inds = selected_indices[:, 2]
144+
145+
scores = scores[:, cls_inds, box_inds].unsqueeze(2)
146+
boxes = boxes[:, box_inds, ...]
147+
dets = torch.cat([boxes, scores], dim=2)
148+
labels = cls_inds.unsqueeze(0)
149+
150+
# pad
151+
dets = torch.cat((dets, dets.new_zeros((1, 1, 5))), 1)
152+
labels = torch.cat((labels, labels.new_zeros((1, 1))), 1)
153+
154+
# topk or sort
155+
is_use_topk = keep_top_k > 0 and \
156+
(torch.onnx.is_in_onnx_export() or keep_top_k < dets.shape[1])
157+
if is_use_topk:
158+
_, topk_inds = dets[:, :, -1].topk(keep_top_k, dim=1)
159+
else:
160+
_, topk_inds = dets[:, :, -1].sort(dim=1, descending=True)
161+
topk_inds = topk_inds.squeeze(0)
162+
dets = dets[:, topk_inds, ...]
163+
labels = labels[:, topk_inds, ...]
164+
165+
return dets, labels
166+
167+
168+
@FUNCTION_REWRITER.register_rewriter(
169+
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms')
170+
def multiclass_nms__default(ctx,
171+
boxes: Tensor,
172+
scores: Tensor,
173+
max_output_boxes_per_class: int = 1000,
174+
iou_threshold: float = 0.5,
175+
score_threshold: float = 0.05,
176+
pre_top_k: int = -1,
177+
keep_top_k: int = -1):
178+
"""Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
179+
180+
This function helps exporting to onnx with batch and multiclass NMS op.
181+
It only supports class-agnostic detection results. That is, the scores
182+
is of shape (N, num_bboxes, num_classes) and the boxes is of shape
183+
(N, num_boxes, 4).
184+
185+
Args:
186+
boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
187+
scores (Tensor): The detection scores of shape
188+
[N, num_boxes, num_classes].
189+
max_output_boxes_per_class (int): Maximum number of output
190+
boxes per class of nms. Defaults to 1000.
191+
iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
192+
score_threshold (float): score threshold of nms.
193+
Defaults to 0.05.
194+
pre_top_k (int): Number of top K boxes to keep before nms.
195+
Defaults to -1.
196+
keep_top_k (int): Number of top K boxes to keep after nms.
197+
Defaults to -1.
198+
199+
Returns:
200+
tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
201+
and `labels` of shape [N, num_det].
202+
"""
203+
deploy_cfg = ctx.cfg
204+
batch_size = boxes.size(0)
205+
if not is_dynamic_batch(deploy_cfg) and batch_size != 1:
206+
return _multiclass_nms_single(
207+
boxes,
208+
scores,
209+
max_output_boxes_per_class=max_output_boxes_per_class,
210+
iou_threshold=iou_threshold,
211+
score_threshold=score_threshold,
212+
pre_top_k=pre_top_k,
213+
keep_top_k=keep_top_k)
214+
else:
215+
return _multiclass_nms(
216+
boxes,
217+
scores,
218+
max_output_boxes_per_class=max_output_boxes_per_class,
219+
iou_threshold=iou_threshold,
220+
score_threshold=score_threshold,
221+
pre_top_k=pre_top_k,
222+
keep_top_k=keep_top_k)
223+
224+
132225
@FUNCTION_REWRITER.register_rewriter(
133226
func_name='mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms',
134227
backend='tensorrt')

0 commit comments

Comments
 (0)