|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | from abc import ABC, abstractmethod |
18 | | -from typing import List, Sequence, Tuple, Union |
| 18 | +from typing import Sequence, Union |
19 | 19 |
|
20 | 20 | import torch |
21 | | -import torch.nn.functional as F |
22 | | -from mmcls.models.necks.gap import GlobalAveragePooling |
23 | 21 |
|
24 | | -from otx.mpa.modules.models.heads.custom_atss_head import CustomATSSHead |
25 | | -from otx.mpa.modules.models.heads.custom_ssd_head import CustomSSDHead |
26 | | -from otx.mpa.modules.models.heads.custom_vfnet_head import CustomVFNetHead |
27 | | -from otx.mpa.modules.models.heads.custom_yolox_head import CustomYOLOXHead |
| 22 | +from otx import MMCLS_AVAILABLE |
| 23 | + |
| 24 | +if MMCLS_AVAILABLE: |
| 25 | + from mmcls.models.necks.gap import GlobalAveragePooling |
28 | 26 |
|
29 | 27 |
|
30 | 28 | class BaseRecordingForwardHook(ABC): |
@@ -130,103 +128,6 @@ def func(feature_map: Union[torch.Tensor, Sequence[torch.Tensor]]) -> torch.Tens |
130 | 128 | return feature_vector |
131 | 129 |
|
132 | 130 |
|
133 | | -class DetSaliencyMapHook(BaseRecordingForwardHook): |
134 | | - """Saliency map hook for object detection models.""" |
135 | | - |
136 | | - def __init__(self, module: torch.nn.Module) -> None: |
137 | | - super().__init__(module) |
138 | | - self._neck = module.neck if module.with_neck else None |
139 | | - self._bbox_head = module.bbox_head |
140 | | - self._num_cls_out_channels = module.bbox_head.cls_out_channels # SSD-like heads also have background class |
141 | | - if hasattr(module.bbox_head, "anchor_generator"): |
142 | | - self._num_anchors = module.bbox_head.anchor_generator.num_base_anchors |
143 | | - else: |
144 | | - self._num_anchors = [1] * 10 |
145 | | - |
146 | | - def func( |
147 | | - self, |
148 | | - x: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], |
149 | | - _: int = -1, |
150 | | - cls_scores_provided: bool = False, |
151 | | - ) -> torch.Tensor: |
152 | | - """ |
153 | | - Generate the saliency map from raw classification head output, then normalizing to (0, 255). |
154 | | -
|
155 | | - :param x: Feature maps from backbone/FPN or classification scores from cls_head |
156 | | - :param cls_scores_provided: If True - use 'x' as is, otherwise forward 'x' through the classification head |
157 | | - :return: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W] |
158 | | - """ |
159 | | - if cls_scores_provided: |
160 | | - cls_scores = x |
161 | | - else: |
162 | | - cls_scores = self._get_cls_scores_from_feature_map(x) |
163 | | - |
164 | | - bs, _, h, w = cls_scores[-1].size() |
165 | | - saliency_maps = torch.empty(bs, self._num_cls_out_channels, h, w) |
166 | | - for batch_idx in range(bs): |
167 | | - cls_scores_anchorless = [] |
168 | | - for scale_idx, cls_scores_per_scale in enumerate(cls_scores): |
169 | | - cls_scores_anchor_grouped = cls_scores_per_scale[batch_idx].reshape( |
170 | | - self._num_anchors[scale_idx], (self._num_cls_out_channels), *cls_scores_per_scale.shape[-2:] |
171 | | - ) |
172 | | - cls_scores_out, _ = cls_scores_anchor_grouped.max(dim=0) |
173 | | - cls_scores_anchorless.append(cls_scores_out.unsqueeze(0)) |
174 | | - cls_scores_anchorless_resized = [] |
175 | | - for cls_scores_anchorless_per_level in cls_scores_anchorless: |
176 | | - cls_scores_anchorless_resized.append( |
177 | | - F.interpolate(cls_scores_anchorless_per_level, (h, w), mode="bilinear") |
178 | | - ) |
179 | | - saliency_maps[batch_idx] = torch.cat(cls_scores_anchorless_resized, dim=0).mean(dim=0) |
180 | | - |
181 | | - saliency_maps = saliency_maps.reshape((bs, self._num_cls_out_channels, -1)) |
182 | | - max_values, _ = torch.max(saliency_maps, -1) |
183 | | - min_values, _ = torch.min(saliency_maps, -1) |
184 | | - saliency_maps = 255 * (saliency_maps - min_values[:, :, None]) / (max_values - min_values + 1e-12)[:, :, None] |
185 | | - saliency_maps = saliency_maps.reshape((bs, self._num_cls_out_channels, h, w)) |
186 | | - saliency_maps = saliency_maps.to(torch.uint8) |
187 | | - return saliency_maps |
188 | | - |
189 | | - def _get_cls_scores_from_feature_map(self, x: torch.Tensor) -> List: |
190 | | - """Forward features through the classification head of the detector.""" |
191 | | - with torch.no_grad(): |
192 | | - if self._neck is not None: |
193 | | - x = self._neck(x) |
194 | | - |
195 | | - if isinstance(self._bbox_head, CustomSSDHead): |
196 | | - cls_scores = [] |
197 | | - for feat, cls_conv in zip(x, self._bbox_head.cls_convs): |
198 | | - cls_scores.append(cls_conv(feat)) |
199 | | - elif isinstance(self._bbox_head, CustomATSSHead): |
200 | | - cls_scores = [] |
201 | | - for cls_feat in x: |
202 | | - for cls_conv in self._bbox_head.cls_convs: |
203 | | - cls_feat = cls_conv(cls_feat) |
204 | | - cls_score = self._bbox_head.atss_cls(cls_feat) |
205 | | - cls_scores.append(cls_score) |
206 | | - elif isinstance(self._bbox_head, CustomVFNetHead): |
207 | | - # Not clear how to separate cls_scores from bbox_preds |
208 | | - cls_scores, _, _ = self._bbox_head(x) |
209 | | - elif isinstance(self._bbox_head, CustomYOLOXHead): |
210 | | - |
211 | | - def forward_single(x, cls_convs, conv_cls): |
212 | | - """Forward feature of a single scale level.""" |
213 | | - cls_feat = cls_convs(x) |
214 | | - cls_score = conv_cls(cls_feat) |
215 | | - return cls_score |
216 | | - |
217 | | - map_results = map( |
218 | | - forward_single, x, self._bbox_head.multi_level_cls_convs, self._bbox_head.multi_level_conv_cls |
219 | | - ) |
220 | | - cls_scores = list(map_results) |
221 | | - else: |
222 | | - raise NotImplementedError( |
223 | | - "Not supported detection head provided. " |
224 | | - "DetSaliencyMapHook supports only the following single stage detectors: " |
225 | | - "YOLOXHead, ATSSHead, SSDHead, VFNetHead." |
226 | | - ) |
227 | | - return cls_scores |
228 | | - |
229 | | - |
230 | 131 | class ReciproCAMHook(BaseRecordingForwardHook): |
231 | 132 | """ |
232 | 133 | Implementation of recipro-cam for class-wise saliency map |
@@ -280,7 +181,7 @@ def _predict_from_feature_map(self, x: torch.Tensor) -> torch.Tensor: |
280 | 181 | return logits |
281 | 182 |
|
282 | 183 | def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: int) -> torch.Tensor: |
283 | | - if self._neck is not None and isinstance(self._neck, GlobalAveragePooling): |
| 184 | + if MMCLS_AVAILABLE and self._neck is not None and isinstance(self._neck, GlobalAveragePooling): |
284 | 185 | """ |
285 | 186 | Optimization workaround for the GAP case (simulate GAP with more simple compute graph) |
286 | 187 | Possible due to static sparsity of mosaic_feature_map |
|
0 commit comments