Skip to content

Commit 8559def

Browse files
author
Galina Zalesskaya
authored
Fix XAI algorithm for Detection (#2609)
* Impove saliency maps algorithm for Detection * Remove extra changes * Update unit tests * Changes for 1 class * Fix pre-commit * Update CHANGELOG
1 parent 794a814 commit 8559def

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ All notable changes to this project will be documented in this file.
1616
- Fix mmcls bug not wrapping model in DataParallel on CPUs (<https://github.com/openvinotoolkit/training_extensions/pull/2601>)
1717
- Fix h-label loss normalization issue w/ exclusive label group of singe label (<https://github.com/openvinotoolkit/training_extensions/pull/2604>)
1818
- Fix division by zero in class incremental learning for classification (<https://github.com/openvinotoolkit/training_extensions/pull/2606>)
19+
- Fix saliency maps calculation issue for detection models (<https://github.com/openvinotoolkit/training_extensions/pull/2609>)
1920

2021
## \[v1.4.3\]
2122

src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,9 @@ def func(
6060
else:
6161
cls_scores = self._get_cls_scores_from_feature_map(feature_map)
6262

63-
# Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects,
64-
# it would highlight one of the class maps as a background class
65-
if self.use_cls_softmax and self._num_cls_out_channels > 1:
66-
cls_scores = [torch.softmax(t, dim=1) for t in cls_scores]
67-
68-
batch_size, _, height, width = cls_scores[-1].size()
63+
middle_idx = len(cls_scores) // 2
64+
# resize to the middle feature map
65+
batch_size, _, height, width = cls_scores[middle_idx].size()
6966
saliency_maps = torch.empty(batch_size, self._num_cls_out_channels, height, width)
7067
for batch_idx in range(batch_size):
7168
cls_scores_anchorless = []
@@ -82,6 +79,11 @@ def func(
8279
)
8380
saliency_maps[batch_idx] = torch.cat(cls_scores_anchorless_resized, dim=0).mean(dim=0)
8481

82+
# Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects,
83+
# it would highlight one of the class maps as a background class
84+
if self.use_cls_softmax:
85+
saliency_maps[0] = torch.stack([torch.softmax(t, dim=1) for t in saliency_maps[0]])
86+
8587
if self._norm_saliency_maps:
8688
saliency_maps = saliency_maps.reshape((batch_size, self._num_cls_out_channels, -1))
8789
saliency_maps = self._normalize_map(saliency_maps)

tests/unit/algorithms/detection/test_xai_detection_validity.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,19 @@
2424

2525
class TestExplainMethods:
2626
ref_saliency_shapes = {
27-
"MobileNetV2-ATSS": (2, 4, 4),
27+
"MobileNetV2-ATSS": (2, 13, 13),
2828
"SSD": (81, 13, 13),
29-
"YOLOX": (80, 13, 13),
29+
"YOLOX": (80, 26, 26),
3030
}
3131

3232
ref_saliency_vals_det = {
33-
"MobileNetV2-ATSS": np.array([67, 216, 255, 57], dtype=np.uint8),
34-
"YOLOX": np.array([80, 28, 42, 53, 49, 68, 72, 75, 69, 57, 65, 6, 157], dtype=np.uint8),
35-
"SSD": np.array([119, 72, 118, 35, 39, 30, 31, 31, 36, 28, 44, 23, 61], dtype=np.uint8),
33+
"MobileNetV2-ATSS": np.array([34, 67, 148, 132, 172, 147, 146, 155, 167, 159], dtype=np.uint8),
34+
"YOLOX": np.array([177, 94, 147, 147, 161, 162, 164, 164, 163, 166], dtype=np.uint8),
35+
"SSD": np.array([255, 178, 212, 90, 93, 79, 79, 80, 87, 83], dtype=np.uint8),
3636
}
3737

3838
ref_saliency_vals_det_wo_postprocess = {
39-
"MobileNetV2-ATSS": -0.10465062,
39+
"MobileNetV2-ATSS": -0.014513552,
4040
"YOLOX": 0.04948914,
4141
"SSD": 0.6629989,
4242
}
@@ -80,8 +80,8 @@ def test_saliency_map_det(self, template):
8080
assert len(saliency_maps) == 2
8181
assert saliency_maps[0].ndim == 3
8282
assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name]
83-
actual_sal_vals = saliency_maps[0][0][0].astype(np.int8)
84-
ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.int8)
83+
actual_sal_vals = saliency_maps[0][0][0][:10].astype(np.int16)
84+
ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.uint8)
8585
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)
8686

8787
@e2e_pytest_unit

0 commit comments

Comments
 (0)