Skip to content

Commit ebc78a1

Browse files
authored
(release/2.5) Remove duplicate explain() method and consolidate XAI functionality into predict() (#4493)
* Refactor XAI utilities and remove deprecated explain method
1 parent acb572b commit ebc78a1

File tree

15 files changed

+315
-386
lines changed

15 files changed

+315
-386
lines changed

docs/source/guide/explanation/additional_features/xai.rst

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ It looks like a heatmap, where warm-colored areas represent the areas with main
1616
These images are taken from `D-RISE paper <https://arxiv.org/abs/2006.03204>`_.
1717

1818

19-
We can generate saliency maps for a certain model that was trained in OpenVINO™ Training Extensions, using ``otx explain`` command line. Learn more about its usage in :doc:`../../tutorials/base/explain` tutorial.
19+
We can generate saliency maps for a certain model that was trained in OpenVINO™ Training Extensions, using ``otx predict --explain True`` command line. Learn more about its usage in :doc:`../../tutorials/base/explain` tutorial.
2020

2121
*********************************
2222
XAI algorithms for classification
@@ -109,15 +109,22 @@ For instance segmentation networks the following algorithm is used to generate s
109109

110110
.. code-block:: python
111111
112-
engine.explain(
113-
checkpoint="<checkpoint-path>", # .pth or .xml weights of the model
112+
engine.predict(
113+
checkpoint="checkpoint.pth", # Use .pth when instantiating the engine with OTXEngine
114114
datamodule=OTXDataModule(), # The data module to use for predictions
115-
dump=True # Wherether to save saliency map images or not
115+
explain=True # Enable explainability features
116116
)
117+
118+
engine.predict(
119+
checkpoint="exported_model.xml", # Use .xml when instantiating the engine with OVEngine
120+
datamodule=OTXDataModule(), # The data module to use for predictions
121+
explain=True # Enable explainability features
122+
)
117123
118124
.. tab-item:: CLI
119125

120126
.. code-block:: bash
121127
122-
(otx) ...$ otx explain ... --checkpoint <checkpoint-path> # .pth or .xml weights of the model
123-
--data_root <dataset_path> # Path to data folder or single image
128+
(otx) ...$ otx predict ... --checkpoint <checkpoint-path> # .pth or .xml weights of the model
129+
--data_root <dataset_path> # Path to data folder or single image
130+
--explain True # Enable explainability features

docs/source/guide/get_started/api_tutorial.rst

Lines changed: 103 additions & 188 deletions
Large diffs are not rendered by default.

docs/source/guide/get_started/cli_commands.rst

Lines changed: 115 additions & 47 deletions
Large diffs are not rendered by default.

docs/source/guide/tutorials/base/explain.rst

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ To be specific, this tutorial uses as an example of the ATSS model trained throu
1212

1313
For visualization we use images from WGISD dataset from the :doc:`object detection tutorial <how_to_train/detection>` together with trained model.
1414

15-
1. Activate the virtual environment
15+
1. Activate the virtual environment
1616
created in the previous step.
1717

1818
.. code-block:: shell
@@ -21,7 +21,7 @@ created in the previous step.
2121
# or by this line, if you created an environment, using tox
2222
. venv/otx/bin/activate
2323
24-
2. ``otx explain`` command returns saliency maps,
24+
2. ``otx predict`` with the ``--explain True`` parameter returns saliency maps,
2525
which are heatmaps with red-colored areas indicating focus. Here's an example how to generate saliency maps from trained checkpoint:
2626

2727
.. tab-set::
@@ -30,31 +30,33 @@ which are heatmaps with red-colored areas indicating focus. Here's an example ho
3030

3131
.. code-block:: shell
3232
33-
(otx) ...$ otx explain --work_dir otx-workspace \
33+
(otx) ...$ otx predict --work_dir otx-workspace \
34+
--explain True \
3435
--explain_config.postprocess True # Resizes and applies colormap to the saliency map
3536
3637
.. tab-item:: CLI (with config)
3738

3839
.. code-block:: shell
3940
40-
(otx) ...$ otx explain --config src/otx/recipe/detection/atss_mobilenetv2.yaml \
41+
(otx) ...$ otx predict --config src/otx/recipe/detection/atss_mobilenetv2.yaml \
4142
--data_root data/wgisd \
42-
--checkpoint otx-workspace/20240312_051135/checkpoints/epoch_033.ckpt \
43+
--checkpoint otx-workspace/.latest/train/best_checkpoint.ckpt \
44+
--explain True \
4345
--explain_config.postprocess True # Resizes and applies colormap to the saliency map
4446
4547
.. tab-item:: API
4648

4749
.. code-block:: python
4850
49-
engine.explain(
50-
checkpoint="<checkpoint-path>",
51+
engine.predict(
52+
checkpoint="otx-workspace/.latest/train/best_checkpoint.ckpt",
5153
datamodule=OTXDataModule(...), # The data module to use for predictions
54+
explain=True,
5255
explain_config=ExplainConfig(postprocess=True), # Resizes and applies colormap to the saliency map
53-
dump=True # Wherether to save saliency map images or not
5456
)
5557
56-
3. The generated saliency maps will appear in ``otx-workspace/.latest/explain/saliency_maps`` folder.
57-
It will contain a pair of generated images with saliency maps for each image used for the explanation process:
58+
3. The generated saliency maps will appear in ``otx-workspace/.latest/explain/saliency_maps`` folder.
59+
It will contain a pair of generated images with saliency maps for each image used for the explanation process:
5860

5961
- saliency map - where red color means more attention of the model
6062
- overlay - where the saliency map is combined with the original image:
@@ -64,7 +66,31 @@ It will contain a pair of generated images with saliency maps for each image use
6466

6567
|
6668
67-
4. We can parametrize the explanation process by specifying
69+
4. To use the exported OpenVINO IR model for explanation, PyTorch weights should be converted to OpenVINO IR model with additional outputs ``saliency_map`` and ``feature_map``.
70+
To do that we should use ``otx export --explain True`` parameter during export.
71+
72+
.. tab-set::
73+
74+
.. tab-item:: CLI
75+
76+
.. code-block:: shell
77+
78+
(otx) ...$ otx export ... --explain True
79+
(otx) ...$ otx predict ... --checkpoint otx-workspace/20240312_052847/exported_model.xml --explain True
80+
81+
.. tab-item:: API
82+
83+
.. code-block:: python
84+
# Use .pth when instantiating the engine with OTXEngine
85+
engine = OTXEngine(model="checkpoint.pth", ...)
86+
engine.export(..., explain=True)
87+
engine.predict(..., explain=True)
88+
89+
90+
engine = OVEngine(model="exported_model.xml", ...)
91+
engine.predict(..., explain=True)
92+
93+
5. We can parametrize the explanation process by specifying
6894
the following parameters in ``ExplainConfig``:
6995

7096
- ``target_explain_group`` - for which target saliency maps will be generated:
@@ -84,21 +110,23 @@ the following parameters in ``ExplainConfig``:
84110

85111
.. code-block:: shell
86112
87-
(otx) ...$ otx explain ... --explain_config.postprocess True
113+
(otx) ...$ otx predict ... --explain True \
114+
--explain_config.postprocess True \
88115
--explain_config.target_explain_group PREDICTIONS
89116
90117
.. tab-item:: API
91118

92119
.. code-block:: python
93120
94-
engine.explain(...,
121+
engine.predict(...,
122+
explain=True,
95123
explain_config=ExplainConfig(
96124
postprocess=True,
97125
target_explain_group=TargetExplainGroup.PREDICTIONS
98126
)
99127
)
100128
101-
5. The explanation algorithm is chosen automatically
129+
6. The explanation algorithm is chosen automatically
102130
based on the used model:
103131

104132
- ``Recipro-CAM`` - for CNN classification models

src/otx/backend/native/engine.py

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -421,10 +421,7 @@ def predict(
421421
... --checkpoint <CKPT_PATH, str>
422422
```
423423
"""
424-
from otx.backend.native.models.utils.xai_utils import (
425-
process_saliency_maps_in_pred_entity,
426-
set_crop_padded_map_flag,
427-
)
424+
from otx.backend.native.models.utils.xai_utils import process_saliency_maps_in_pred_entity
428425

429426
model = self.model
430427

@@ -462,7 +459,6 @@ def predict(
462459
if explain:
463460
if explain_config is None:
464461
explain_config = ExplainConfig()
465-
explain_config = set_crop_padded_map_flag(explain_config, datamodule)
466462

467463
predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config, datamodule.label_info)
468464

@@ -548,91 +544,6 @@ def export(
548544
self.model.explain_mode = False
549545
return exported_model_path
550546

551-
def explain(
552-
self,
553-
checkpoint: PathLike | None = None,
554-
datamodule: EVAL_DATALOADERS | OTXDataModule | None = None,
555-
explain_config: ExplainConfig | None = None,
556-
**kwargs,
557-
) -> list | None:
558-
r"""Run XAI using the specified model and data (test subset).
559-
560-
Args:
561-
checkpoint (PathLike | None, optional): The path to the checkpoint file to load the model from.
562-
datamodule (EVAL_DATALOADERS | OTXDataModule | None, optional): The data module to use for predictions.
563-
explain_config (ExplainConfig | None, optional): Config used to handle saliency maps.
564-
**kwargs: Additional keyword arguments for pl.Trainer configuration.
565-
566-
Returns:
567-
list: Saliency maps.
568-
569-
Example:
570-
>>> engine.explain(
571-
... datamodule=OTXDataModule(),
572-
... checkpoint=<checkpoint/path>,
573-
... explain_config=ExplainConfig(),
574-
... )
575-
576-
CLI Usage:
577-
1. To run XAI with the torch model in work_dir, run
578-
```shell
579-
>>> otx explain \
580-
... --work_dir <WORK_DIR_PATH, str>
581-
```
582-
2. To run XAI using the specified model (torch or IR), run
583-
```shell
584-
>>> otx explain \
585-
... --work_dir <WORK_DIR_PATH, str> \
586-
... --checkpoint <CKPT_PATH, str>
587-
```
588-
3. To run XAI using the configuration, run
589-
```shell
590-
>>> otx explain \
591-
... --config <CONFIG_PATH> --data_root <DATASET_PATH, str> \
592-
... --checkpoint <CKPT_PATH, str>
593-
```
594-
"""
595-
from otx.backend.native.models.utils.xai_utils import (
596-
process_saliency_maps_in_pred_entity,
597-
set_crop_padded_map_flag,
598-
)
599-
600-
model = self.model
601-
602-
checkpoint = checkpoint if checkpoint is not None else self.checkpoint
603-
datamodule = datamodule if datamodule is not None else self.datamodule
604-
605-
if checkpoint is not None:
606-
ckpt = self._load_model_checkpoint(checkpoint, map_location="cpu")
607-
model.load_state_dict(ckpt)
608-
609-
if model.label_info != self.datamodule.label_info:
610-
msg = (
611-
"To launch a explain pipeline, the label information should be same "
612-
"between the training and testing datasets. "
613-
"Please check whether you use the same dataset: "
614-
f"model.label_info={model.label_info}, "
615-
f"datamodule.label_info={self.datamodule.label_info}"
616-
)
617-
raise ValueError(msg)
618-
619-
model.explain_mode = True
620-
621-
self._build_trainer(**kwargs)
622-
623-
predict_result = self.trainer.predict(
624-
model=model,
625-
datamodule=datamodule,
626-
)
627-
628-
if explain_config is None:
629-
explain_config = ExplainConfig()
630-
explain_config = set_crop_padded_map_flag(explain_config, datamodule)
631-
632-
predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config, datamodule.label_info)
633-
model.explain_mode = False
634-
return predict_result
635-
636547
def benchmark(
637548
self,
638549
checkpoint: PathLike | None = None,

src/otx/backend/native/models/detection/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def _customize_outputs(
211211
scores=scores,
212212
bboxes=bboxes,
213213
labels=labels,
214-
saliency_map=[saliency_map.detach().to(torch.float32) for saliency_map in outputs["saliency_map"]],
214+
saliency_map=outputs["saliency_map"],
215215
feature_vector=[
216216
feature_vector.detach().unsqueeze(0).to(torch.float32)
217217
for feature_vector in outputs["feature_vector"]

src/otx/backend/native/models/utils/xai_utils.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
if TYPE_CHECKING:
2727
from torch import LongTensor, Tensor
2828

29-
from otx.data.module import OTXDataModule
3029

3130
ProcessedSaliencyMaps = list[dict[str, np.ndarray | torch.Tensor]]
3231

@@ -37,13 +36,10 @@ def process_saliency_maps_in_pred_entity(
3736
label_info: LabelInfoTypes,
3837
) -> list[OTXPredBatch]:
3938
"""Process saliency maps in PredEntity."""
40-
41-
def _process(
42-
predict_result_per_batch: OTXPredBatch,
43-
label_info: LabelInfoTypes,
44-
) -> OTXPredBatch:
45-
if predict_result_per_batch.saliency_map is None: # skip empty saliency maps
46-
return predict_result_per_batch
39+
processed_predict_result = []
40+
for predict_result_per_batch in predict_result:
41+
if predict_result_per_batch.saliency_map is None or len(predict_result_per_batch.saliency_map) == 0:
42+
continue
4743

4844
# Extract batch data with proper type handling
4945
labels = predict_result_per_batch.labels if predict_result_per_batch.labels is not None else []
@@ -60,6 +56,7 @@ def _process(
6056
# Add additional conf threshold for saving maps with predicted classes,
6157
# since predictions can have less than 0.05 confidence
6258
conf_thr = explain_config.predicted_maps_conf_thr
59+
keep_ratio = imgs_info[0].keep_ratio # type: ignore[union-attr, index]
6360

6461
pred_labels = []
6562
for labels, scores in zip(predict_result_per_batch.labels, predict_result_per_batch.scores): # type: ignore[union-attr, arg-type]
@@ -81,11 +78,12 @@ def _process(
8178
ori_img_shapes,
8279
image_shape,
8380
paddings,
81+
keep_ratio,
8482
)
8583
predict_result_per_batch.saliency_map = processed_saliency_maps
86-
return predict_result_per_batch
84+
processed_predict_result.append(predict_result_per_batch)
8785

88-
return [_process(predict_result_per_batch, label_info) for predict_result_per_batch in predict_result]
86+
return processed_predict_result
8987

9088

9189
def process_saliency_maps(
@@ -95,6 +93,7 @@ def process_saliency_maps(
9593
ori_img_shapes: list,
9694
image_shape: tuple[int, int],
9795
paddings: list[tuple[int, int, int, int]],
96+
keep_ratio: bool,
9897
) -> ProcessedSaliencyMaps:
9998
"""Perform saliency map convertion to dict and post-processing."""
10099
if explain_config.target_explain_group == TargetExplainGroup.ALL:
@@ -107,7 +106,7 @@ def process_saliency_maps(
107106
msg = f"Target explain group {explain_config.target_explain_group} is not supported."
108107
raise ValueError(msg)
109108

110-
if explain_config.crop_padded_map:
109+
if keep_ratio:
111110
processed_saliency_maps = _crop_padded_map(processed_saliency_maps, image_shape, paddings)
112111

113112
if explain_config.postprocess:
@@ -221,12 +220,3 @@ def _convert_labels_from_hcls_format(
221220
pred_labels.append(label_info.label_to_idx[label_str])
222221

223222
return pred_labels
224-
225-
226-
def set_crop_padded_map_flag(explain_config: ExplainConfig, datamodule: OTXDataModule) -> ExplainConfig:
227-
"""If resize with keep_ratio = True was used, set crop_padded_map flag to True."""
228-
for transform in datamodule.test_subset.transforms:
229-
tranf_name = transform["class_path"].split(".")[-1]
230-
if tranf_name == "Resize" and transform["init_args"].get("keep_ratio", False):
231-
explain_config.crop_padded_map = True
232-
return explain_config

src/otx/backend/native/tools/explain/explain_algo.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""Algorithms for calculcalating XAI branch for Explainable AI."""
55

66
from __future__ import annotations
77

8+
import warnings
89
from typing import TYPE_CHECKING, Callable
910

1011
import torch
@@ -267,6 +268,14 @@ def __init__(
267268
self._num_classes = num_classes
268269
self._num_anchors = num_anchors
269270
# Should be switched off for tiling
271+
if num_classes == 1 and use_cls_softmax:
272+
# softmax would result in all 1.0 values if there's only 1 class
273+
warnings.warn(
274+
"use_cls_softmax is automatically disabled when num_classes=1 to prevent degenerate softmax behavior",
275+
UserWarning,
276+
stacklevel=2,
277+
)
278+
use_cls_softmax = False
270279
self.use_cls_softmax = use_cls_softmax
271280

272281
def func(

0 commit comments

Comments
 (0)