Skip to content

Commit 3d157ab

Browse files
author
Evgeny Tsykunov
authored
Fixing detection saliency map for one class case (#2368)
* fix softmax * fix validity tests
1 parent 895bd36 commit 3d157ab

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

src/otx/algorithms/detection/adapters/mmdet/hooks/det_class_probability_map_hook.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def func(
6262

6363
# Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects,
6464
# it would highlight one of the class maps as a background class
65-
if self.use_cls_softmax:
65+
if self.use_cls_softmax and self._num_cls_out_channels > 1:
6666
cls_scores = [torch.softmax(t, dim=1) for t in cls_scores]
6767

6868
batch_size, _, height, width = cls_scores[-1].size()

tests/unit/algorithms/classification/test_xai_classification_validity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,6 @@ def test_saliency_map_cls(self, template):
5454
assert len(saliency_maps) == 2
5555
assert saliency_maps[0].ndim == 3
5656
assert saliency_maps[0].shape == (1000, 7, 7)
57-
assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_cls[template.name]) <= 1)
57+
actual_sal_vals = saliency_maps[0][0][0].astype(np.int8)
58+
ref_sal_vals = self.ref_saliency_vals_cls[template.name].astype(np.int8)
59+
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)

tests/unit/algorithms/detection/test_xai_detection_validity.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def test_saliency_map_det(self, template):
8080
assert len(saliency_maps) == 2
8181
assert saliency_maps[0].ndim == 3
8282
assert saliency_maps[0].shape == self.ref_saliency_shapes[template.name]
83-
assert np.all(np.abs(saliency_maps[0][0][0] - self.ref_saliency_vals_det[template.name]) <= 1)
83+
actual_sal_vals = saliency_maps[0][0][0].astype(np.int8)
84+
ref_sal_vals = self.ref_saliency_vals_det[template.name].astype(np.int8)
85+
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)
8486

8587
@e2e_pytest_unit
8688
@pytest.mark.parametrize("template", templates_det, ids=templates_det_ids)

0 commit comments

Comments
 (0)