@@ -210,19 +210,44 @@ def export(
210210 self ,
211211 batch_inputs : torch .Tensor ,
212212 batch_img_metas : list [dict ],
213- ) -> tuple [torch .Tensor , ...]:
214- """Export for two stage detectors."""
215- x = self .extract_feat (batch_inputs )
213+ explain_mode : bool ,
214+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ] | dict :
215+ """Export the model for ONNX/OpenVINO.
216+
217+ Args:
218+ batch_inputs (torch.Tensor): image tensor with shape (N, C, H, W).
219+ batch_img_metas (list[dict]): image information.
220+ explain_mode (bool): whether to return feature vector.
216221
222+ Returns:
223+ tuple[torch.Tensor, torch.Tensor, torch.Tensor] | dict:
224+ - bboxes (torch.Tensor): bounding boxes.
225+ - labels (torch.Tensor): labels.
226+ - masks (torch.Tensor): masks.
227+ - feature_vector (torch.Tensor, optional): feature vector.
228+ - saliency_map (torch.Tensor, optional): saliency map.
229+ """
230+ x = self .extract_feat (batch_inputs )
217231 rpn_results_list = self .rpn_head .export (
218232 x ,
219233 batch_img_metas ,
220234 rescale = False ,
221235 )
222-
223- return self .roi_head .export (
236+ bboxes , labels , masks = self .roi_head .export (
224237 x ,
225238 rpn_results_list ,
226239 batch_img_metas ,
227240 rescale = False ,
228241 )
242+
243+ if explain_mode :
244+ feature_vector = self .feature_vector_fn (x )
245+ return {
246+ "bboxes" : bboxes ,
247+ "labels" : labels ,
248+ "masks" : masks ,
249+ "feature_vector" : feature_vector ,
250+ # create dummy tensor as model API supports saliency_map
251+ "saliency_map" : torch .zeros (1 ),
252+ }
253+ return bboxes , labels , masks
0 commit comments