Skip to content

Commit decfdbe

Browse files
authored
Enable export of feature vectors for semantic segmentation task (#4055)
1 parent a837a1d commit decfdbe

File tree

8 files changed

+88
-29
lines changed

8 files changed

+88
-29
lines changed

src/otx/algo/segmentation/huggingface_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,8 @@ def _exporter(self) -> OTXModelExporter:
162162

163163
def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
164164
"""Model forward function used for the model tracing during model exportation."""
165+
if self.explain_mode:
166+
msg = "Explain mode is not supported for this model."
167+
raise NotImplementedError(msg)
168+
165169
return self.model(image)

src/otx/algo/segmentation/litehrnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _exporter(self) -> OTXModelExporter:
8181
swap_rgb=False,
8282
via_onnx=False,
8383
onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK},
84-
output_names=None,
84+
output_names=["preds", "feature_vector"] if self.explain_mode else None,
8585
)
8686

8787
@property

src/otx/algo/segmentation/segmentors/base_model.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import torch.nn.functional as f
1111
from torch import Tensor, nn
1212

13+
from otx.algo.explain.explain_algo import feature_vector_fn
14+
1315
if TYPE_CHECKING:
1416
from otx.core.data.entity.base import ImageInfo
1517

@@ -58,7 +60,7 @@ def forward(
5860
- If mode is "predict", returns the predicted outputs.
5961
- Otherwise, returns the model outputs after interpolation.
6062
"""
61-
outputs = self.extract_features(inputs)
63+
enc_feats, outputs = self.extract_features(inputs)
6264
outputs = f.interpolate(outputs, size=inputs.size()[2:], mode="bilinear", align_corners=True)
6365

6466
if mode == "tensor":
@@ -76,12 +78,19 @@ def forward(
7678
if mode == "predict":
7779
return outputs.argmax(dim=1)
7880

81+
if mode == "explain":
82+
feature_vector = feature_vector_fn(enc_feats)
83+
return {
84+
"preds": outputs,
85+
"feature_vector": feature_vector,
86+
}
87+
7988
return outputs
8089

81-
def extract_features(self, inputs: Tensor) -> Tensor:
90+
def extract_features(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
8291
"""Extract features from the backbone and head."""
8392
enc_feats = self.backbone(inputs)
84-
return self.decode_head(enc_feats)
93+
return enc_feats, self.decode_head(enc_feats)
8594

8695
def calculate_loss(
8796
self,

src/otx/core/model/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def __init__(
124124
self.input_size = input_size
125125
self.classification_layers: dict[str, dict[str, Any]] = {}
126126
self.model = self._create_model()
127-
self._explain_mode = False
128127
self.optimizer_callable = ensure_callable(optimizer)
129128
self.scheduler_callable = ensure_callable(scheduler)
130129
self.metric_callable = ensure_callable(metric)

src/otx/core/model/segmentation.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,12 @@ def _build_model(self) -> nn.Module:
119119
"""
120120

121121
def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
122-
mode = "loss" if self.training else "predict"
122+
if self.training:
123+
mode = "loss"
124+
elif self.explain_mode:
125+
mode = "explain"
126+
else:
127+
mode = "predict"
123128

124129
if self.train_type == OTXTrainType.SEMI_SUPERVISED and mode == "loss":
125130
if not isinstance(entity, dict):
@@ -155,6 +160,16 @@ def _customize_outputs(
155160
losses[k] = v
156161
return losses
157162

163+
if self.explain_mode:
164+
return SegBatchPredEntity(
165+
batch_size=len(outputs["preds"]),
166+
images=inputs.images,
167+
imgs_info=inputs.imgs_info,
168+
scores=[],
169+
masks=outputs["preds"],
170+
feature_vector=outputs["feature_vector"],
171+
)
172+
158173
return SegBatchPredEntity(
159174
batch_size=len(outputs),
160175
images=inputs.images,
@@ -199,14 +214,24 @@ def _exporter(self) -> OTXModelExporter:
199214
swap_rgb=False,
200215
via_onnx=False,
201216
onnx_export_configuration=None,
202-
output_names=None,
217+
output_names=["preds", "feature_vector"] if self.explain_mode else None,
203218
)
204219

205220
def _convert_pred_entity_to_compute_metric(
206221
self,
207222
preds: SegBatchPredEntity,
208223
inputs: SegBatchDataEntity,
209224
) -> MetricInput:
225+
"""Convert prediction and input entities to a format suitable for metric computation.
226+
227+
Args:
228+
preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
229+
inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.
230+
231+
Returns:
232+
MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
233+
corresponding to the predicted and target masks for metric evaluation.
234+
"""
210235
return [
211236
{
212237
"preds": pred_mask,
@@ -228,8 +253,26 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:
228253

229254
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
230255
"""Model forward function used for the model tracing during model exportation."""
231-
raw_outputs = self.model(inputs=image, mode="tensor")
232-
return torch.softmax(raw_outputs, dim=1)
256+
if self.explain_mode:
257+
outputs = self.model(inputs=image, mode="explain")
258+
outputs["preds"] = torch.softmax(outputs["preds"], dim=1)
259+
return outputs
260+
261+
outputs = self.model(inputs=image, mode="tensor")
262+
return torch.softmax(outputs, dim=1)
263+
264+
def forward_explain(self, inputs: SegBatchDataEntity) -> SegBatchPredEntity:
265+
"""Model forward explain function."""
266+
outputs = self.model(inputs=inputs.images, mode="explain")
267+
268+
return SegBatchPredEntity(
269+
batch_size=len(outputs["preds"]),
270+
images=inputs.images,
271+
imgs_info=inputs.imgs_info,
272+
scores=[],
273+
masks=outputs["preds"],
274+
feature_vector=outputs["feature_vector"],
275+
)
233276

234277
def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity:
235278
"""Returns a dummy input for semantic segmentation model."""
@@ -308,32 +351,34 @@ def _customize_outputs(
308351
outputs: list[ImageResultWithSoftPrediction],
309352
inputs: SegBatchDataEntity,
310353
) -> SegBatchPredEntity | OTXBatchLossEntity:
311-
if outputs and outputs[0].saliency_map.size != 1:
312-
predicted_s_maps = [out.saliency_map for out in outputs]
313-
predicted_f_vectors = [out.feature_vector for out in outputs]
314-
return SegBatchPredEntity(
315-
batch_size=len(outputs),
316-
images=inputs.images,
317-
imgs_info=inputs.imgs_info,
318-
scores=[],
319-
masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs],
320-
saliency_map=predicted_s_maps,
321-
feature_vector=predicted_f_vectors,
322-
)
323-
354+
masks = [tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs]
355+
predicted_f_vectors = (
356+
[out.feature_vector for out in outputs] if outputs and outputs[0].feature_vector.size != 1 else []
357+
)
324358
return SegBatchPredEntity(
325359
batch_size=len(outputs),
326360
images=inputs.images,
327361
imgs_info=inputs.imgs_info,
328362
scores=[],
329-
masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs],
363+
masks=masks,
364+
feature_vector=predicted_f_vectors,
330365
)
331366

332367
def _convert_pred_entity_to_compute_metric(
333368
self,
334369
preds: SegBatchPredEntity,
335370
inputs: SegBatchDataEntity,
336371
) -> MetricInput:
372+
"""Convert prediction and input entities to a format suitable for metric computation.
373+
374+
Args:
375+
preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
376+
inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.
377+
378+
Returns:
379+
MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
380+
corresponding to the predicted and target masks for metric evaluation.
381+
"""
337382
return [
338383
{
339384
"preds": pred_mask,

tests/e2e/cli/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ def test_otx_e2e_cli(
220220
# 5) otx export with XAI
221221
if "instance_segmentation/rtmdet_inst_tiny" in recipe:
222222
return
223-
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
224-
return # Supported only for classification, detection and instance segmentation task.
223+
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]):
224+
return # Supported only for classification, detection and segmentation tasks.
225225

226226
unsupported_models = ["dino", "rtdetr"]
227227
if any(model in model_name for model in unsupported_models):

tests/integration/cli/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def test_otx_e2e(
242242
# 5) otx export with XAI
243243
if "instance_segmentation/rtmdet_inst_tiny" in recipe:
244244
return
245-
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
246-
return # Supported only for classification, detection and instance segmentation task.
245+
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]):
246+
return # Supported only for classification, detection and segmentation tasks.
247247

248248
if "dino" in model_name:
249249
return # DINO is not supported.

tests/unit/algo/segmentation/segmentors/test_base_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,10 @@ def test_forward_returns_prediction(self, model, inputs):
4343
def test_extract_features(self, model, inputs):
4444
images = inputs[0]
4545
features = model.extract_features(images)
46-
assert isinstance(features, torch.Tensor)
47-
assert features.shape == (1, 2, 256, 256)
46+
assert isinstance(features, tuple)
47+
assert isinstance(features[0], torch.Tensor)
48+
assert isinstance(features[1], torch.Tensor)
49+
assert features[1].shape == (1, 2, 256, 256)
4850

4951
def test_calculate_loss(self, model, inputs):
5052
model.criterion.name = "CrossEntropyLoss"

0 commit comments

Comments
 (0)