Skip to content

Commit d48187c

Browse files
[Feature] Support RTMDet and RTMPose ncnn deployment (#1857)
* support rtmpose ncnn * fix docformatter * fix docformatter * fix classname from tauj to dev-1.x branch * rename file * fix comments * remove unused rewriter * fix norm * fix lint * fix rtmcc_block * fix norm * add ut * fix origin_func * fix norm * fix rtmdet_head * add ut * false run_with_backend for ncnn * fix lint
1 parent 423e27a commit d48187c

File tree

12 files changed

+440
-2
lines changed

12 files changed

+440
-2
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
2+
3+
backend_config = dict(precision='FP16')
4+
codebase_config = dict(model_type='ncnn_end2end')
5+
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = ['../_base_/base_static.py', '../../_base_/backends/ncnn.py']
2+
3+
codebase_config = dict(model_type='ncnn_end2end')
4+
onnx_config = dict(output_names=['detection_output'], input_shape=[320, 320])
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = ['./pose-detection_static.py', '../_base_/backends/ncnn.py']
2+
3+
backend_config = dict(precision='FP16')
4+
onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])

csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,8 +2200,6 @@ int main(int argc, char** argv) {
22002200
}
22012201
fprintf(pp, " 4=%d", keepdims);
22022202
fprintf(pp, " 5=1");
2203-
// Force set Reduction for FP32, FP16 may exceed for some models.
2204-
fprintf(pp, " 31=15");
22052203
} else if (op == "Reorg") {
22062204
int stride = get_node_attr_i(node, "stride", 1);
22072205
fprintf(pp, " 0=%d", stride);

mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from mmdeploy.codebase.mmdet import get_post_processing_params
1010
from mmdeploy.core import FUNCTION_REWRITER, mark
1111
from mmdeploy.mmcv.ops import multiclass_nms
12+
from mmdeploy.utils import Backend
1213

1314

1415
@FUNCTION_REWRITER.register_rewriter(
@@ -105,3 +106,120 @@ def __mark_pred_maps(cls_scores, bbox_preds):
105106
score_threshold=score_threshold,
106107
pre_top_k=pre_top_k,
107108
keep_top_k=keep_top_k)
109+
110+
111+
@FUNCTION_REWRITER.register_rewriter(
112+
func_name='mmdet.models.dense_heads.rtmdet_head.'
113+
'RTMDetHead.predict_by_feat',
114+
backend=Backend.NCNN.value)
115+
def rtmdet_head__predict_by_feat__ncnn(
116+
self,
117+
cls_scores: List[Tensor],
118+
bbox_preds: List[Tensor],
119+
batch_img_metas: Optional[List[dict]] = None,
120+
cfg: Optional[ConfigDict] = None,
121+
rescale: bool = False,
122+
with_nms: bool = True):
123+
"""Rewrite `predict_by_feat` of RTMDetHead for ncnn backend.
124+
1. Decode the prior to a box format for ncnn DetectionOutput layer to do
125+
the post-processing.
126+
2. Batch dimension is not supported by ncnn, but supported by pytorch.
127+
The negative value of axis in torch.cat is rewritten as corresponding
128+
positive value to avoid axis shift.
129+
3. 2-dimension tensor broadcast of `BinaryOps` operator is not supported by
130+
ncnn. This function unsqueeze 2-dimension tensor to 3-dimension tensor for
131+
correct `BinaryOps` calculation by ncnn.
132+
Args:
133+
cls_scores (list[Tensor]): Classification scores for all
134+
scale levels, each is a 4D-tensor, has shape
135+
(batch_size, num_priors * num_classes, H, W).
136+
bbox_preds (list[Tensor]): Box energies / deltas for all
137+
scale levels, each is a 4D-tensor, has shape
138+
(batch_size, num_priors * 4, H, W).
139+
objectnesses (list[Tensor], Optional): Score factor for
140+
all scale level, each is a 4D-tensor, has shape
141+
(batch_size, 1, H, W).
142+
batch_img_metas (list[dict], Optional): Batch image meta info.
143+
Defaults to None.
144+
cfg (ConfigDict, optional): Test / postprocessing
145+
configuration, if None, test_cfg would be used.
146+
Defaults to None.
147+
rescale (bool): If True, return boxes in original image space.
148+
Defaults to False.
149+
with_nms (bool): If True, do nms before return boxes.
150+
Defaults to True.
151+
Returns:
152+
output__ncnn (Tensor): outputs, shape is [N, num_det, 6].
153+
"""
154+
ctx = FUNCTION_REWRITER.get_context()
155+
from mmdeploy.codebase.mmdet.ops import ncnn_detection_output_forward
156+
from mmdeploy.utils import get_root_logger
157+
from mmdeploy.utils.config_utils import is_dynamic_shape
158+
dynamic_flag = is_dynamic_shape(ctx.cfg)
159+
if dynamic_flag:
160+
logger = get_root_logger()
161+
logger.warning('RTMDet does not support dynamic shape with ncnn.')
162+
img_height = int(batch_img_metas[0]['img_shape'][0])
163+
img_width = int(batch_img_metas[0]['img_shape'][1])
164+
165+
assert len(cls_scores) == len(bbox_preds)
166+
device = cls_scores[0].device
167+
cfg = self.test_cfg if cfg is None else cfg
168+
batch_size = bbox_preds[0].shape[0]
169+
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]
170+
mlvl_priors = self.prior_generator.grid_priors(
171+
featmap_sizes, device=device, with_stride=True)
172+
mlvl_priors = [mlvl_prior.unsqueeze(0) for mlvl_prior in mlvl_priors]
173+
flatten_priors = torch.cat(mlvl_priors, dim=1)
174+
175+
flatten_cls_scores = [
176+
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
177+
self.cls_out_channels)
178+
for cls_score in cls_scores
179+
]
180+
flatten_bbox_preds = [
181+
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
182+
for bbox_pred in bbox_preds
183+
]
184+
185+
cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
186+
dummy_cls_scores = torch.zeros(
187+
batch_size, cls_scores.shape[-2], 1, device=cls_scores.device)
188+
189+
batch_mlvl_scores = torch.cat([dummy_cls_scores, cls_scores], dim=2)
190+
191+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
192+
assert flatten_priors.shape[-1] == 4, f'rtmdet needs (B, N, 4) priors, got\
193+
(B, N, {flatten_priors.shape[-1]})'
194+
195+
tl_x = (flatten_priors[:, :, 0:1] -
196+
flatten_bbox_preds[:, :, 0:1]) / img_width
197+
tl_y = (flatten_priors[:, :, 1:2] -
198+
flatten_bbox_preds[:, :, 1:2]) / img_height
199+
br_x = (flatten_priors[:, :, 0:1] +
200+
flatten_bbox_preds[:, :, 2:3]) / img_width
201+
br_y = (flatten_priors[:, :, 1:2] +
202+
flatten_bbox_preds[:, :, 3:4]) / img_height
203+
prior_box_ncnn = torch.stack([tl_x, tl_y, br_x, br_y], -1)
204+
205+
scores = batch_mlvl_scores
206+
207+
batch_mlvl_bboxes = flatten_bbox_preds.reshape(batch_size, 1, -1)
208+
batch_mlvl_scores = scores.reshape(batch_size, 1, -1)
209+
batch_mlvl_priors = prior_box_ncnn.reshape(batch_size, 1, -1)
210+
batch_mlvl_vars = torch.ones_like(batch_mlvl_priors)
211+
batch_mlvl_priors = torch.cat([batch_mlvl_priors, batch_mlvl_vars], dim=1)
212+
deploy_cfg = ctx.cfg
213+
post_params = get_post_processing_params(deploy_cfg)
214+
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
215+
score_threshold = cfg.get('score_thr', post_params.score_threshold)
216+
pre_top_k = post_params.pre_top_k
217+
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
218+
219+
vars = torch.tensor([1, 1, 1, 1], dtype=torch.float32)
220+
output__ncnn = ncnn_detection_output_forward(
221+
batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_priors,
222+
score_threshold, iou_threshold, pre_top_k, keep_top_k,
223+
self.num_classes + 1,
224+
vars.cpu().detach().numpy())
225+
return output__ncnn

