@@ -54,7 +54,8 @@ def __init__(self,
5454 device = device ,
5555 ** kwargs )
5656 # create head for decoding heatmap
57- self .head = builder .build_head (model_cfg .model .head )
57+ self .head = builder .build_head (model_cfg .model .head ) if hasattr (
58+ model_cfg .model , 'head' ) else None
5859
5960 def _init_wrapper (self , backend : Backend , backend_files : Sequence [str ],
6061 device : str , ** kwargs ):
@@ -97,6 +98,9 @@ def forward(self,
9798 inputs = inputs .contiguous ().to (self .device )
9899 batch_outputs = self .wrapper ({self .input_name : inputs })
99100 batch_outputs = self .wrapper .output_to_list (batch_outputs )
101+ if self .model_cfg .model .type == 'YOLODetector' :
102+ return self .pack_yolox_pose_result (batch_outputs , data_samples )
103+
100104 codec = self .model_cfg .codec
101105 if isinstance (codec , (list , tuple )):
102106 codec = codec [- 1 ]
@@ -158,6 +162,48 @@ def pack_result(self,
158162
159163 return data_samples
160164
165+ def pack_yolox_pose_result (self , preds : List [torch .Tensor ],
166+ data_samples : List [BaseDataElement ]):
167+ """Pack yolox-pose prediction results to mmpose format
168+ Args:
169+ preds (List[Tensor]): Prediction of bboxes and key-points.
170+ data_samples (List[BaseDataElement]): A list of meta info for
171+ image(s).
172+ Returns:
173+ data_samples (List[BaseDataElement]):
174+ updated data_samples with predictions.
175+ """
176+ assert preds [0 ].shape [0 ] == len (data_samples )
177+ batched_dets , batched_kpts = preds
178+ for data_sample_idx , data_sample in enumerate (data_samples ):
179+ bboxes = batched_dets [data_sample_idx , :, :4 ]
180+ bbox_scores = batched_dets [data_sample_idx , :, 4 ]
181+ keypoints = batched_kpts [data_sample_idx , :, :, :2 ]
182+ keypoint_scores = batched_kpts [data_sample_idx , :, :, 2 ]
183+
184+ # filter zero or negative scores
185+ inds = bbox_scores > 0.0
186+ bboxes = bboxes [inds , :]
187+ bbox_scores = bbox_scores [inds ]
188+ keypoints = keypoints [inds , :]
189+ keypoint_scores = keypoint_scores [inds ]
190+
191+ pred_instances = InstanceData ()
192+ # rescale
193+ scale_factor = data_sample .scale_factor
194+ scale_factor = keypoints .new_tensor (scale_factor )
195+ keypoints /= keypoints .new_tensor (scale_factor ).reshape (1 , 1 , 2 )
196+ bboxes /= keypoints .new_tensor (scale_factor ).repeat (1 , 2 )
197+ pred_instances .bboxes = bboxes .cpu ().numpy ()
198+ pred_instances .bbox_scores = bbox_scores
199+ # the precision test requires keypoints to be np.ndarray
200+ pred_instances .keypoints = keypoints .cpu ().numpy ()
201+ pred_instances .keypoint_scores = keypoint_scores
202+ pred_instances .lebels = torch .zeros (bboxes .shape [0 ])
203+
204+ data_sample .pred_instances = pred_instances
205+ return data_samples
206+
161207
162208@__BACKEND_MODEL .register_module ('sdk' )
163209class SDKEnd2EndModel (End2EndModel ):
@@ -236,8 +282,13 @@ def build_pose_detection_model(
236282 if isinstance (data_preprocessor , dict ):
237283 dp = data_preprocessor .copy ()
238284 dp_type = dp .pop ('type' )
239- assert dp_type == 'PoseDataPreprocessor'
240- data_preprocessor = PoseDataPreprocessor (** dp )
285+ if dp_type == 'mmdet.DetDataPreprocessor' :
286+ from mmdet .models .data_preprocessors import DetDataPreprocessor
287+ data_preprocessor = DetDataPreprocessor (** dp )
288+ else :
289+ assert dp_type == 'PoseDataPreprocessor'
290+ data_preprocessor = PoseDataPreprocessor (** dp )
291+
241292 backend_pose_model = __BACKEND_MODEL .build (
242293 dict (
243294 type = model_type ,
0 commit comments