Skip to content

Commit 4c376d9

Browse files
BoomerlRunningLeon
andauthored
CodeCamp2023-555 (#2469)
* support condinst from mmdet * remove * update * update * support batch inference * add condinst head unit testing * fix lint error * remove * fix bug in postprocess * remove * update --------- Co-authored-by: RunningLeon <[email protected]>
1 parent e74901f commit 4c376d9

File tree

8 files changed

+240
-4
lines changed

8 files changed

+240
-4
lines changed

csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ class ResizeInstanceMask : public ResizeBBox {
149149
int resize_width = int(mask_width / scale_factor_[1] + 0.5);
150150
// skip resize if scale_factor is 1.0
151151
if (resize_height != mask_height || resize_width != mask_width) {
152-
cv::resize(mask_mat, mask_mat, cv::Size(resize_height, resize_width), cv::INTER_LINEAR);
152+
cv::resize(mask_mat, mask_mat, cv::Size(resize_width, resize_height), cv::INTER_LINEAR);
153153
}
154154
// crop masks
155155
mask_mat = mask_mat(cv::Range(0, img_h), cv::Range(0, img_w)).clone();

mmdeploy/codebase/mmdet/deploy/object_detection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
320320
type = 'ResizeInstanceMask' # for instance-seg
321321
# resize and crop mask to origin image
322322
params['is_resize_mask'] = True
323+
if 'mask_thr' in params:
324+
type = 'ResizeInstanceMask' # for instance-seg
325+
# resize and crop mask to origin image
326+
params['mask_thr_binary'] = params['mask_thr']
327+
params['is_resize_mask'] = True
323328

324329
if get_backend(self.deploy_cfg) == Backend.RKNN:
325330
if 'YOLO' in self.model_cfg.model.type or \

mmdeploy/codebase/mmdet/deploy/object_detection_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def postprocessing_results(self,
241241
masks = batch_masks[i]
242242
img_h, img_w = img_metas[i]['img_shape'][:2]
243243
ori_h, ori_w = img_metas[i]['ori_shape'][:2]
244-
if model_type == 'RTMDet':
244+
if model_type in ['RTMDet', 'CondInst']:
245245
export_postprocess_mask = True
246246
else:
247247
export_postprocess_mask = False

mmdeploy/codebase/mmdet/models/dense_heads/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from . import base_dense_head # noqa: F401,F403
33
from . import centernet_head # noqa: F401,F403
4+
from . import condinst_head # noqa: F401,F403
45
from . import detr_head # noqa: F401,F403
56
from . import fovea_head # noqa: F401,F403
67
from . import gfl_head # noqa: F401,F403
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import Dict, List, Optional
3+
4+
import torch
5+
from mmdet.models.utils import aligned_bilinear
6+
from mmengine.config import ConfigDict
7+
from torch import Tensor
8+
9+
from mmdeploy.codebase.mmdet.deploy import get_post_processing_params
10+
from mmdeploy.core import FUNCTION_REWRITER
11+
from mmdeploy.mmcv.ops.nms import multiclass_nms
12+
13+
14+
@FUNCTION_REWRITER.register_rewriter(
15+
'mmdet.models.dense_heads.CondInstBboxHead.predict_by_feat')
16+
def condinst_bbox_head__predict_by_feat(
17+
self,
18+
cls_scores: List[Tensor],
19+
bbox_preds: List[Tensor],
20+
score_factors: Optional[List[Tensor]] = None,
21+
param_preds: Optional[List[Tensor]] = None,
22+
batch_img_metas: Optional[List[dict]] = None,
23+
cfg: Optional[ConfigDict] = None,
24+
rescale: bool = False,
25+
with_nms: bool = True,
26+
):
27+
ctx = FUNCTION_REWRITER.get_context()
28+
deploy_cfg = ctx.cfg
29+
30+
assert len(cls_scores) == len(bbox_preds)
31+
device = bbox_preds[0].device
32+
cfg = self.test_cfg if cfg is None else cfg
33+
batch_size = bbox_preds[0].shape[0]
34+
featmap_sizes = [cls_score.shape[-2:] for cls_score in cls_scores]
35+
36+
all_level_points_strides = self.prior_generator.grid_priors(
37+
featmap_sizes, device=device, with_stride=True)
38+
all_level_points = [i[:, :2] for i in all_level_points_strides]
39+
all_level_strides = [i[:, 2] for i in all_level_points_strides]
40+
41+
flatten_cls_scores = [
42+
cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
43+
self.cls_out_channels)
44+
for cls_score in cls_scores
45+
]
46+
flatten_bbox_preds = [
47+
bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4)
48+
for bbox_pred in bbox_preds
49+
]
50+
flatten_score_factors = [
51+
score_factor.permute(0, 2, 3, 1).reshape(batch_size, -1, 1)
52+
for score_factor in score_factors
53+
]
54+
flatten_param_preds = [
55+
param_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, self.num_params)
56+
for param_pred in param_preds
57+
]
58+
flatten_cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
59+
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
60+
flatten_score_factors = torch.cat(flatten_score_factors, dim=1).sigmoid()
61+
flatten_param_preds = torch.cat(flatten_param_preds, dim=1)
62+
63+
points = torch.cat(all_level_points)
64+
strides = torch.cat(all_level_strides)
65+
tl_x = points[..., 0] - flatten_bbox_preds[..., 0]
66+
tl_y = points[..., 1] - flatten_bbox_preds[..., 1]
67+
br_x = points[..., 0] + flatten_bbox_preds[..., 2]
68+
br_y = points[..., 1] + flatten_bbox_preds[..., 3]
69+
70+
bboxes = torch.stack([tl_x, tl_y, br_x, br_y], -1)
71+
scores = flatten_cls_scores
72+
score_factors = flatten_score_factors
73+
param_preds = flatten_param_preds
74+
scores = scores * score_factors
75+
76+
# get post processing config
77+
post_params = get_post_processing_params(deploy_cfg)
78+
max_output_boxes_per_class = post_params.max_output_boxes_per_class
79+
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
80+
score_threshold = cfg.get('score_thr', post_params.score_threshold)
81+
pre_top_k = post_params.pre_top_k
82+
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
83+
84+
dets, labels, inds = multiclass_nms(
85+
bboxes,
86+
scores,
87+
max_output_boxes_per_class,
88+
iou_threshold,
89+
score_threshold,
90+
pre_top_k=pre_top_k,
91+
keep_top_k=keep_top_k,
92+
output_index=True,
93+
)
94+
95+
batch_inds = torch.arange(batch_size, device=bboxes.device).view(-1, 1)
96+
points = points.unsqueeze(0).repeat(batch_size, 1, 1)
97+
strides = strides.unsqueeze(0).repeat(batch_size, 1)
98+
param_preds = param_preds[batch_inds, inds, :]
99+
points = points[batch_inds, inds, :]
100+
strides = strides[batch_inds, inds]
101+
results = dict(
102+
dets=dets,
103+
labels=labels,
104+
param_preds=param_preds,
105+
points=points,
106+
strides=strides)
107+
return results
108+
109+
110+
@FUNCTION_REWRITER.register_rewriter(
111+
'mmdet.models.dense_heads.CondInstMaskHead.forward')
112+
def condinst_mask_head__forward(self, x: tuple,
113+
positive_infos: Dict[str, torch.Tensor]):
114+
mask_feats = self.mask_feature_head(x)
115+
116+
param_preds = positive_infos['param_preds']
117+
points = positive_infos['points']
118+
strides = positive_infos['strides']
119+
120+
batch_size = points.shape[0]
121+
num_insts = points.shape[1]
122+
hw = mask_feats.size()[-2:]
123+
mask_feats = mask_feats.unsqueeze(1).repeat(1, num_insts, 1, 1, 1)
124+
125+
points = points.reshape(-1, 1, 2).unsqueeze(0)
126+
locations = self.prior_generator.single_level_grid_priors(
127+
hw, level_idx=0, device=mask_feats.device)
128+
locations = locations.unsqueeze(0).repeat(batch_size, 1,
129+
1).reshape(batch_size, 1, -1, 2)
130+
centers = points.reshape(batch_size, -1, 1, 2)
131+
rel_coordinates = (centers - locations).permute(0, 1, 3, 2).float()
132+
rel_coordinates /= (strides[:, :, None, None] * self.size_of_interest)
133+
rel_coords = rel_coordinates.reshape(batch_size, -1, 2, hw[0], hw[1])
134+
mask_head_inputs = torch.cat([rel_coords, mask_feats], dim=2)
135+
136+
weights, biases = _parse_dynamic_params(self, param_preds)
137+
mask_preds = _dynamic_conv_forward(mask_head_inputs, weights, biases)
138+
mask_preds = mask_preds.reshape(batch_size, num_insts, hw[0], hw[1])
139+
mask_preds = aligned_bilinear(
140+
mask_preds, int(self.mask_feat_stride / self.mask_out_stride))
141+
return (mask_preds, )
142+
143+
144+
@FUNCTION_REWRITER.register_rewriter(
145+
'mmdet.models.dense_heads.CondInstMaskHead.predict_by_feat')
146+
def condinst_mask_head__predict_by_feat(self,
147+
mask_preds: Tensor,
148+
results_list: Dict[str, torch.Tensor],
149+
batch_img_metas: List[dict],
150+
rescale: bool = True,
151+
**kwargs):
152+
cfg = self.test_cfg
153+
154+
dets = results_list['dets']
155+
labels = results_list['labels']
156+
img_hw = batch_img_metas[0]['img_shape'][:2]
157+
158+
mask_preds = mask_preds.sigmoid()
159+
mask_preds = aligned_bilinear(mask_preds, self.mask_out_stride)
160+
mask_preds = mask_preds[:, :, :img_hw[0], :img_hw[1]]
161+
masks = (mask_preds > cfg.mask_thr).float()
162+
163+
return dets, labels, masks
164+
165+
166+
def _parse_dynamic_params(self, params: Tensor):
167+
"""parse the dynamic params for dynamic conv."""
168+
batch_size = params.shape[0]
169+
num_insts = params.shape[1]
170+
params = params.permute(1, 0, 2)
171+
params_splits = list(
172+
torch.split_with_sizes(
173+
params, self.weight_nums + self.bias_nums, dim=2))
174+
175+
weight_splits = params_splits[:self.num_layers]
176+
bias_splits = params_splits[self.num_layers:]
177+
178+
for idx in range(self.num_layers):
179+
if idx < self.num_layers - 1:
180+
weight_splits[idx] = weight_splits[idx].reshape(
181+
batch_size, num_insts, self.in_channels, -1)
182+
else:
183+
weight_splits[idx] = weight_splits[idx].reshape(
184+
batch_size, num_insts, 1, -1)
185+
186+
return weight_splits, bias_splits
187+
188+
189+
def _dynamic_conv_forward(features: Tensor, weights: List[Tensor],
190+
biases: List[Tensor]):
191+
"""dynamic forward, each layer follow a relu."""
192+
n_layers = len(weights)
193+
x = features.flatten(0, 1).flatten(2)
194+
for i, (w, b) in enumerate(zip(weights, biases)):
195+
# replace dynamic conv with bmm
196+
w = w.flatten(0, 1)
197+
b = b.flatten(0, 1).unsqueeze(2)
198+
x = torch.bmm(w, x)
199+
x = x + b
200+
if i < n_layers - 1:
201+
x = x.clamp_(min=0)
202+
return x

