Skip to content

Commit 9f9b3a8

Browse files
lzhangzzlvhan028
andauthored
[Enhancement] Support RTMDet-Ins (#1867)
* support RTMDet-Ins * optimization * avoid out of boundary --------- Co-authored-by: lvhan028 <[email protected]>
1 parent 43383e8 commit 9f9b3a8

File tree

3 files changed

+90
-24
lines changed

3 files changed

+90
-24
lines changed

csrc/mmdeploy/codebase/mmdet/instance_segmentation.cpp

Lines changed: 83 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
#include "mmdeploy/core/registry.h"
44
#include "mmdeploy/core/utils/device_utils.h"
55
#include "mmdeploy/experimental/module_adapter.h"
6+
#include "mmdeploy/operation/managed.h"
7+
#include "mmdeploy/operation/vision.h"
68
#include "object_detection.h"
79
#include "opencv2/imgproc/imgproc.hpp"
810
#include "opencv_utils.h"
@@ -14,7 +16,10 @@ class ResizeInstanceMask : public ResizeBBox {
1416
explicit ResizeInstanceMask(const Value& cfg) : ResizeBBox(cfg) {
1517
if (cfg.contains("params")) {
1618
mask_thr_binary_ = cfg["params"].value("mask_thr_binary", mask_thr_binary_);
19+
is_rcnn_ = cfg["params"].contains("rcnn");
1720
}
21+
operation::Context ctx(device_, stream_);
22+
warp_affine_ = operation::Managed<operation::WarpAffine>::Create("bilinear");
1823
}
1924

2025
// TODO: remove duplication
@@ -53,15 +58,17 @@ class ResizeInstanceMask : public ResizeBBox {
5358

5459
OUTCOME_TRY(auto _dets, MakeAvailableOnDevice(dets, kHost, stream()));
5560
OUTCOME_TRY(auto _labels, MakeAvailableOnDevice(labels, kHost, stream()));
56-
OUTCOME_TRY(auto _masks, MakeAvailableOnDevice(masks, kHost, stream()));
57-
OUTCOME_TRY(stream().Wait());
61+
// Note: `masks` are kept on device to avoid data copy overhead from device to host.
62+
// refer to https://github.com/open-mmlab/mmdeploy/issues/1849
63+
// OUTCOME_TRY(auto _masks, MakeAvailableOnDevice(masks, kHost, stream()));
64+
// OUTCOME_TRY(stream().Wait());
5865

5966
OUTCOME_TRY(auto result, DispatchGetBBoxes(prep_res["img_metas"], _dets, _labels));
6067

6168
auto ori_w = prep_res["img_metas"]["ori_shape"][2].get<int>();
6269
auto ori_h = prep_res["img_metas"]["ori_shape"][1].get<int>();
6370

64-
ProcessMasks(result, _masks, ori_w, ori_h);
71+
ProcessMasks(result, masks, _dets, ori_w, ori_h);
6572

6673
return to_value(result);
6774
} catch (const std::exception& e) {
@@ -71,14 +78,23 @@ class ResizeInstanceMask : public ResizeBBox {
7178
}
7279

7380
protected:
74-
void ProcessMasks(Detections& result, Tensor cpu_masks, int img_w, int img_h) const {
75-
auto shape = TensorShape{cpu_masks.shape(1), cpu_masks.shape(2), cpu_masks.shape(3)};
76-
cpu_masks.Reshape(shape);
77-
MMDEPLOY_DEBUG("{}, {}", cpu_masks.shape(), cpu_masks.data_type());
81+
Result<void> ProcessMasks(Detections& result, Tensor d_mask, Tensor cpu_dets, int img_w,
82+
int img_h) {
83+
d_mask.Squeeze(0);
84+
cpu_dets.Squeeze(0);
85+
86+
::mmdeploy::operation::Context ctx(device_, stream_);
87+
88+
std::vector<Tensor> warped_masks;
89+
warped_masks.reserve(result.size());
90+
91+
std::vector<Tensor> h_warped_masks;
92+
h_warped_masks.reserve(result.size());
93+
7894
for (auto& det : result) {
79-
auto mask = cpu_masks.Slice(det.index);
80-
cv::Mat mask_mat((int)mask.shape(1), (int)mask.shape(2), CV_32F, mask.data<float>());
81-
cv::Mat warped_mask;
95+
auto mask = d_mask.Slice(det.index);
96+
auto mask_height = (int)mask.shape(1);
97+
auto mask_width = (int)mask.shape(2);
8298
auto& bbox = det.bbox;
8399
// same as mmdet with skip_empty = True
84100
auto x0 = std::max(std::floor(bbox[0]) - 1, 0.f);
@@ -88,22 +104,67 @@ class ResizeInstanceMask : public ResizeBBox {
88104
auto width = static_cast<int>(x1 - x0);
89105
auto height = static_cast<int>(y1 - y0);
90106
// params align_corners = False
91-
auto fx = (float)mask_mat.cols / (bbox[2] - bbox[0]);
92-
auto fy = (float)mask_mat.rows / (bbox[3] - bbox[1]);
93-
auto tx = (x0 + .5f - bbox[0]) * fx - .5f;
94-
auto ty = (y0 + .5f - bbox[1]) * fy - .5f;
95-
96-
cv::Mat m = (cv::Mat_<float>(2, 3) << fx, 0, tx, 0, fy, ty);
97-
cv::warpAffine(mask_mat, warped_mask, m, cv::Size{width, height},
98-
cv::INTER_LINEAR | cv::WARP_INVERSE_MAP);
99-
warped_mask = warped_mask > mask_thr_binary_;
100-
101-
det.mask = Mat(height, width, PixelFormat::kGRAYSCALE, DataType::kINT8,
102-
std::shared_ptr<void>(warped_mask.data, [mat = warped_mask](void*) {}));
107+
float fx;
108+
float fy;
109+
float tx;
110+
float ty;
111+
if (is_rcnn_) { // mask r-cnn
112+
fx = (float)mask_width / (bbox[2] - bbox[0]);
113+
fy = (float)mask_height / (bbox[3] - bbox[1]);
114+
tx = (x0 + .5f - bbox[0]) * fx - .5f;
115+
ty = (y0 + .5f - bbox[1]) * fy - .5f;
116+
} else { // rtmdet-ins
117+
auto raw_bbox = cpu_dets.Slice(det.index);
118+
auto raw_bbox_data = raw_bbox.data<float>();
119+
fx = (raw_bbox_data[2] - raw_bbox_data[0]) / (bbox[2] - bbox[0]);
120+
fy = (raw_bbox_data[3] - raw_bbox_data[1]) / (bbox[3] - bbox[1]);
121+
tx = (x0 + .5f - bbox[0]) * fx - .5f + raw_bbox_data[0];
122+
ty = (y0 + .5f - bbox[1]) * fy - .5f + raw_bbox_data[1];
123+
}
124+
125+
float affine_matrix[] = {fx, 0, tx, 0, fy, ty};
126+
127+
cv::Mat_<float> m(2, 3, affine_matrix);
128+
cv::invertAffineTransform(m, m);
129+
130+
mask.Reshape({1, mask_height, mask_width, 1});
131+
132+
Tensor& warped_mask = warped_masks.emplace_back();
133+
OUTCOME_TRY(warp_affine_.Apply(mask, warped_mask, affine_matrix, height, width));
134+
135+
OUTCOME_TRY(CopyToHost(warped_mask, h_warped_masks.emplace_back()));
103136
}
137+
138+
OUTCOME_TRY(stream_.Wait());
139+
140+
for (size_t i = 0; i < h_warped_masks.size(); ++i) {
141+
result[i].mask = ThresholdMask(h_warped_masks[i]);
142+
}
143+
144+
return success();
145+
}
146+
147+
Result<void> CopyToHost(const Tensor& src, Tensor& dst) {
148+
if (src.device() == kHost) {
149+
dst = src;
150+
return success();
151+
}
152+
dst = TensorDesc{kHost, src.data_type(), src.shape()};
153+
OUTCOME_TRY(stream_.Copy(src.buffer(), dst.buffer(), dst.byte_size()));
154+
return success();
155+
}
156+
157+
Mat ThresholdMask(const Tensor& h_mask) const {
158+
cv::Mat warped_mat = cpu::Tensor2CVMat(h_mask);
159+
warped_mat = warped_mat > mask_thr_binary_;
160+
return {warped_mat.rows, warped_mat.cols, PixelFormat::kGRAYSCALE, DataType::kINT8,
161+
std::shared_ptr<void>(warped_mat.data, [mat = warped_mat](void*) {})};
104162
}
105163

164+
private:
165+
operation::Managed<operation::WarpAffine> warp_affine_;
106166
float mask_thr_binary_{.5f};
167+
bool is_rcnn_{true};
107168
};
108169

109170
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, ResizeInstanceMask);

demo/python/object_detection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import argparse
3+
import math
34

45
import cv2
56
from mmdeploy_python import Detector
@@ -36,7 +37,10 @@ def main():
3637
if masks[index].size:
3738
mask = masks[index]
3839
blue, green, red = cv2.split(img)
39-
mask_img = blue[top:top + mask.shape[0], left:left + mask.shape[1]]
40+
41+
x0 = int(max(math.floor(bbox[0]) - 1, 0))
42+
y0 = int(max(math.floor(bbox[1]) - 1, 0))
43+
mask_img = blue[y0:y0 + mask.shape[0], x0:x0 + mask.shape[1]]
4044
cv2.bitwise_or(mask, mask_img, mask_img)
4145
img = cv2.merge([blue, green, red])
4246

mmdeploy/codebase/mmdet/deploy/object_detection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
311311
params['score_thr'] = params['rcnn']['score_thr']
312312
if 'mask_thr_binary' in params['rcnn']:
313313
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
314-
type = 'ResizeInstanceMask' # for instance-seg
314+
if 'mask_thr_binary' in params:
315+
type = 'ResizeInstanceMask' # for instance-seg
315316
if get_backend(self.deploy_cfg) == Backend.RKNN:
316317
if 'YOLO' in self.model_cfg.model.type or \
317318
'RTMDet' in self.model_cfg.model.type:

0 commit comments

Comments
 (0)