Skip to content

Commit 2bdc356

Browse files
authored
RTMDet Inst Seg Explain Mode for 2.2 (#4083)
* Explain mode for RTMDet Inst Seg * Update changelog * reformat changelog
1 parent 7bb36ef commit 2bdc356

File tree

6 files changed

+25
-8
lines changed

6 files changed

+25
-8
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ All notable changes to this project will be documented in this file.
110110
(<https://github.com/openvinotoolkit/training_extensions/pull/4067>)
111111
- Fix wrong model name in converter & template
112112
(<https://github.com/openvinotoolkit/training_extensions/pull/4082>)
113+
- Fix RTMDet Inst Explain Mode
114+
(<https://github.com/openvinotoolkit/training_extensions/pull/4083>)
113115

114116
## \[v2.1.0\]
115117

src/otx/algo/detection/base_models/single_stage_detector.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111

1212
from typing import TYPE_CHECKING
1313

14+
import torch
15+
16+
from otx.algo.instance_segmentation.heads.rtmdet_inst_head import RTMDetInstSepBNHead
1417
from otx.algo.modules.base_module import BaseModule
1518
from otx.algo.utils.mmengine_utils import InstanceData
1619
from otx.core.data.entity.detection import DetBatchDataEntity
1720

1821
if TYPE_CHECKING:
19-
import torch
2022
from torch import Tensor, nn
2123

2224

@@ -210,6 +212,19 @@ def export(
210212
backbone_feat = self.extract_feat(batch_inputs)
211213
bbox_head_feat = self.bbox_head.forward(backbone_feat)
212214
feature_vector = self.feature_vector_fn(backbone_feat)
215+
216+
if isinstance(self.bbox_head, RTMDetInstSepBNHead):
217+
# create dummy saliency map as its implemented in ModelAPI
218+
saliency_map = torch.zeros(1)
219+
bboxes, labels, masks = self.bbox_head.export(backbone_feat, batch_img_metas, rescale=rescale) # type: ignore[misc]
220+
return {
221+
"bboxes": bboxes,
222+
"labels": labels,
223+
"masks": masks,
224+
"feature_vector": feature_vector,
225+
"saliency_map": saliency_map,
226+
}
227+
213228
saliency_map = self.explain_fn(bbox_head_feat[0])
214229
bboxes, labels = self.bbox_head.export(backbone_feat, batch_img_metas, rescale=rescale)
215230
return {

src/otx/algo/instance_segmentation/heads/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .fcn_mask_head import FCNMaskHead
88
from .roi_head_tv import TVRoIHeads
99
from .rpn_head import RPNHead
10-
from .rtmdet_ins_head import RTMDetInsSepBNHead
10+
from .rtmdet_inst_head import RTMDetInstSepBNHead
1111

1212
__all__ = [
1313
"Shared2FCBBoxHead",
@@ -16,5 +16,5 @@
1616
"FCNMaskHead",
1717
"TVRoIHeads",
1818
"RPNHead",
19-
"RTMDetInsSepBNHead",
19+
"RTMDetInstSepBNHead",
2020
]

src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py renamed to src/otx/algo/instance_segmentation/heads/rtmdet_inst_head.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
# mypy: disable-error-code="call-overload, index, override, attr-defined, misc"
4747

4848

49-
class RTMDetInsHead(RTMDetHead):
49+
class RTMDetInstHead(RTMDetHead):
5050
"""Detection Head of RTMDet-Ins.
5151
5252
Args:
@@ -764,7 +764,7 @@ def forward(self, features: tuple[Tensor, ...]) -> Tensor:
764764
return self.projection(mask_features)
765765

766766

767-
class RTMDetInsSepBNHead(RTMDetInsHead):
767+
class RTMDetInstSepBNHead(RTMDetInstHead):
768768
"""Detection Head of RTMDet-Ins with sep-bn layers.
769769
770770
Args:

src/otx/algo/instance_segmentation/rtmdet_inst.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from otx.algo.common.utils.samplers import PseudoSampler
1919
from otx.algo.detection.base_models import SingleStageDetector
2020
from otx.algo.detection.necks import CSPNeXtPAFPN
21-
from otx.algo.instance_segmentation.heads import RTMDetInsSepBNHead
21+
from otx.algo.instance_segmentation.heads import RTMDetInstSepBNHead
2222
from otx.algo.instance_segmentation.losses import DiceLoss
2323
from otx.algo.modules.norm import build_norm_layer
2424
from otx.core.config.data import TileConfig
@@ -155,7 +155,7 @@ def _build_model(self, num_classes: int) -> SingleStageDetector:
155155
activation=partial(nn.SiLU, inplace=True),
156156
)
157157

158-
bbox_head = RTMDetInsSepBNHead(
158+
bbox_head = RTMDetInstSepBNHead(
159159
num_classes=num_classes,
160160
in_channels=96,
161161
stacked_convs=2,

tests/integration/api/test_xai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_predict_with_explain(
110110
if "dino" in model_name:
111111
pytest.skip("DINO is not supported.")
112112

113-
if any(keyword in recipe for keyword in ["rtmdet_inst_tiny", "maskdino", "maskrcnn_r50_tv"]):
113+
if any(keyword in recipe for keyword in ["maskdino", "maskrcnn_r50_tv"]):
114114
# TODO(Eugene): inst-seg models not fully support yet.
115115
pytest.skip(f"There's issue with inst-seg: {recipe}. Skip for now.")
116116

0 commit comments

Comments
 (0)