mmdeploy/codebase/mmdet/models/detectors/single_stage_instance_seg.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,30 @@
1212
'instance_segmentor_forward',
1313
inputs=['input'],
1414
outputs=['dets', 'labels', 'masks'])
15-
def __forward_impl_instance_seg(self, batch_inputs, data_samples, **kwargs):
15+
def __forward_impl_instance_seg(self,
16+
batch_inputs,
17+
data_samples,
18+
rescale=True,
19+
**kwargs):
1620
"""Rewrite and adding mark for `forward`.
1721
1822
Encapsulate this function for rewriting `forward` of BaseDetector.
1923
1. Add mark for BaseDetector.
2024
2. Support both dynamic and static export to onnx.
2125
"""
2226
x = self.extract_feat(batch_inputs)
23-
mask_outs = self.mask_head.predict(x, data_samples, rescale=False)
27+
if self.with_bbox:
28+
# the bbox branch does not need to be scaled to the original
29+
# image scale, because the mask branch will scale both bbox
30+
# and mask at the same time.
31+
bbox_rescale = rescale if not self.with_mask else False
32+
results_list = self.bbox_head.predict(
33+
x, data_samples, rescale=bbox_rescale)
34+
else:
35+
results_list = None
36+
37+
mask_outs = self.mask_head.predict(
38+
x, data_samples, rescale=rescale, results_list=results_list)
2439
return mask_outs
2540

2641

mmdeploy/pytorch/functions/repeat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def tensor__repeat__tensorrt(input: torch.Tensor, *size: Union[torch.Size,
1919

2020
origin_func = ctx.origin_func
2121
if input.dim() == 1 and len(size) == 1:
22+
if isinstance(*size, tuple):
23+
return origin_func(input.unsqueeze(0),
24+
*([1] + list(*size))).squeeze(0)
2225
return origin_func(input.unsqueeze(0), *([1] + list(size))).squeeze(0)
2326
else:
2427
return origin_func(input, *size)

tests/regression/mmdet.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,13 @@ models:
446446
pipelines:
447447
- *pipeline_ort_dynamic_fp32
448448
- *pipeline_trt_dynamic_fp32
449+
450+
- name: CondInst
451+
metafile: configs/condinst/metafile.yml
452+
model_configs:
453+
- configs/condinst/condinst_r50_fpn_ms-poly-90k_coco_instance.py
454+
pipelines:
455+
- deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py
456+
backend_test: *default_backend_test
457+
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
458+
backend_test: *default_backend_test

0 commit comments

Comments
 (0)