mmdeploy/codebase/mmpose/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22

33
from . import heads # noqa: F401,F403
44
from . import pose_estimators # noqa: F401,F403
5+
from . import utils # noqa: F401,F403
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
from . import rtmcc_block
4+
5+
__all__ = ['rtmcc_block']
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from mmpose.models.utils import rope
6+
7+
from mmdeploy.core import FUNCTION_REWRITER
8+
9+
10+
@FUNCTION_REWRITER.register_rewriter(
11+
'mmpose.models.utils.rtmcc_block.ScaleNorm.forward', backend='ncnn')
12+
def scalenorm__forward__ncnn(self, x):
13+
"""Rewrite `scalenorm` for ncnn backend.
14+
15+
Rewrite scalenorm to avoid FP16 exceed in ncnn Android platform.
16+
"""
17+
# The one-dim of Fubinious norm is equal to L2Norm.
18+
# Set p=2 explicitly to map torch.norm to ReduceL2 onnx op,
19+
# which will avoid FP16 exceed.
20+
norm = torch.norm(x, dim=2, keepdim=True)
21+
norm = norm * self.scale
22+
# Rewrite for ncnn binaryop broadcast.
23+
norm = norm.clamp(min=self.eps)
24+
return (x.unsqueeze(2) / norm.unsqueeze(2)).squeeze(2) * self.g
25+
26+
27+
@FUNCTION_REWRITER.register_rewriter(
28+
'mmpose.models.utils.rtmcc_block.RTMCCBlock._forward', backend='ncnn')
29+
def rtmccblock___forward_ncnn(self, inputs):
30+
"""Rewrite `_forward` of RTMBlock for ncnn backend.
31+
32+
Rewrite the matmul and avoid unbind for ncnn backend.
33+
"""
34+
if self.attn_type == 'self-attn':
35+
x = inputs
36+
else:
37+
x, k, v = inputs
38+
39+
x = self.ln(x)
40+
uv = self.uv(x)
41+
if self.attn_type == 'self-attn':
42+
uv = self.act_fn(uv)
43+
u = uv[..., :self.e]
44+
v = uv[..., self.e:2 * self.e]
45+
base = uv[..., 2 * self.e:2 * self.e + self.s]
46+
47+
q = (base.unsqueeze(1) * self.gamma[None, None, 0:1, :] +
48+
self.beta[None, None, 0:1, :]).squeeze(1)
49+
k = (base.unsqueeze(1) * self.gamma[None, None, 1:2, :] +
50+
self.beta[None, None, 1:2, :]).squeeze(1)
51+
52+
if self.pos_enc:
53+
q = rope(q, dim=1)
54+
k = rope(k, dim=1)
55+
else:
56+
u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1)
57+
58+
k = self.k_fc(k)
59+
v = self.v_fc(v)
60+
61+
if self.pos_enc:
62+
q = rope(q, 1)
63+
k = rope(k, 1)
64+
qk = torch.bmm(q, k.permute(0, 2, 1))
65+
if self.use_rel_bias:
66+
if self.attn_type == 'self-attn':
67+
bias = self.rel_pos_bias(q.size(1))
68+
else:
69+
bias = self.rel_pos_bias(q.size(1), k.size(1))
70+
qk += bias[:, :q.size(1), :k.size(1)]
71+
72+
kernel = torch.square(F.relu(qk / self.sqrt_s))
73+
if self.dropout_rate > 0.:
74+
kernel = self.dropout(kernel)
75+
76+
x = u * torch.bmm(kernel, v)
77+
x = self.o(x)
78+
79+
return x
80+
81+
82+
@FUNCTION_REWRITER.register_rewriter(
83+
'mmpose.models.utils.rtmcc_block.Scale.forward', backend='ncnn')
84+
def scale__forward_ncnn(self, x):
85+
"""Rewrite `forward` of Scale for ncnn backend.
86+
87+
Adapt the shape to avoid ncnn BinaryOp seg fault.
88+
"""
89+
x = x.unsqueeze(1)
90+
scale = self.scale[None, None, None, :]
91+
return (x * scale).squeeze(1)

