diff --git a/docs/en/04-supported-codebases/mmdet.md b/docs/en/04-supported-codebases/mmdet.md index 16bbacb299..3e7815a413 100644 --- a/docs/en/04-supported-codebases/mmdet.md +++ b/docs/en/04-supported-codebases/mmdet.md @@ -220,6 +220,7 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | | [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | | [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N | +| [SparseInst](https://github.com/open-mmlab/mmdetection/blob/main/projects/SparseInst) | Instance Segmentation | Y | Y | N | N | N | | [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | | [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | | [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | diff --git a/docs/zh_cn/04-supported-codebases/mmdet.md b/docs/zh_cn/04-supported-codebases/mmdet.md index c131f76698..7b1f431158 100644 --- a/docs/zh_cn/04-supported-codebases/mmdet.md +++ b/docs/zh_cn/04-supported-codebases/mmdet.md @@ -223,6 +223,7 @@ cv2.imwrite('output_detection.png', img) | [SOLO](https://github.com/open-mmlab/mmdetection/tree/main/configs/solo) | Instance Segmentation | Y | N | N | N | Y | | [SOLOv2](https://github.com/open-mmlab/mmdetection/tree/main/configs/solov2) | Instance Segmentation | Y | N | N | N | Y | | [CondInst](https://github.com/open-mmlab/mmdetection/tree/main/configs/condinst) | Instance Segmentation | Y | Y | N | N | N | +| [SparseInst](https://github.com/open-mmlab/mmdetection/blob/main/projects/SparseInst) | Instance Segmentation | Y | Y | N | N | N | | [Panoptic FPN](https://github.com/open-mmlab/mmdetection/tree/main/configs/panoptic_fpn) | Panoptic Segmentation | Y | Y | N | N | N | | [MaskFormer](https://github.com/open-mmlab/mmdetection/tree/main/configs/maskformer) | Panoptic Segmentation | Y | Y | N | N | N | | [Mask2Former](https://github.com/open-mmlab/mmdetection/tree/main/configs/mask2former)[\*](#mask2former) | Panoptic Segmentation | Y | Y | N | N | N | diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index c6a958e5eb..dab5b074b6 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -241,7 +241,7 @@ def postprocessing_results(self, masks = batch_masks[i] img_h, img_w = img_metas[i]['img_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2] - if model_type in ['RTMDet', 'CondInst']: + if model_type in ['RTMDet', 'CondInst', 'SparseInst']: export_postprocess_mask = True else: export_postprocess_mask = False diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 062bc7de52..3bee17f449 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -11,5 +11,6 @@ from . import rtmdet_ins_head # noqa: F401,F403 from . import solo_head # noqa: F401,F403 from . import solov2_head # noqa: F401,F403 +from . import sparseinst_head # noqa: F401,F403 from . import yolo_head # noqa: F401,F403 from . import yolox_head # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py new file mode 100644 index 0000000000..8e4c3e487d --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/sparseinst_head.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +import torch.nn.functional as F +from torch import Tensor + +from mmdeploy.core import FUNCTION_REWRITER + + +@torch.jit.script +def rescoring_mask(scores, mask_pred, masks): + mask_pred_ = mask_pred.float() + return scores * ((masks * mask_pred_).sum([2, 3]) / + (mask_pred_.sum([2, 3]) + 1e-6)) + + +@FUNCTION_REWRITER.register_rewriter( + 'projects.SparseInst.sparseinst.SparseInst.predict') +def sparseinst__predict( + self, + batch_inputs: Tensor, + batch_data_samples: List[dict], + rescale: bool = False, +): + """Rewrite `predict` of `SparseInst` for default backend.""" + max_shape = batch_inputs.shape[-2:] + x = self.extract_feat(batch_inputs) + output = self.decoder(x) + + pred_scores = output['pred_logits'].sigmoid() + pred_masks = output['pred_masks'].sigmoid() + pred_objectness = output['pred_scores'].sigmoid() + pred_scores = torch.sqrt(pred_scores * pred_objectness) + + # max/argmax + scores, labels = pred_scores.max(dim=-1) + # cls threshold + keep = scores > self.cls_threshold + scores = scores.where(keep, scores.new_zeros(1)) + labels = labels.where(keep, labels.new_zeros(1)) + keep = keep.unsqueeze(-1).unsqueeze(-1).expand_as(pred_masks) + pred_masks = pred_masks.where(keep, pred_masks.new_zeros(1)) + + img_meta = batch_data_samples[0].metainfo + # rescoring mask using maskness + scores = rescoring_mask(scores, pred_masks > self.mask_threshold, + pred_masks) + h, w = img_meta['img_shape'][:2] + pred_masks = F.interpolate( + pred_masks, size=max_shape, mode='bilinear', + align_corners=False)[:, :, :h, :w] + + bboxes = torch.zeros(scores.shape[0], scores.shape[1], 4) + dets = torch.cat([bboxes, scores.unsqueeze(-1)], dim=-1) + masks = (pred_masks > self.mask_threshold).float() + + return dets, labels, masks diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 232e50ac5e..530f403a7b 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -2582,3 +2582,93 @@ def forward(self, x, param_preds, points, strides): deploy_cfg=deploy_cfg) assert rewrite_outputs is not None + + +def get_sparseinst(): + """SparseInst Config.""" + test_cfg = Config(dict(score_thr=0.4, mask_thr_binary=0.45)) + data_preprocessor = dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + pad_size_divisor=32) + backbone = Config( + dict( + type='ResNet', + depth=50, + out_indices=(1, 2, 3), + frozen_stages=0, + norm_cfg=dict(type='BN', requires_grad=False), + init_cfg=dict( + type='Pretrained', checkpoint='torchvision://resnet50'))) + + from projects.SparseInst.sparseinst import SparseInst + model = SparseInst( + data_preprocessor=data_preprocessor, + backbone=backbone, + encoder=dict( + type='InstanceContextEncoder', in_channels=[512, 1024, 2048]), + decoder=dict( + type='BaseIAMDecoder', in_channels=256 + 2, num_classes=80), + criterion=dict( + type='SparseInstCriterion', + num_classes=80, + assigner=dict(type='SparseInstMatcher', alpha=0.8, beta=0.2)), + test_cfg=test_cfg, + init_cfg=dict( + type='Normal', + layer='Conv2d', + std=0.01, + override=dict( + type='Normal', name='conv_cls', std=0.01, bias_prob=0.01))) + + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_sparseinst_predict(backend_type): + """Test predict rewrite of sparseinst.""" + check_backend(backend_type) + sparseinst = get_sparseinst() + sparseinst.cpu().eval() + + output_names = ['dets', 'labels', 'masks'] + deploy_cfg = Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + post_processing=dict( + score_threshold=0.05, + confidence_threshold=0.005, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + export_postprocess_mask=False)))) + + img = torch.randn(1, 3, 320, 320) + from mmdet.structures import DetDataSample + data_sample = DetDataSample(metainfo=dict(img_shape=(320, 320, 3))) + + # to get outputs of onnx model after rewrite + wrapped_model = WrapModel( + sparseinst, 'predict', batch_data_samples=[data_sample]) + rewrite_inputs = {'batch_inputs': img} + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + + if is_backend_output: + assert rewrite_outputs[0].shape[-1] == 5 + assert rewrite_outputs[1] is not None + assert rewrite_outputs[2] is not None + else: + assert rewrite_outputs is not None