Skip to content

Commit f1baed1

Browse files
author
Evgeny Tsykunov
authored
Per-class saliency maps for M-RCNN (#2301)
* per class sal maps for maskrcnn * tiling support + test enablement/fix * enable xai detection e2e tests
1 parent ae81031 commit f1baed1

File tree

14 files changed

+602
-245
lines changed

14 files changed

+602
-245
lines changed

otx/algorithms/common/adapters/mmcv/hooks/recording_forward_hook.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from __future__ import annotations
1717

1818
from abc import ABC
19-
from typing import List, Sequence, Union
19+
from typing import List, Optional, Sequence, Union
2020

21+
import numpy as np
2122
import torch
2223

2324
from otx.algorithms.classification import MMCLS_AVAILABLE
@@ -69,10 +70,24 @@ def _recording_forward(
6970
self, _: torch.nn.Module, x: torch.Tensor, output: torch.Tensor
7071
): # pylint: disable=unused-argument
7172
tensors = self.func(output)
72-
tensors = tensors.detach().cpu().numpy()
73-
for tensor in tensors:
73+
if isinstance(tensors, torch.Tensor):
74+
tensors_np = tensors.detach().cpu().numpy()
75+
elif isinstance(tensors, np.ndarray):
76+
tensors_np = tensors
77+
else:
78+
self._torch_to_numpy_from_list(tensors)
79+
tensors_np = tensors
80+
81+
for tensor in tensors_np:
7482
self._records.append(tensor)
7583

84+
def _torch_to_numpy_from_list(self, tensor_list: List[Optional[torch.Tensor]]):
85+
for i in range(len(tensor_list)):
86+
if isinstance(tensor_list[i], list):
87+
self._torch_to_numpy_from_list(tensor_list[i])
88+
elif isinstance(tensor_list[i], torch.Tensor):
89+
tensor_list[i] = tensor_list[i].detach().cpu().numpy()
90+
7691
def __enter__(self) -> BaseRecordingForwardHook:
7792
"""Enter."""
7893
self._handle = self._module.backbone.register_forward_hook(self._recording_forward)

otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
# Copyright (C) 2023 Intel Corporation
33
# SPDX-License-Identifier: Apache-2.0
44
#
5-
from typing import List, Tuple, Union
5+
import copy
6+
from typing import List, Optional, Tuple, Union
67

8+
import numpy as np
79
import torch
810
import torch.nn.functional as F
11+
from mmdet.core import bbox2roi
912

1013
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
1114
BaseRecordingForwardHook,
@@ -120,3 +123,127 @@ def forward_single(x, cls_convs, conv_cls):
120123
"YOLOXHead, ATSSHead, SSDHead, VFNetHead."
121124
)
122125
return cls_scores
126+
127+
128+
class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook):
129+
"""Saliency map hook for Mask R-CNN model. Only for torch model, does not support OpenVINO IR model.
130+
131+
Args:
132+
module (torch.nn.Module): Mask R-CNN model.
133+
input_img_shape (Tuple[int]): Resolution of the model input image.
134+
saliency_map_shape (Tuple[int]): Resolution of the output saliency map.
135+
max_detections_per_img (int): Upper limit of the number of detections
136+
from which soft mask predictions are getting aggregated.
137+
normalize (bool): Flag that defines if the output saliency map will be normalized.
138+
Although, partial normalization is anyway done by segmentation mask head.
139+
"""
140+
141+
def __init__(
142+
self,
143+
module: torch.nn.Module,
144+
input_img_shape: Tuple[int, int],
145+
saliency_map_shape: Tuple[int, int] = (224, 224),
146+
max_detections_per_img: int = 300,
147+
normalize: bool = True,
148+
) -> None:
149+
super().__init__(module)
150+
self._neck = module.neck if module.with_neck else None
151+
self._input_img_shape = input_img_shape
152+
self._saliency_map_shape = saliency_map_shape
153+
self._max_detections_per_img = max_detections_per_img
154+
self._norm_saliency_maps = normalize
155+
156+
def func(
157+
self,
158+
feature_map: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
159+
_: int = -1,
160+
) -> List[List[Optional[np.ndarray]]]:
161+
"""Generate saliency maps by aggregating per-class soft predictions of mask head for all detected boxes.
162+
163+
:param feature_map: Feature maps from backbone.
164+
:return: Class-wise Saliency Maps. One saliency map per each predicted class.
165+
"""
166+
with torch.no_grad():
167+
if self._neck is not None:
168+
feature_map = self._module.neck(feature_map)
169+
170+
det_bboxes, det_labels = self._get_detections(feature_map)
171+
saliency_maps = self._get_saliency_maps_from_mask_predictions(feature_map, det_bboxes, det_labels)
172+
if self._norm_saliency_maps:
173+
saliency_maps = self._normalize(saliency_maps)
174+
return saliency_maps
175+
176+
def _get_detections(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
177+
batch_size = x[0].shape[0]
178+
img_metas = [
179+
{
180+
"scale_factor": [1, 1, 1, 1], # dummy scale_factor, not used
181+
"img_shape": self._input_img_shape,
182+
}
183+
]
184+
img_metas *= batch_size
185+
proposals = self._module.rpn_head.simple_test_rpn(x, img_metas)
186+
test_cfg = copy.deepcopy(self._module.roi_head.test_cfg)
187+
test_cfg["max_per_img"] = self._max_detections_per_img
188+
test_cfg["nms"]["iou_threshold"] = 1
189+
test_cfg["nms"]["max_num"] = self._max_detections_per_img
190+
det_bboxes, det_labels = self._module.roi_head.simple_test_bboxes(
191+
x, img_metas, proposals, test_cfg, rescale=False
192+
)
193+
return det_bboxes, det_labels
194+
195+
def _get_saliency_maps_from_mask_predictions(
196+
self, x: torch.Tensor, det_bboxes: List[torch.Tensor], det_labels: List[torch.Tensor]
197+
) -> List[List[Optional[np.ndarray]]]:
198+
_bboxes = [det_bboxes[i][:, :4] for i in range(len(det_bboxes))]
199+
mask_rois = bbox2roi(_bboxes)
200+
mask_results = self._module.roi_head._mask_forward(x, mask_rois)
201+
mask_pred = mask_results["mask_pred"]
202+
num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
203+
mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
204+
205+
batch_size = x[0].shape[0]
206+
207+
scale_x = self._input_img_shape[1] / self._saliency_map_shape[1]
208+
scale_y = self._input_img_shape[0] / self._saliency_map_shape[0]
209+
scale_factor = torch.FloatTensor((scale_x, scale_y, scale_x, scale_y))
210+
test_cfg = self._module.roi_head.test_cfg.copy()
211+
test_cfg["mask_thr_binary"] = -1
212+
213+
saliency_maps = [] # type: List[List[Optional[np.ndarray]]]
214+
for i in range(batch_size):
215+
saliency_maps.append([])
216+
for j in range(self._module.roi_head.mask_head.num_classes):
217+
saliency_maps[i].append(None)
218+
219+
for i in range(batch_size):
220+
if det_bboxes[i].shape[0] == 0:
221+
continue
222+
else:
223+
segm_result = self._module.roi_head.mask_head.get_seg_masks(
224+
mask_preds[i],
225+
_bboxes[i],
226+
det_labels[i],
227+
test_cfg,
228+
self._saliency_map_shape,
229+
scale_factor=scale_factor,
230+
rescale=True,
231+
)
232+
for class_id, segm_res in enumerate(segm_result):
233+
if segm_res:
234+
saliency_maps[i][class_id] = np.mean(np.array(segm_res), axis=0)
235+
return saliency_maps
236+
237+
@staticmethod
238+
def _normalize(saliency_maps: List[List[Optional[np.ndarray]]]) -> List[List[Optional[np.ndarray]]]:
239+
batch_size = len(saliency_maps)
240+
num_classes = len(saliency_maps[0])
241+
for i in range(batch_size):
242+
for class_id in range(num_classes):
243+
per_class_map = saliency_maps[i][class_id]
244+
if per_class_map is not None:
245+
max_values = np.max(per_class_map)
246+
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
247+
per_class_map = per_class_map.astype(np.uint8)
248+
saliency_maps[i][class_id] = per_class_map
249+
return saliency_maps

otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_detector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from mmdet.models.detectors.mask_rcnn import MaskRCNN
1111

1212
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
13-
ActivationMapHook,
1413
FeatureVectorHook,
1514
)
1615
from otx.algorithms.common.adapters.mmdeploy.utils import is_mmdeploy_enabled
@@ -99,7 +98,7 @@ def load_state_dict_pre_hook(model, model_classes, chkpt_classes, chkpt_dict, pr
9998
def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **kwargs):
10099
"""Function for custom_mask_rcnn__simple_test."""
101100
assert self.with_bbox, "Bbox head must be implemented."
102-
x = backbone_out = self.backbone(img)
101+
x = self.backbone(img)
103102
if self.with_neck:
104103
x = self.neck(x)
105104
if proposals is None:
@@ -108,7 +107,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None, **k
108107

