Skip to content

Commit e19f6fa

Browse files
Support deploy of YoloX-Pose (#2184)
* dev_mmpose * tide * fix lint * del redundant task and model * fix * test ut * test ut * upload configs * fix * remove debug * fix lint * use mmcv.ops.nms * fix lint * remove loop * debug * test modified ut * fix lint * fix return type * fix * fix rescale * fix * fix pack_result * update batch inference * fix nms and pytorch show_box * fix lint * modify ut * add docstring * modify nms * fix * add openvino config * update docs * fix test_mmpose --------- Co-authored-by: RunningLeon <[email protected]>
1 parent a664f06 commit e19f6fa

File tree

15 files changed

+455
-21
lines changed

15 files changed

+455
-21
lines changed

.github/workflows/build.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ jobs:
5252
run: |
5353
git clone -b dev --depth 1 https://github.com/open-mmlab/mmyolo.git /home/runner/work/mmyolo
5454
python -m pip install -v -e /home/runner/work/mmyolo
55+
- name: Install mmpose
56+
run: |
57+
git clone --depth 1 https://github.com/open-mmlab/mmpose.git /home/runner/work/mmpose
58+
python -m pip install -v -e /home/runner/work/mmpose
5559
- name: Build and install
5660
run: |
5761
rm -rf .eggs && python -m pip install -e .
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']
2+
3+
onnx_config = dict(
4+
output_names=['dets', 'keypoints'],
5+
dynamic_axes={
6+
'input': {
7+
0: 'batch',
8+
},
9+
'dets': {
10+
0: 'batch',
11+
},
12+
'keypoints': {
13+
0: 'batch'
14+
}
15+
})
16+
17+
codebase_config = dict(
18+
post_processing=dict(
19+
score_threshold=0.05,
20+
iou_threshold=0.5,
21+
max_output_boxes_per_class=200,
22+
pre_top_k=5000,
23+
keep_top_k=100,
24+
background_label_id=-1,
25+
))
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
_base_ = ['./pose-detection_static.py', '../_base_/backends/openvino.py']
2+
3+
onnx_config = dict(
4+
output_names=['dets', 'keypoints'],
5+
dynamic_axes={
6+
'input': {
7+
0: 'batch',
8+
},
9+
'dets': {
10+
0: 'batch',
11+
},
12+
'keypoints': {
13+
0: 'batch'
14+
}
15+
})
16+
backend_config = dict(
17+
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 640, 640]))])
18+
19+
codebase_config = dict(
20+
post_processing=dict(
21+
score_threshold=0.05,
22+
iou_threshold=0.5,
23+
max_output_boxes_per_class=200,
24+
pre_top_k=5000,
25+
keep_top_k=100,
26+
background_label_id=-1,
27+
))
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt.py']
2+
3+
onnx_config = dict(
4+
output_names=['dets', 'keypoints'],
5+
dynamic_axes={
6+
'input': {
7+
0: 'batch',
8+
},
9+
'dets': {
10+
0: 'batch',
11+
},
12+
'keypoints': {
13+
0: 'batch'
14+
}
15+
})
16+
backend_config = dict(
17+
common_config=dict(max_workspace_size=1 << 30),
18+
model_inputs=[
19+
dict(
20+
input_shapes=dict(
21+
input=dict(
22+
min_shape=[1, 3, 640, 640],
23+
opt_shape=[1, 3, 640, 640],
24+
max_shape=[1, 3, 640, 640])))
25+
])
26+
27+
codebase_config = dict(
28+
post_processing=dict(
29+
score_threshold=0.05,
30+
iou_threshold=0.5,
31+
max_output_boxes_per_class=200,
32+
pre_top_k=5000,
33+
keep_top_k=100,
34+
background_label_id=-1,
35+
))

csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ __launch_bounds__(nthds_per_cta) __global__
5050
bboxOffset) *
5151
5;
5252
if (nmsedIndex != nullptr) {
53-
nmsedIndex[i] = bboxId / 5;
53+
nmsedIndex[i] = bboxId / 5 - bboxOffset;
5454
}
5555
// clipped bbox xmin
5656
nmsedDets[i * 6] =
@@ -74,7 +74,7 @@ __launch_bounds__(nthds_per_cta) __global__
7474
bboxOffset) *
7575
4;
7676
if (nmsedIndex != nullptr) {
77-
nmsedIndex[i] = bboxId / 4;
77+
nmsedIndex[i] = bboxId / 4 - bboxOffset;
7878
}
7979
// clipped bbox xmin
8080
nmsedDets[i * 5] =

docs/en/04-supported-codebases/mmpose.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,4 @@ TODO
160160
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
161161
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
162162
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
163+
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |

docs/zh_cn/04-supported-codebases/mmpose.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,4 @@ task_processor.visualize(
164164
| [Hourglass](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#hourglass-eccv-2016) | PoseDetection | Y | Y | Y | N | Y |
165165
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
166166
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
167+
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox-pose) | PoseDetection | Y | Y | N | N | Y |

mmdeploy/codebase/mmpose/deploy/pose_detection.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ class MMPose(MMCodebase):
120120
@classmethod
121121
def register_deploy_modules(cls):
122122
"""register rewritings."""
123+
import mmdeploy.codebase.mmdet.models
124+
import mmdeploy.codebase.mmdet.ops
125+
import mmdeploy.codebase.mmdet.structures
123126
import mmdeploy.codebase.mmpose.models # noqa: F401
124127

125128
@classmethod
@@ -202,9 +205,11 @@ def create_input(self,
202205
raise AssertionError('imgs must be strings or numpy arrays')
203206
elif isinstance(imgs, (np.ndarray, str)):
204207
imgs = [imgs]
208+
img_path = [imgs]
205209
else:
206210
raise AssertionError('imgs must be strings or numpy arrays')
207211
if isinstance(imgs, (list, tuple)) and isinstance(imgs[0], str):
212+
img_path = imgs
208213
img_data = [mmcv.imread(img) for img in imgs]
209214
imgs = img_data
210215
person_results = []
@@ -220,7 +225,7 @@ def create_input(self,
220225
TRANSFORMS.build(c) for c in cfg.test_dataloader.dataset.pipeline
221226
]
222227
test_pipeline = Compose(test_pipeline)
223-
if input_shape is not None:
228+
if input_shape is not None and hasattr(cfg, 'codec'):
224229
if isinstance(cfg.codec, dict):
225230
codec = cfg.codec
226231
elif isinstance(cfg.codec, list):
@@ -243,9 +248,15 @@ def create_input(self,
243248
bbox_score = np.array([bbox[4] if len(bbox) == 5 else 1
244249
]) # shape (1,)
245250
data = {
246-
'img': imgs[i],
247-
'bbox_score': bbox_score,
248-
'bbox': bbox[None], # shape (1, 4)
251+
'img':
252+
imgs[i],
253+
'bbox_score':
254+
bbox_score,
255+
'bbox': [] if hasattr(cfg.model, 'bbox_head')
256+
and cfg.model.bbox_head.type == 'YOLOXPoseHead' else
257+
bbox[None],
258+
'img_path':
259+
img_path[i]
249260
}
250261
data.update(meta_data)
251262
data = test_pipeline(data)
@@ -288,11 +299,17 @@ def visualize(self,
288299

289300
if isinstance(image, str):
290301
image = mmcv.imread(image, channel_order='rgb')
302+
draw_bbox = result.pred_instances.bboxes is not None
303+
if draw_bbox and isinstance(result.pred_instances.bboxes,
304+
torch.Tensor):
305+
result.pred_instances.bboxes = result.pred_instances.bboxes.cpu(
306+
).numpy()
291307
visualizer.add_datasample(
292308
name,
293309
image,
294310
data_sample=result,
295311
draw_gt=False,
312+
draw_bbox=draw_bbox,
296313
show=show_result,
297314
out_file=output_file)
298315

mmdeploy/codebase/mmpose/deploy/pose_detection_model.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def __init__(self,
5454
device=device,
5555
**kwargs)
5656
# create head for decoding heatmap
57-
self.head = builder.build_head(model_cfg.model.head)
57+
self.head = builder.build_head(model_cfg.model.head) if hasattr(
58+
model_cfg.model, 'head') else None
5859

5960
def _init_wrapper(self, backend: Backend, backend_files: Sequence[str],
6061
device: str, **kwargs):
@@ -97,6 +98,9 @@ def forward(self,
9798
inputs = inputs.contiguous().to(self.device)
9899
batch_outputs = self.wrapper({self.input_name: inputs})
99100
batch_outputs = self.wrapper.output_to_list(batch_outputs)
101+
if self.model_cfg.model.type == 'YOLODetector':
102+
return self.pack_yolox_pose_result(batch_outputs, data_samples)
103+
100104
codec = self.model_cfg.codec
101105
if isinstance(codec, (list, tuple)):
102106
codec = codec[-1]
@@ -158,6 +162,48 @@ def pack_result(self,
158162

159163
return data_samples
160164

165+
def pack_yolox_pose_result(self, preds: List[torch.Tensor],
166+
data_samples: List[BaseDataElement]):
167+
"""Pack yolox-pose prediction results to mmpose format
168+
Args:
169+
preds (List[Tensor]): Prediction of bboxes and key-points.
170+
data_samples (List[BaseDataElement]): A list of meta info for
171+
image(s).
172+
Returns:
173+
data_samples (List[BaseDataElement]):
174+
updated data_samples with predictions.
175+
"""
176+
assert preds[0].shape[0] == len(data_samples)
177+
batched_dets, batched_kpts = preds
178+
for data_sample_idx, data_sample in enumerate(data_samples):
179+
bboxes = batched_dets[data_sample_idx, :, :4]
180+
bbox_scores = batched_dets[data_sample_idx, :, 4]
181+
keypoints = batched_kpts[data_sample_idx, :, :, :2]
182+
keypoint_scores = batched_kpts[data_sample_idx, :, :, 2]
183+
184+
# filter zero or negative scores
185+
inds = bbox_scores > 0.0
186+
bboxes = bboxes[inds, :]
187+
bbox_scores = bbox_scores[inds]
188+
keypoints = keypoints[inds, :]
189+
keypoint_scores = keypoint_scores[inds]
190+
191+
pred_instances = InstanceData()
192+
# rescale
193+
scale_factor = data_sample.scale_factor
194+
scale_factor = keypoints.new_tensor(scale_factor)
195+
keypoints /= keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
196+
bboxes /= keypoints.new_tensor(scale_factor).repeat(1, 2)
197+
pred_instances.bboxes = bboxes.cpu().numpy()
198+
pred_instances.bbox_scores = bbox_scores
199+
# the precision test requires keypoints to be np.ndarray
200+
pred_instances.keypoints = keypoints.cpu().numpy()
201+
pred_instances.keypoint_scores = keypoint_scores
202+
pred_instances.lebels = torch.zeros(bboxes.shape[0])
203+
204+
data_sample.pred_instances = pred_instances
205+
return data_samples
206+
161207

162208
@__BACKEND_MODEL.register_module('sdk')
163209
class SDKEnd2EndModel(End2EndModel):
@@ -236,8 +282,13 @@ def build_pose_detection_model(
236282
if isinstance(data_preprocessor, dict):
237283
dp = data_preprocessor.copy()
238284
dp_type = dp.pop('type')
239-
assert dp_type == 'PoseDataPreprocessor'
240-
data_preprocessor = PoseDataPreprocessor(**dp)
285+
if dp_type == 'mmdet.DetDataPreprocessor':
286+
from mmdet.models.data_preprocessors import DetDataPreprocessor
287+
data_preprocessor = DetDataPreprocessor(**dp)
288+
else:
289+
assert dp_type == 'PoseDataPreprocessor'
290+
data_preprocessor = PoseDataPreprocessor(**dp)
291+
241292
backend_pose_model = __BACKEND_MODEL.build(
242293
dict(
243294
type=model_type,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from . import mspn_head
2+
from . import mspn_head, yolox_pose_head # noqa: F401,F403
33

4-
__all__ = ['mspn_head']
4+
__all__ = ['mspn_head', 'yolox_pose_head']

0 commit comments

Comments
 (0)