mmdeploy/pytorch/functions/normalize.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3+
from typing import Optional, Sequence, Union
4+
35
import torch
46

57
from mmdeploy.core import FUNCTION_REWRITER
@@ -39,3 +41,26 @@ def normalize__ncnn(input: torch.Tensor,
3941
input.transpose(1, dim), p=p, dim=1,
4042
eps=eps).transpose(1, dim)
4143
return output
44+
45+
46+
@FUNCTION_REWRITER.register_rewriter(func_name='torch.norm', backend='ncnn')
47+
def norm__ncnn(input: torch.Tensor,
48+
p: Optional[Union[int, str]] = 'fro',
49+
dim: Optional[Union[int, Sequence]] = None,
50+
keepdim: Optional[bool] = False,
51+
out: Optional[torch.Tensor] = None,
52+
dtype: Optional[torch.dtype] = None):
53+
"""Rewrite `torch.norm` for ncnn backend.
54+
55+
Rewrite torch.norm when p is Frobenius norm to avoid FP16 exceed in ncnn
56+
Android platform.
57+
"""
58+
ctx = FUNCTION_REWRITER.get_context()
59+
origin_func = ctx.origin_func
60+
if p == 'fro' and (isinstance(dim, int) or len(dim) == 1):
61+
# Substitute Frobenius norm with L2 norm.
62+
return origin_func(
63+
input, p=2, dim=dim, keepdim=keepdim, out=out, dtype=dtype)
64+
else:
65+
return origin_func(
66+
input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype)

