Skip to content

Commit 7040faf

Browse files
authored
Remove background label from RT Info for segmentation task (#4011)
* remove background from rt_info * provide another solution * fix unit test
1 parent 7744c89 commit 7040faf

File tree

5 files changed

+31
-18
lines changed

5 files changed

+31
-18
lines changed

src/otx/core/data/dataset/segmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _extract_class_mask(item: DatasetItem, img_shape: tuple[int, int], ignore_in
9898
msg = "It is not currently support an ignore index which is more than 255."
9999
raise ValueError(msg, ignore_index)
100100

101-
# fill mask with background label if we have Polygon/Ellipse annotations
101+
# fill mask with background label if we have Polygon/Ellipse/Bbox annotations
102102
fill_value = 0 if isinstance(item.annotations[0], (Ellipse, Polygon, Bbox, RotatedBbox)) else ignore_index
103103
class_mask = np.full(shape=img_shape[:2], fill_value=fill_value, dtype=np.uint8)
104104

@@ -179,9 +179,9 @@ def __init__(
179179
to_tv_image,
180180
)
181181

182-
if self.has_polygons and "background" not in [label_name.lower() for label_name in self.label_info.label_names]:
182+
if self.has_polygons:
183183
# insert background class at index 0 since polygons represent only objects
184-
self.label_info.label_names.insert(0, "background")
184+
self.label_info.label_names.insert(0, "otx_background_lbl")
185185

186186
self.label_info = SegLabelInfo(
187187
label_names=self.label_info.label_names,

src/otx/core/model/base.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,11 +1095,6 @@ def model_adapter_parameters(self) -> dict:
10951095
def _set_label_info(self, label_info: LabelInfoTypes) -> None:
10961096
"""Set this model label information."""
10971097
new_label_info = self._dispatch_label_info(label_info)
1098-
1099-
if self._label_info != new_label_info:
1100-
msg = "OVModel strictly does not allow overwrite label_info if they are different each other."
1101-
raise ValueError(msg)
1102-
11031098
self._label_info = new_label_info
11041099

11051100
def _create_label_info_from_ov_ir(self) -> LabelInfo:

src/otx/core/model/segmentation.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import copy
89
import json
910
from abc import abstractmethod
1011
from collections.abc import Sequence
@@ -165,12 +166,20 @@ def _customize_outputs(
165166
@property
166167
def _export_parameters(self) -> TaskLevelExportParameters:
167168
"""Defines parameters required to export a particular model implementation."""
169+
if self.label_info.label_names[0] == "otx_background_lbl":
170+
# remove otx background label for export
171+
modified_label_info = copy.deepcopy(self.label_info)
172+
modified_label_info.label_names.pop(0)
173+
else:
174+
modified_label_info = self.label_info
175+
168176
return super()._export_parameters.wrap(
169177
model_type="Segmentation",
170178
task_type="segmentation",
171179
return_soft_prediction=True,
172180
soft_threshold=0.5,
173181
blur_strength=-1,
182+
label_info=modified_label_info,
174183
)
175184

176185
@property

src/otx/engine/engine.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
import copy
89
import csv
910
import inspect
1011
import logging
@@ -370,14 +371,22 @@ def test(
370371
model = model_cls.load_from_checkpoint(checkpoint_path=checkpoint, **model.hparams)
371372

372373
if model.label_info != self.datamodule.label_info:
373-
msg = (
374-
"To launch a test pipeline, the label information should be same "
375-
"between the training and testing datasets. "
376-
"Please check whether you use the same dataset: "
377-
f"model.label_info={model.label_info}, "
378-
f"datamodule.label_info={self.datamodule.label_info}"
379-
)
380-
raise ValueError(msg)
374+
if (
375+
self.task == "SEMANTIC_SEGMENTATION"
376+
and "otx_background_lbl" in self.datamodule.label_info.label_names
377+
and (len(self.datamodule.label_info.label_names) - len(model.label_info.label_names) == 1)
378+
):
379+
# workaround for background label
380+
model.label_info = copy.deepcopy(self.datamodule.label_info)
381+
else:
382+
msg = (
383+
"To launch a test pipeline, the label information should be same "
384+
"between the training and testing datasets. "
385+
"Please check whether you use the same dataset: "
386+
f"model.label_info={model.label_info}, "
387+
f"datamodule.label_info={self.datamodule.label_info}"
388+
)
389+
raise ValueError(msg)
381390

382391
self._build_trainer(**kwargs)
383392

tests/unit/core/data/dataset/test_segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_get_item(
1919
max_refetch=3,
2020
)
2121
assert isinstance(dataset[0], SegDataEntity)
22-
assert "background" in [label_name.lower() for label_name in dataset.label_info.label_names]
22+
assert "otx_background_lbl" in [label_name.lower() for label_name in dataset.label_info.label_names]
2323

2424
def test_get_item_from_bbox_dataset(
2525
self,
@@ -33,4 +33,4 @@ def test_get_item_from_bbox_dataset(
3333
)
3434
assert isinstance(dataset[0], SegDataEntity)
3535
# OTXSegmentationDataset should add background when getting a dataset which includes only bbox annotations
36-
assert "background" in [label_name.lower() for label_name in dataset.label_info.label_names]
36+
assert "otx_background_lbl" in [label_name.lower() for label_name in dataset.label_info.label_names]

0 commit comments

Comments
 (0)