Skip to content

Commit 4cf0f92

Browse files
authored
export get_simcc_maximum for simcc (#2449)
* update * update for simcc csrc * fix docker ci * update simcc_label
1 parent 1132e82 commit 4cf0f92

File tree

10 files changed

+109
-17
lines changed

10 files changed

+109
-17
lines changed

.github/workflows/docker.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ jobs:
5353
export TAG=$TAG_PREFIX
5454
echo "TAG=${TAG}" >> $GITHUB_ENV
5555
echo $TAG
56-
docker ./docker/Release/ -t ${TAG} --no-cache
56+
docker build ./docker/Release/ -t ${TAG} --no-cache
5757
docker push $TAG
5858
- name: Push docker image with released tag
5959
if: startsWith(github.ref, 'refs/tags/') == true

.github/workflows/publish.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@ jobs:
2929
echo $MMDEPLOY_VERSION
3030
echo "MMDEPLOY_VERSION=$MMDEPLOY_VERSION" >> $GITHUB_ENV
3131
echo "OUTPUT_DIR=$PREBUILD_DIR/$MMDEPLOY_VERSION" >> $GITHUB_ENV
32-
pip install twine
32+
python3 -m pip install twine --user
3333
- name: Upload mmdeploy
3434
continue-on-error: true
3535
run: |
3636
cd $OUTPUT_DIR/mmdeploy
3737
ls -sha *.whl
38-
twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
38+
python3 -m twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
3939
- name: Upload mmdeploy_runtime
4040
continue-on-error: true
4141
run: |
4242
cd $OUTPUT_DIR/mmdeploy_runtime
4343
ls -sha *.whl
44-
twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
44+
python3 -m twine upload *.whl -u __token__ -p ${{ secrets.pypi_password }}
4545
- name: Check assets
4646
run: |
4747
ls -sha $OUTPUT_DIR/sdk

configs/mmpose/pose-detection_simcc_onnxruntime_dynamic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@
1414
0: 'batch'
1515
}
1616
})
17+
18+
codebase_config = dict(
19+
export_postprocess=False # do not export get_simcc_maximum
20+
)

