Skip to content

Commit 1e3d06d

Browse files
authored
[Feature] Support ONNX and TensorRT exportation of RTMO models (open-mmlab#2597)
* support ONNX&TensorRT exportation of RTMO * add configs for rtmo * replace bbox expansion factor with parameter bbox_padding * refine code * refine comment * apply model.switch_to_deploy in BaseTask.build_pytorch_model * fix lint * add rtmo into regression test * add rtmo with trt backend into regression test * add rtmo into supported model list
1 parent 6ff3c93 commit 1e3d06d

File tree

8 files changed

+184
-2
lines changed

8 files changed

+184
-2
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
_base_ = ['./pose-detection_static.py', '../_base_/backends/onnxruntime.py']
2+
3+
onnx_config = dict(
4+
output_names=['dets', 'keypoints'],
5+
dynamic_axes={
6+
'input': {
7+
0: 'batch',
8+
},
9+
'dets': {
10+
0: 'batch',
11+
},
12+
'keypoints': {
13+
0: 'batch'
14+
}
15+
})
16+
17+
codebase_config = dict(
18+
post_processing=dict(
19+
score_threshold=0.05,
20+
iou_threshold=0.5,
21+
max_output_boxes_per_class=200,
22+
pre_top_k=2000,
23+
keep_top_k=50,
24+
background_label_id=-1,
25+
))
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
_base_ = ['./pose-detection_static.py', '../_base_/backends/tensorrt-fp16.py']
2+
3+
onnx_config = dict(
4+
output_names=['dets', 'keypoints'],
5+
dynamic_axes={
6+
'input': {
7+
0: 'batch',
8+
},
9+
'dets': {
10+
0: 'batch',
11+
},
12+
'keypoints': {
13+
0: 'batch'
14+
}
15+
})
16+
17+
backend_config = dict(
18+
common_config=dict(max_workspace_size=1 << 30),
19+
model_inputs=[
20+
dict(
21+
input_shapes=dict(
22+
input=dict(
23+
min_shape=[1, 3, 640, 640],
24+
opt_shape=[1, 3, 640, 640],
25+
max_shape=[1, 3, 640, 640])))
26+
])
27+
28+
codebase_config = dict(
29+
post_processing=dict(
30+
score_threshold=0.05,
31+
iou_threshold=0.5,
32+
max_output_boxes_per_class=200,
33+
pre_top_k=2000,
34+
keep_top_k=50,
35+
background_label_id=-1,
36+
))

docs/en/04-supported-codebases/mmpose.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,4 @@ TODO
161161
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
162162
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
163163
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y |
164+
| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N |

docs/zh_cn/04-supported-codebases/mmpose.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,4 @@ task_processor.visualize(
165165
| [SimCC](https://mmpose.readthedocs.io/en/latest/model_zoo_papers/algorithms.html#simcc-eccv-2022) | PoseDetection | Y | Y | Y | N | Y |
166166
| [RTMPose](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmpose) | PoseDetection | Y | Y | Y | N | Y |
167167
| [YoloX-Pose](https://github.com/open-mmlab/mmpose/tree/main/projects/yolox_pose) | PoseDetection | Y | Y | N | N | Y |
168+
| [RTMO](https://github.com/open-mmlab/mmpose/tree/dev-1.x/projects/rtmo) | PoseDetection | Y | Y | N | N | N |

mmdeploy/codebase/base/task.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,11 @@ def build_pytorch_model(self,
126126
if hasattr(model, 'backbone') and hasattr(model.backbone,
127127
'switch_to_deploy'):
128128
model.backbone.switch_to_deploy()
129+
130+
if hasattr(model, 'switch_to_deploy') and callable(
131+
model.switch_to_deploy):
132+
model.switch_to_deploy()
133+
129134
model = model.to(self.device)
130135
model.eval()
131136
return model
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403
2+
from . import mspn_head, rtmo_head, simcc_head, yolox_pose_head
33

4-
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head']
4+
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head', 'rtmo_head']
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Optional, Tuple
3+
4+
import torch
5+
from mmpose.structures.bbox import bbox_xyxy2cs
6+
from torch import Tensor
7+
8+
from mmdeploy.codebase.mmdet import get_post_processing_params
9+
from mmdeploy.core import FUNCTION_REWRITER
10+
from mmdeploy.mmcv.ops.nms import multiclass_nms
11+
from mmdeploy.utils import Backend, get_backend
12+
13+
14+
@FUNCTION_REWRITER.register_rewriter(
15+
func_name='mmpose.models.heads.hybrid_heads.'
16+
'rtmo_head.RTMOHead.forward')
17+
def predict(self,
18+
x: Tuple[Tensor],
19+
batch_data_samples: List = [],
20+
test_cfg: Optional[dict] = None):
21+
"""Get predictions and transform to bbox and keypoints results.
22+
Args:
23+
x (Tuple[Tensor]): The input tensor from upstream network.
24+
batch_data_samples: Batch image meta info. Defaults to None.
25+
test_cfg: The runtime config for testing process.
26+
27+
Returns:
28+
Tuple[Tensor]: Predict bbox and keypoint results.
29+
- dets (Tensor): Predict bboxes and scores, which is a 3D Tensor,
30+
has shape (batch_size, num_instances, 5), the last dimension 5
31+
arrange as (x1, y1, x2, y2, score).
32+
- pred_kpts (Tensor): Predict keypoints and scores, which is a 4D
33+
Tensor, has shape (batch_size, num_instances, num_keypoints, 5),
34+
the last dimension 3 arrange as (x, y, score).
35+
"""
36+
37+
# deploy context
38+
ctx = FUNCTION_REWRITER.get_context()
39+
backend = get_backend(ctx.cfg)
40+
deploy_cfg = ctx.cfg
41+
42+
cfg = self.test_cfg if test_cfg is None else test_cfg
43+
44+
# get predictions
45+
cls_scores, bbox_preds, _, kpt_vis, pose_vecs = self.head_module(x)[:5]
46+
assert len(cls_scores) == len(bbox_preds)
47+
num_imgs = cls_scores[0].shape[0]
48+
49+
# flatten and concat predictions
50+
scores = self._flatten_predictions(cls_scores).sigmoid()
51+
flatten_bbox_preds = self._flatten_predictions(bbox_preds)
52+
flatten_pose_vecs = self._flatten_predictions(pose_vecs)
53+
flatten_kpt_vis = self._flatten_predictions(kpt_vis).sigmoid()
54+
bboxes = self.decode_bbox(flatten_bbox_preds, self.flatten_priors,
55+
self.flatten_stride)
56+
57+
if backend == Backend.TENSORRT:
58+
# pad for batched_nms because its output index is filled with -1
59+
bboxes = torch.cat(
60+
[bboxes,
61+
bboxes.new_zeros((bboxes.shape[0], 1, bboxes.shape[2]))],
62+
dim=1)
63+
64+
scores = torch.cat(
65+
[scores, scores.new_zeros((scores.shape[0], 1, 1))], dim=1)
66+
67+
# nms parameters
68+
post_params = get_post_processing_params(deploy_cfg)
69+
max_output_boxes_per_class = post_params.max_output_boxes_per_class
70+
iou_threshold = cfg.get('nms_thr', post_params.iou_threshold)
71+
score_threshold = cfg.get('score_thr', post_params.score_threshold)
72+
pre_top_k = post_params.get('pre_top_k', -1)
73+
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
74+
75+
# do nms
76+
_, _, nms_indices = multiclass_nms(
77+
bboxes,
78+
scores,
79+
max_output_boxes_per_class,
80+
iou_threshold,
81+
score_threshold,
82+
pre_top_k=pre_top_k,
83+
keep_top_k=keep_top_k,
84+
output_index=True)
85+
86+
batch_inds = torch.arange(num_imgs, device=scores.device).view(-1, 1)
87+
88+
# filter predictions
89+
dets = torch.cat([bboxes, scores], dim=2)
90+
dets = dets[batch_inds, nms_indices, ...]
91+
pose_vecs = flatten_pose_vecs[batch_inds, nms_indices, ...]
92+
kpt_vis = flatten_kpt_vis[batch_inds, nms_indices, ...]
93+
grids = self.flatten_priors[nms_indices, ...]
94+
95+
# decode keypoints
96+
bbox_cs = torch.cat(bbox_xyxy2cs(dets[..., :4], self.bbox_padding), dim=-1)
97+
keypoints = self.dcc.forward_test(pose_vecs, bbox_cs, grids)
98+
pred_kpts = torch.cat([keypoints, kpt_vis.unsqueeze(-1)], dim=-1)
99+
100+
return dets, pred_kpts

tests/regression/mmpose.yml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,17 @@ models:
150150
input_img: *img_human_pose
151151
test_img: *img_human_pose
152152
deploy_config: configs/mmpose/pose-detection_yolox-pose_onnxruntime_dynamic.py
153+
154+
- name: RTMO
155+
metafile: configs/body_2d_keypoint/rtmo/body7/rtmo_body7.yml
156+
model_configs:
157+
- configs/body_2d_keypoint/rtmo/body7/rtmo-s_8xb32-600e_body7-640x640.py
158+
pipelines:
159+
- convert_image:
160+
input_img: *img_human_pose
161+
test_img: *img_human_pose
162+
deploy_config: configs/mmpose/pose-detection_rtmo_onnxruntime_dynamic.py
163+
- convert_image:
164+
input_img: *img_human_pose
165+
test_img: *img_human_pose
166+
deploy_config: configs/mmpose/pose-detection_rtmo_tensorrt-fp16_dynamic-640x640.py

0 commit comments

Comments
 (0)