109108
if ctx.cfg["dump_features"]:
110109
feature_vector = FeatureVectorHook.func(x)
111-
saliency_map = ActivationMapHook.func(backbone_out)
110+
# Saliency map will be generated from predictions. Generate dummy saliency_map.
111+
saliency_map = torch.empty(1, dtype=torch.uint8)
112112
return (*out, feature_vector, saliency_map)
113113

114114
return out

otx/algorithms/detection/adapters/mmdet/models/detectors/custom_maskrcnn_tile_optimized.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,6 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_im
209209

210210
# pylint: disable=ungrouped-imports
211211
from otx.algorithms.common.adapters.mmcv.hooks.recording_forward_hook import (
212-
ActivationMapHook,
213212
FeatureVectorHook,
214213
)
215214

@@ -327,7 +326,7 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):
327326
assert self.with_bbox, "Bbox head must be implemented."
328327
tile_prob = self.tile_classifier.simple_test(img)
329328

330-
x = backbone_out = self.backbone(img)
329+
x = self.backbone(img)
331330
if self.with_neck:
332331
x = self.neck(x)
333332
if proposals is None:
@@ -336,7 +335,8 @@ def custom_mask_rcnn__simple_test(ctx, self, img, img_metas, proposals=None):
336335

337336
if ctx.cfg["dump_features"]:
338337
feature_vector = FeatureVectorHook.func(x)
339-
saliency_map = ActivationMapHook.func(backbone_out)
338+
# Saliency map will be generated from predictions. Generate dummy saliency_map.
339+
saliency_map = torch.empty(1, dtype=torch.uint8)
340340
return (*out, tile_prob, feature_vector, saliency_map)
341341

