Skip to content

Commit 1090fb6

Browse files
authored
support htc (#2438)
* support htc * update mmdet.yml
1 parent c4dc10d commit 1090fb6

File tree

6 files changed

+62
-3
lines changed

6 files changed

+62
-3
lines changed

docs/en/04-supported-codebases/mmdet.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter
214214
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
215215
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
216216
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
217+
| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y |
217218
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
218219
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
219220
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |

docs/zh_cn/04-supported-codebases/mmdet.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ cv2.imwrite('output_detection.png', img)
217217
| [CenterNet](https://github.com/open-mmlab/mmdetection/tree/main/configs/centernet) | Object Detection | Y | Y | N | ? | Y |
218218
| [RTMDet](https://github.com/open-mmlab/mmdetection/tree/main/configs/rtmdet) | Object Detection | Y | Y | N | ? | Y |
219219
| [Cascade Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/cascade_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
220+
| [HTC](https://github.com/open-mmlab/mmdetection/tree/main/configs/htc) | Instance Segmentation | Y | Y | N | ? | Y |
220221
| [Mask R-CNN](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask_rcnn) | Instance Segmentation | Y | Y | N | N | Y |
221222
| [Swin Transformer](https://github.com/open-mmlab/mmdetection/tree/main/configs/swin) | Instance Segmentation | Y | Y | N | N | Y |
222223
| [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y |

mmdeploy/codebase/mmdet/models/roi_heads/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from . import bbox_head # noqa: F401,F403
33
from . import cascade_roi_head # noqa: F401,F403
44
from . import fcn_mask_head # noqa: F401,F403
5+
from . import htc_roi_head # noqa: F401,F403
56
from . import single_level_roi_extractor # noqa: F401,F403
67
from . import standard_roi_head # noqa: F401,F403

mmdeploy/codebase/mmdet/models/roi_heads/cascade_roi_head.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def cascade_roi_head__predict_bbox(self,
1616
batch_img_metas: List[dict],
1717
rpn_results_list: List[Tensor],
1818
rcnn_test_cfg: ConfigType,
19-
rescale: bool = False) -> List[Tensor]:
19+
rescale: bool = False,
20+
**kwargs) -> List[Tensor]:
2021
"""Rewrite `predict_bbox` of `CascadeRoIHead` for default backend.
2122
2223
Args:
@@ -52,8 +53,7 @@ def cascade_roi_head__predict_bbox(self,
5253
ms_scores = []
5354
max_shape = batch_img_metas[0]['img_shape']
5455
for i in range(self.num_stages):
55-
bbox_results = self._bbox_forward(i, x, rois)
56-
56+
bbox_results = self._bbox_forward(i, x, rois, **kwargs)
5757
cls_score = bbox_results['cls_score']
5858
bbox_pred = bbox_results['bbox_pred']
5959
# Recover the batch dimension
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from typing import List, Tuple
3+
4+
import torch
5+
from torch import Tensor
6+
7+
from mmdeploy.core import FUNCTION_REWRITER
8+
9+
10+
@FUNCTION_REWRITER.register_rewriter(
11+
'mmdet.models.roi_heads.htc_roi_head.HybridTaskCascadeRoIHead.predict_mask'
12+
)
13+
def htc_roi_head__predict_mask(self,
14+
x: Tuple[Tensor],
15+
semantic_heat: Tensor,
16+
batch_img_metas: List[dict],
17+
results_list: List[Tensor],
18+
rescale: bool = False) -> List[Tensor]:
19+
dets, det_labels = results_list
20+
21+
batch_size = dets.size(0)
22+
det_bboxes = dets[..., :4]
23+
batch_index = torch.arange(
24+
det_bboxes.size(0),
25+
device=det_bboxes.device).float().view(-1, 1, 1).expand(
26+
det_bboxes.size(0), det_bboxes.size(1), 1)
27+
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
28+
mask_rois = mask_rois.view(-1, 5)
29+
30+
mask_results = self._mask_forward(
31+
stage=-1,
32+
x=x,
33+
rois=mask_rois,
34+
semantic_feat=semantic_heat,
35+
training=False)
36+
37+
mask_preds = mask_results['mask_preds'][0]
38+
num_det = det_bboxes.shape[1]
39+
segm_results = self.mask_head[-1].predict_by_feat(
40+
mask_preds,
41+
results_list,
42+
batch_img_metas,
43+
self.test_cfg,
44+
rescale=rescale)
45+
segm_results = segm_results.reshape(batch_size, num_det,
46+
segm_results.shape[-2],
47+
segm_results.shape[-1])
48+
return dets, det_labels, segm_results

tests/regression/mmdet.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,3 +456,11 @@ models:
456456
backend_test: *default_backend_test
457457
- deploy_config: configs/mmdet/instance-seg/instance-seg_tensorrt_dynamic-320x320-1344x1344.py
458458
backend_test: *default_backend_test
459+
460+
- name: HTC
461+
metafile: configs/htc/metafile.yml
462+
model_configs:
463+
- configs/htc/htc_r50_fpn_1x_coco.py
464+
pipelines:
465+
- *pipeline_seg_ort_dynamic_fp32
466+
- *pipeline_seg_trt_dynamic_fp32

0 commit comments

Comments
 (0)