55import mmdeploy
66from mmdeploy .core import FUNCTION_REWRITER , mark
77from mmdeploy .mmcv .ops import ONNXNMSop , TRTBatchedNMSop
8+ from mmdeploy .utils import is_dynamic_batch
89
910
1011def select_nms_index (scores : torch .Tensor ,
@@ -82,28 +83,10 @@ def _multiclass_nms(boxes: Tensor,
8283 keep_top_k : int = - 1 ):
8384 """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
8485
85- This function helps exporting to onnx with batch and multiclass NMS op.
86- It only supports class-agnostic detection results. That is, the scores
87- is of shape (N, num_bboxes, num_classes) and the boxes is of shape
88- (N, num_boxes, 4).
89-
90- Args:
91- boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
92- scores (Tensor): The detection scores of shape
93- [N, num_boxes, num_classes].
94- max_output_boxes_per_class (int): Maximum number of output
95- boxes per class of nms. Defaults to 1000.
96- iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
97- score_threshold (float): score threshold of nms.
98- Defaults to 0.05.
99- pre_top_k (int): Number of top K boxes to keep before nms.
100- Defaults to -1.
101- keep_top_k (int): Number of top K boxes to keep after nms.
102- Defaults to -1.
103-
104- Returns:
105- tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
106- and `labels` of shape [N, num_det].
86+ This function helps exporting to onnx with batch and multiclass NMS op. It
87+ only supports class-agnostic detection results. That is, the scores is of
88+ shape (N, num_bboxes, num_classes) and the boxes is of shape (N, num_boxes,
89+ 4).
10790 """
10891 max_output_boxes_per_class = torch .LongTensor ([max_output_boxes_per_class ])
10992 iou_threshold = torch .tensor ([iou_threshold ], dtype = torch .float32 )
@@ -129,6 +112,116 @@ def _multiclass_nms(boxes: Tensor,
129112 return dets , labels
130113
131114
115+ def _multiclass_nms_single (boxes : Tensor ,
116+ scores : Tensor ,
117+ max_output_boxes_per_class : int = 1000 ,
118+ iou_threshold : float = 0.5 ,
119+ score_threshold : float = 0.05 ,
120+ pre_top_k : int = - 1 ,
121+ keep_top_k : int = - 1 ):
122+ """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
123+
124+ Single batch nms could be optimized.
125+ """
126+ max_output_boxes_per_class = torch .LongTensor ([max_output_boxes_per_class ])
127+ iou_threshold = torch .tensor ([iou_threshold ], dtype = torch .float32 )
128+ score_threshold = torch .tensor ([score_threshold ], dtype = torch .float32 )
129+
130+ # pre topk
131+ if pre_top_k > 0 :
132+ max_scores , _ = scores .max (- 1 )
133+ _ , topk_inds = max_scores .squeeze (0 ).topk (pre_top_k )
134+ boxes = boxes [:, topk_inds , :]
135+ scores = scores [:, topk_inds , :]
136+
137+ scores = scores .permute (0 , 2 , 1 )
138+ selected_indices = ONNXNMSop .apply (boxes , scores ,
139+ max_output_boxes_per_class ,
140+ iou_threshold , score_threshold )
141+
142+ cls_inds = selected_indices [:, 1 ]
143+ box_inds = selected_indices [:, 2 ]
144+
145+ scores = scores [:, cls_inds , box_inds ].unsqueeze (2 )
146+ boxes = boxes [:, box_inds , ...]
147+ dets = torch .cat ([boxes , scores ], dim = 2 )
148+ labels = cls_inds .unsqueeze (0 )
149+
150+ # pad
151+ dets = torch .cat ((dets , dets .new_zeros ((1 , 1 , 5 ))), 1 )
152+ labels = torch .cat ((labels , labels .new_zeros ((1 , 1 ))), 1 )
153+
154+ # topk or sort
155+ is_use_topk = keep_top_k > 0 and \
156+ (torch .onnx .is_in_onnx_export () or keep_top_k < dets .shape [1 ])
157+ if is_use_topk :
158+ _ , topk_inds = dets [:, :, - 1 ].topk (keep_top_k , dim = 1 )
159+ else :
160+ _ , topk_inds = dets [:, :, - 1 ].sort (dim = 1 , descending = True )
161+ topk_inds = topk_inds .squeeze (0 )
162+ dets = dets [:, topk_inds , ...]
163+ labels = labels [:, topk_inds , ...]
164+
165+ return dets , labels
166+
167+
168+ @FUNCTION_REWRITER .register_rewriter (
169+ func_name = 'mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms' )
170+ def multiclass_nms__default (ctx ,
171+ boxes : Tensor ,
172+ scores : Tensor ,
173+ max_output_boxes_per_class : int = 1000 ,
174+ iou_threshold : float = 0.5 ,
175+ score_threshold : float = 0.05 ,
176+ pre_top_k : int = - 1 ,
177+ keep_top_k : int = - 1 ):
178+ """Create a dummy onnx::NonMaxSuppression op while exporting to ONNX.
179+
180+ This function helps exporting to onnx with batch and multiclass NMS op.
181+ It only supports class-agnostic detection results. That is, the scores
182+ is of shape (N, num_bboxes, num_classes) and the boxes is of shape
183+ (N, num_boxes, 4).
184+
185+ Args:
186+ boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4].
187+ scores (Tensor): The detection scores of shape
188+ [N, num_boxes, num_classes].
189+ max_output_boxes_per_class (int): Maximum number of output
190+ boxes per class of nms. Defaults to 1000.
191+ iou_threshold (float): IOU threshold of nms. Defaults to 0.5.
192+ score_threshold (float): score threshold of nms.
193+ Defaults to 0.05.
194+ pre_top_k (int): Number of top K boxes to keep before nms.
195+ Defaults to -1.
196+ keep_top_k (int): Number of top K boxes to keep after nms.
197+ Defaults to -1.
198+
199+ Returns:
200+ tuple[Tensor, Tensor]: (dets, labels), `dets` of shape [N, num_det, 5]
201+ and `labels` of shape [N, num_det].
202+ """
203+ deploy_cfg = ctx .cfg
204+ batch_size = boxes .size (0 )
205+ if not is_dynamic_batch (deploy_cfg ) and batch_size != 1 :
206+ return _multiclass_nms_single (
207+ boxes ,
208+ scores ,
209+ max_output_boxes_per_class = max_output_boxes_per_class ,
210+ iou_threshold = iou_threshold ,
211+ score_threshold = score_threshold ,
212+ pre_top_k = pre_top_k ,
213+ keep_top_k = keep_top_k )
214+ else :
215+ return _multiclass_nms (
216+ boxes ,
217+ scores ,
218+ max_output_boxes_per_class = max_output_boxes_per_class ,
219+ iou_threshold = iou_threshold ,
220+ score_threshold = score_threshold ,
221+ pre_top_k = pre_top_k ,
222+ keep_top_k = keep_top_k )
223+
224+
132225@FUNCTION_REWRITER .register_rewriter (
133226 func_name = 'mmdeploy.codebase.mmdet.core.post_processing._multiclass_nms' ,
134227 backend = 'tensorrt' )
0 commit comments