csrc/mmdeploy/codebase/mmpose/simcc_label.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ class SimCCLabelDecode : public MMPose {
2626
auto& params = config["params"];
2727
flip_test_ = params.value("flip_test", flip_test_);
2828
simcc_split_ratio_ = params.value("simcc_split_ratio", simcc_split_ratio_);
29+
export_postprocess_ = params.value("export_postprocess", export_postprocess_);
30+
if (export_postprocess_) {
31+
simcc_split_ratio_ = 1.0;
32+
}
2933
if (params.contains("input_size")) {
3034
from_value(params["input_size"], input_size_);
3135
}
@@ -52,26 +56,31 @@ class SimCCLabelDecode : public MMPose {
5256

5357
Tensor keypoints({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 2}});
5458
Tensor scores({Device{"cpu"}, DataType::kFLOAT, {simcc_x.shape(0), simcc_x.shape(1), 1}});
55-
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
59+
float *keypoints_data = nullptr, *scores_data = nullptr;
60+
if (!export_postprocess_) {
61+
get_simcc_maximum(simcc_x, simcc_y, keypoints, scores);
62+
keypoints_data = keypoints.data<float>();
63+
scores_data = scores.data<float>();
64+
} else {
65+
keypoints_data = simcc_x.data<float>();
66+
scores_data = simcc_y.data<float>();
67+
}
5668

5769
std::vector<float> center;
5870
std::vector<float> scale;
5971
from_value(img_metas["center"], center);
6072
from_value(img_metas["scale"], scale);
6173
PoseDetectorOutput output;
6274

63-
float* keypoints_data = keypoints.data<float>();
64-
float* scores_data = scores.data<float>();
6575
float scale_value = 200, x = -1, y = -1, s = 0;
6676
for (int i = 0; i < simcc_x.shape(1); i++) {
67-
x = *(keypoints_data + 0) / simcc_split_ratio_;
68-
y = *(keypoints_data + 1) / simcc_split_ratio_;
77+
x = *(keypoints_data++) / simcc_split_ratio_;
78+
y = *(keypoints_data++) / simcc_split_ratio_;
79+
s = *(scores_data++);
80+
6981
x = x * scale[0] * scale_value / input_size_[0] + center[0] - scale[0] * scale_value * 0.5;
7082
y = y * scale[1] * scale_value / input_size_[1] + center[1] - scale[1] * scale_value * 0.5;
71-
s = *(scores_data + 0);
7283
output.key_points.push_back({{x, y}, s});
73-
keypoints_data += 2;
74-
scores_data += 1;
7584
}
7685
return to_value(output);
7786
}
@@ -104,6 +113,7 @@ class SimCCLabelDecode : public MMPose {
104113

105114
private:
106115
bool flip_test_{false};
116+
bool export_postprocess_{false};
107117
bool shift_heatmap_{false};
108118
float simcc_split_ratio_{2.0};
109119
std::vector<int> input_size_{192, 256};
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
from .post_processing import get_simcc_maximum
4+
5+
__all__ = ['get_simcc_maximum']
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import torch
3+
4+
5+
def get_simcc_maximum(simcc_x: torch.Tensor,
6+
simcc_y: torch.Tensor) -> torch.Tensor:
7+
"""Get maximum response location and value from simcc representations.
8+
9+
rewrite to support `torch.Tensor` input type.
10+
11+
Args:
12+
simcc_x (torch.Tensor): x-axis SimCC in shape (N, K, Wx)
13+
simcc_y (torch.Tensor): y-axis SimCC in shape (N, K, Wy)
14+
15+
Returns:
16+
tuple:
17+
- locs (torch.Tensor): locations of maximum heatmap responses in shape
18+
(N, K, 2)
19+
- vals (torch.Tensor): values of maximum heatmap responses in shape
20+
(N, K)
21+
"""
22+
N, K, _ = simcc_x.shape
23+
simcc_x = simcc_x.flatten(0, 1)
24+
simcc_y = simcc_y.flatten(0, 1)
25+
x_locs = simcc_x.argmax(dim=1, keepdim=True)
26+
y_locs = simcc_y.argmax(dim=1, keepdim=True)
27+
locs = torch.cat((x_locs, y_locs), dim=1).to(torch.float32)
28+
max_val_x, _ = simcc_x.max(dim=1, keepdim=True)
29+
max_val_y, _ = simcc_y.max(dim=1, keepdim=True)
30+
vals, _ = torch.cat([max_val_x, max_val_y], dim=1).min(dim=1)
31+
locs = locs.reshape(N, K, 2)
32+
vals = vals.reshape(N, K)
33+
return locs, vals

mmdeploy/codebase/mmpose/deploy/pose_detection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from mmengine.registry import Registry
1414

1515
from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
16-
from mmdeploy.utils import Codebase, Task, get_input_shape, get_root_logger
16+
from mmdeploy.utils import (Codebase, Task, get_codebase_config,
17+
get_input_shape, get_root_logger)
1718

1819

1920
def process_model_config(
@@ -362,6 +363,9 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
362363
params['post_process'] = 'megvii'
363364
params['modulate_kernel'] = self.model_cfg.kernel_sizes[-1]
364365
elif codec.type == 'SimCCLabel':
366+
export_postprocess = get_codebase_config(self.deploy_cfg).get(
367+
'export_postprocess', False)
368+
params['export_postprocess'] = export_postprocess
365369
component = 'SimCCLabelDecode'
366370
elif codec.type == 'RegressionLabel':
367371
component = 'DeepposeRegressionHeadDecode'

mmdeploy/codebase/mmpose/deploy/pose_detection_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,20 @@ def forward(self,
101101
if self.model_cfg.model.type == 'YOLODetector':
102102
return self.pack_yolox_pose_result(batch_outputs, data_samples)
103103

104+
codebase_cfg = get_codebase_config(self.deploy_cfg)
104105
codec = self.model_cfg.codec
105106
if isinstance(codec, (list, tuple)):
106107
codec = codec[-1]
107108
if codec.type == 'SimCCLabel':
108-
batch_pred_x, batch_pred_y = batch_outputs
109-
preds = self.head.decode((batch_pred_x, batch_pred_y))
109+
export_postprocess = codebase_cfg.get('export_postprocess', False)
110+
if export_postprocess:
111+
keypoints, scores = [_.cpu().numpy() for _ in batch_outputs]
112+
preds = [
113+
InstanceData(keypoints=keypoints, keypoint_scores=scores)
114+
]
115+
else:
116+
batch_pred_x, batch_pred_y = batch_outputs
117+
preds = self.head.decode((batch_pred_x, batch_pred_y))
110118
elif codec.type in ['RegressionLabel', 'IntegralRegressionLabel']:
111119
preds = self.head.decode(batch_outputs)
112120
else:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from . import mspn_head, yolox_pose_head # noqa: F401,F403
2+
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403
33

4-
__all__ = ['mspn_head', 'yolox_pose_head']
4+
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head']
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmdeploy.codebase.mmpose.codecs import get_simcc_maximum
3+
from mmdeploy.core import FUNCTION_REWRITER
4+
from mmdeploy.utils import get_codebase_config
5+
6+
7+
@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.RTMCCHead.forward')
8+
@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.SimCCHead.forward')
9+
def simcc_head__forward(self, feats):
10+
"""Rewrite `forward` of SimCCHead for default backend.
11+
12+
Args:
13+
feats (tuple[Tensor]): Input features.
14+
Returns:
15+
key-points (torch.Tensor): Output keypoints in
16+
shape of (N, K, 3)
17+
"""
18+
ctx = FUNCTION_REWRITER.get_context()
19+
simcc_x, simcc_y = ctx.origin_func(self, feats)
20+
codebase_cfg = get_codebase_config(ctx.cfg)
21+
export_postprocess = codebase_cfg.get('export_postprocess', False)
22+
if not export_postprocess:
23+
return simcc_x, simcc_y
24+
assert self.decoder.use_dark is False, \
25+
'Do not support SimCCLabel with use_dark=True'
26+
pts, scores = get_simcc_maximum(simcc_x, simcc_y)
27+
pts /= self.decoder.simcc_split_ratio
28+
return pts, scores

0 commit comments

Comments
 (0)