Skip to content

Commit 7088095

Browse files
authored
[FIX] Add stability to explain detection (#1901)
Add stability to explain detection
1 parent f7c2f45 commit 7088095

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

otx/algorithms/detection/adapters/mmdet/models/heads/custom_atss_head.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -172,18 +172,16 @@ def loss_single(
172172
pos_centerness = centerness[pos_inds]
173173

174174
centerness_targets = self.centerness_target(pos_anchors, pos_bbox_targets)
175-
pos_decode_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
176-
pos_decode_bbox_targets = self.bbox_coder.decode(pos_anchors, pos_bbox_targets)
175+
if self.reg_decoded_bbox:
176+
pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
177177

178178
if self.use_qfl:
179-
quality[pos_inds] = bbox_overlaps(
180-
pos_decode_bbox_pred.detach(), pos_decode_bbox_targets, is_aligned=True
181-
).clamp(min=1e-6)
179+
quality[pos_inds] = bbox_overlaps(pos_bbox_pred.detach(), pos_bbox_targets, is_aligned=True).clamp(
180+
min=1e-6
181+
)
182182

183183
# regression loss
184-
loss_bbox = self.loss_bbox(
185-
pos_decode_bbox_pred, pos_decode_bbox_targets, weight=centerness_targets, avg_factor=1.0
186-
)
184+
loss_bbox = self.loss_bbox(pos_bbox_pred, pos_bbox_targets, weight=centerness_targets, avg_factor=1.0)
187185

188186
# centerness loss
189187
loss_centerness = self.loss_centerness(pos_centerness, centerness_targets, avg_factor=num_total_samples)

otx/mpa/det/explainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33
#
44

5+
import torch
56
from mmcv.utils import Config, ConfigDict
67
from mmdet.datasets import build_dataloader as mmdet_build_dataloader
78
from mmdet.datasets import build_dataset as mmdet_build_dataset
@@ -153,7 +154,8 @@ def explain(self, cfg, model_builder=None):
153154
eval_predictions = []
154155
with self.explainer_hook(feature_model) as saliency_hook:
155156
for data in test_dataloader:
156-
result = model(return_loss=False, rescale=True, **data)
157+
with torch.no_grad():
158+
result = model(return_loss=False, rescale=True, **data)
157159
eval_predictions.extend(result)
158160
saliency_maps = saliency_hook.records
159161

0 commit comments

Comments
 (0)