342342
return (*out, tile_prob)

otx/algorithms/detection/adapters/mmdet/task.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from otx.algorithms.detection.adapters.mmdet.datasets import ImageTilingDataset
6262
from otx.algorithms.detection.adapters.mmdet.hooks.det_class_probability_map_hook import (
6363
DetClassProbabilityMapHook,
64+
MaskRCNNRecordingForwardHook,
6465
)
6566
from otx.algorithms.detection.adapters.mmdet.utils import (
6667
patch_input_preprocessing,
@@ -397,7 +398,8 @@ def hook(module, inp, outp):
397398
if raw_model.__class__.__name__ == "NNCFNetwork":
398399
raw_model = raw_model.get_nncf_wrapped_model()
399400
if isinstance(raw_model, TwoStageDetector):
400-
saliency_hook = ActivationMapHook(feature_model)
401+
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
402+
saliency_hook = MaskRCNNRecordingForwardHook(feature_model, input_img_shape=(height, width))
401403
else:
402404
saliency_hook = DetClassProbabilityMapHook(feature_model)
403405

@@ -515,15 +517,9 @@ def _explain_model(
515517
explain_parameters: Optional[ExplainParameters] = None,
516518
) -> Dict[str, Any]:
517519
"""Main explain function of MMDetectionTask."""
518-
519520
for item in dataset:
520521
item.subset = Subset.TESTING
521522

522-
explainer_hook_selector = {
523-
"classwisesaliencymap": DetClassProbabilityMapHook,
524-
"eigencam": EigenCamHook,
525-
"activationmap": ActivationMapHook,
526-
}
527523
self._data_cfg = ConfigDict(
528524
data=ConfigDict(
529525
train=ConfigDict(
@@ -593,6 +589,18 @@ def hook(module, inp, outp):
593589
model.register_forward_pre_hook(pre_hook)
594590
model.register_forward_hook(hook)
595591

592+
if isinstance(feature_model, TwoStageDetector):
593+
height, width, _ = mm_dataset[0]["img_metas"][0].data["img_shape"]
594+
per_class_xai_algorithm = partial(MaskRCNNRecordingForwardHook, input_img_shape=(width, height))
595+
else:
596+
per_class_xai_algorithm = DetClassProbabilityMapHook # type: ignore
597+
598+
explainer_hook_selector = {
599+
"classwisesaliencymap": per_class_xai_algorithm,
600+
"eigencam": EigenCamHook,
601+
"activationmap": ActivationMapHook,
602+
}
603+
596604
explainer = explain_parameters.explainer if explain_parameters else None
597605
if explainer is not None:
598606
explainer_hook = explainer_hook_selector.get(explainer.lower(), None)
@@ -602,9 +610,8 @@ def hook(module, inp, outp):
602610
raise NotImplementedError(f"Explainer algorithm {explainer} not supported!")
603611
logger.info(f"Explainer algorithm: {explainer}")
604612

605-
# Class-wise Saliency map for Single-Stage Detector, otherwise use class-ignore saliency map.
606613
eval_predictions = []
607-
with explainer_hook(feature_model) as saliency_hook:
614+
with explainer_hook(feature_model) as saliency_hook: # type: ignore
608615
for data in dataloader:
609616
with torch.no_grad():
610617
result = model(return_loss=False, rescale=True, **data)

otx/algorithms/detection/adapters/openvino/model_wrappers/openvino_models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import Dict
1818

19+
import cv2
1920
import numpy as np
2021

2122
try:
@@ -107,6 +108,85 @@ def postprocess(self, outputs, meta):
107108

108109
return scores, classes, boxes, resized_masks
109110

111+
def get_saliency_map_from_prediction(self, outputs, meta, num_classes):
112+
"""Post process function for saliency map of OTX MaskRCNN model."""
113+
boxes = outputs[self.output_blob_name["boxes"]]
114+
if boxes.shape[0] == 1:
115+
boxes = boxes.squeeze(0)
116+
scores = boxes[:, 4]
117+
boxes = boxes[:, :4]
118+
masks = outputs[self.output_blob_name["masks"]]
119+
if masks.shape[0] == 1:
120+
masks = masks.squeeze(0)
121+
classes = outputs[self.output_blob_name["labels"]].astype(np.uint32)
122+
if classes.shape[0] == 1:
123+
classes = classes.squeeze(0)
124+
125+
scale_x = meta["resized_shape"][1] / meta["original_shape"][1]
126+
scale_y = meta["resized_shape"][0] / meta["original_shape"][0]
127+
boxes[:, 0::2] /= scale_x
128+
boxes[:, 1::2] /= scale_y
129+
130+
saliency_maps = [None for _ in range(num_classes)]
131+
for box, score, cls, raw_mask in zip(boxes, scores, classes, masks):
132+
resized_mask = self._resize_mask(box, raw_mask * score, *meta["original_shape"][:-1])
133+
if saliency_maps[cls] is None:
134+
saliency_maps[cls] = [resized_mask]
135+
else:
136+
saliency_maps[cls].append(resized_mask)
137+
138+
saliency_maps = self._average_and_normalize(saliency_maps, num_classes)
139+
return saliency_maps
140+
141+
def _resize_mask(self, box, raw_cls_mask, im_h, im_w):
142+
# Add zero border to prevent upsampling artifacts on segment borders.
143+
raw_cls_mask = np.pad(raw_cls_mask, ((1, 1), (1, 1)), "constant", constant_values=0)
144+
extended_box = self._expand_box(box, raw_cls_mask.shape[0] / (raw_cls_mask.shape[0] - 2.0)).astype(int)
145+
w, h = np.maximum(extended_box[2:] - extended_box[:2] + 1, 1)
146+
x0, y0 = np.clip(extended_box[:2], a_min=0, a_max=[im_w, im_h])
147+
x1, y1 = np.clip(extended_box[2:] + 1, a_min=0, a_max=[im_w, im_h])
148+
149+
raw_cls_mask = cv2.resize(raw_cls_mask.astype(np.float32), (w, h))
150+
# Put an object mask in an image mask.
151+
im_mask = np.zeros((im_h, im_w), dtype=np.float32)
152+
im_mask[y0:y1, x0:x1] = raw_cls_mask[
153+
(y0 - extended_box[1]) : (y1 - extended_box[1]), (x0 - extended_box[0]) : (x1 - extended_box[0])
154+
]
155+
return im_mask
156+
157+
@staticmethod
158+
def _average_and_normalize(saliency_maps, num_classes):
159+
for i in range(num_classes):
160+
if saliency_maps[i] is not None:
161+
saliency_maps[i] = np.array(saliency_maps[i]).mean(0)
162+
163+
for i in range(num_classes):
164+
per_class_map = saliency_maps[i]
165+
if per_class_map is not None:
166+
max_values = np.max(per_class_map)
167+
per_class_map = 255 * (per_class_map) / (max_values + 1e-12)
168+
per_class_map = per_class_map.astype(np.uint8)
169+
saliency_maps[i] = per_class_map
170+
return saliency_maps
171+
172+
def get_tiling_saliency_map_from_prediction(self, detections, num_classes):
173+
"""Post process function for saliency map of OTX MaskRCNN model for tiling."""
174+
saliency_maps = [None for _ in range(num_classes)]
175+
176+
# No detection case
177+
if isinstance(detections, np.ndarray) and detections.size == 0:
178+
return saliency_maps
179+
180+
classes = [int(cls) - 1 for cls in detections[1]]
181+
masks = detections[3]
182+
for mask, cls in zip(masks, classes):
183+
if saliency_maps[cls] is None:
184+
saliency_maps[cls] = [mask]
185+
else:
186+
saliency_maps[cls].append(mask)
187+
saliency_maps = self._average_and_normalize(saliency_maps, num_classes)
188+
return saliency_maps
189+
110190
def segm_postprocess(self, *args, **kwargs):
111191
"""Post-process for segmentation masks."""
112192
return self._segm_postprocess(*args, **kwargs)

0 commit comments

Comments
 (0)