tests/test_codebase/test_mmdet/test_mmdet_models.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,3 +2121,88 @@ def test_solo_head_predict_by_feat(backend_type: Backend):
21212121
atol=1e-05)
21222122
else:
21232123
assert rewrite_outputs is not None
2124+
2125+
2126+
def get_rtmdet_head_model():
2127+
2128+
from mmdet.models.dense_heads import RTMDetHead
2129+
from mmdet.models.task_modules.prior_generators.point_generator import \
2130+
MlvlPointGenerator
2131+
2132+
test_cfg = Config(
2133+
dict(
2134+
deploy_nms_pre=0,
2135+
min_bbox_size=0,
2136+
score_thr=0.05,
2137+
nms=dict(type='nms', iou_threshold=0.6),
2138+
max_per_img=100))
2139+
model = RTMDetHead(1, 64)
2140+
model.prior_generator = MlvlPointGenerator([8, 4, 2])
2141+
model.test_cfg = test_cfg
2142+
2143+
model.requires_grad_(False)
2144+
return model
2145+
2146+
2147+
def test_rtmdet_head_predict_by_feat_ncnn():
2148+
"""Test predict_by_feat rewrite of yolov3 head."""
2149+
backend_type = Backend.NCNN
2150+
check_backend(backend_type)
2151+
rtmdet_head = get_rtmdet_head_model()
2152+
rtmdet_head.cpu().eval()
2153+
s = 320
2154+
batch_img_metas = [{
2155+
'scale_factor': np.ones(4),
2156+
'pad_shape': (s, s, 3),
2157+
'img_shape': (s, s, 3)
2158+
}]
2159+
2160+
output_names = ['detection_output']
2161+
deploy_cfg = Config(
2162+
dict(
2163+
backend_config=dict(type=backend_type.value),
2164+
onnx_config=dict(output_names=output_names, input_shape=None),
2165+
codebase_config=dict(
2166+
type='mmdet',
2167+
model_type='ncnn_end2end',
2168+
task='ObjectDetection',
2169+
post_processing=dict(
2170+
score_threshold=0.05,
2171+
iou_threshold=0.45,
2172+
confidence_threshold=0.005,
2173+
max_output_boxes_per_class=200,
2174+
pre_top_k=-1,
2175+
keep_top_k=10,
2176+
background_label_id=-1,
2177+
))))
2178+
2179+
seed_everything(1234)
2180+
cls_scores = [
2181+
torch.rand(1, 1, 40, 40),
2182+
torch.rand(1, 1, 20, 20),
2183+
torch.rand(1, 1, 10, 10)
2184+
]
2185+
2186+
bbox_preds = [
2187+
torch.rand(1, 4, 40, 40),
2188+
torch.rand(1, 4, 20, 20),
2189+
torch.rand(1, 4, 10, 10)
2190+
]
2191+
2192+
# to get outputs of onnx model after rewrite
2193+
wrapped_model = WrapModel(
2194+
rtmdet_head,
2195+
'predict_by_feat',
2196+
batch_img_metas=batch_img_metas,
2197+
with_nms=True)
2198+
rewrite_inputs = {'cls_scores': cls_scores, 'bbox_preds': bbox_preds}
2199+
rewrite_outputs, is_backend_output = get_rewrite_outputs(
2200+
wrapped_model=wrapped_model,
2201+
model_inputs=rewrite_inputs,
2202+
deploy_cfg=deploy_cfg,
2203+
run_with_backend=False)
2204+
# output should be of shape [1, N, 6]
2205+
if is_backend_output:
2206+
assert rewrite_outputs[0].shape[-1] == 6
2207+
else:
2208+
assert rewrite_outputs.shape[-1] == 6

0 commit comments

Comments
 (0)