Skip to content

Commit 88ab4b8

Browse files
authored
Add dummy XAI to RTDETR (export mode) & disable strong aug (#4106)
* Implement warning for unsupported explain mode in DETR model and update transform probabilities to zero in RTDETR recipes * update changelog * Update photometric distortion probability in RTDETR recipes
1 parent 0556ea6 commit 88ab4b8

File tree

5 files changed

+14
-14
lines changed

5 files changed

+14
-14
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ All notable changes to this project will be documented in this file.
118118
(<https://github.com/openvinotoolkit/training_extensions/pull/4082>)
119119
- Fix RTMDet Inst Explain Mode
120120
(<https://github.com/openvinotoolkit/training_extensions/pull/4083>)
121+
- Fix RTDETR Explain Mode
122+
(<https://github.com/openvinotoolkit/training_extensions/pull/4106>)
121123

122124
## \[v2.1.0\]
123125

src/otx/algo/detection/base_models/detection_transformer.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import warnings
89
from typing import Any
910

1011
import numpy as np
@@ -95,16 +96,22 @@ def export(
9596
explain_mode: bool = False,
9697
) -> dict[str, Any] | tuple[list[Any], list[Any], list[Any]]:
9798
"""Exports the model."""
98-
if explain_mode:
99-
msg = "Explain mode is not supported for DETR models yet."
100-
raise NotImplementedError(msg)
101-
102-
return self.postprocess(
99+
results = self.postprocess(
103100
self._forward_features(batch_inputs),
104101
[meta["img_shape"] for meta in batch_img_metas],
105102
deploy_mode=True,
106103
)
107104

105+
if explain_mode:
106+
# TODO(Eugene): Implement explain mode for DETR model.
107+
warnings.warn("Explain mode is not supported for DETR model. Return dummy values.", stacklevel=2)
108+
xai_output = {
109+
"feature_vector": torch.zeros(1, 1),
110+
"saliency_map": torch.zeros(1),
111+
}
112+
results.update(xai_output) # type: ignore[union-attr]
113+
return results
114+
108115
def postprocess(
109116
self,
110117
outputs: dict[str, Tensor],

src/otx/recipe/detection/rtdetr_101.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ overrides:
5454
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
5555
init_args:
5656
p: 0.5
57-
- class_path: torchvision.transforms.v2.RandomZoomOut
58-
init_args:
59-
fill: 0
6057
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
6158
init_args:
6259
prob: 0.5

src/otx/recipe/detection/rtdetr_18.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,6 @@ overrides:
5353
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
5454
init_args:
5555
p: 0.5
56-
- class_path: torchvision.transforms.v2.RandomZoomOut
57-
init_args:
58-
fill: 0
5956
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
6057
init_args:
6158
prob: 0.5

src/otx/recipe/detection/rtdetr_50.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@ overrides:
5454
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
5555
init_args:
5656
p: 0.5
57-
- class_path: torchvision.transforms.v2.RandomZoomOut
58-
init_args:
59-
fill: 0
6057
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
6158
init_args:
6259
prob: 0.5

0 commit comments

Comments
 (0)