Skip to content

Commit 6fb59db

Browse files
authored
Bug fix: difference between "otx eval" score and evaluation while training (#2307)
* sort labels in right order * fix data adapter * refine bugfix * refine bugfix * implement unit test * align with pre-commit
1 parent b3fc87d commit 6fb59db

File tree

6 files changed

+23
-10
lines changed

6 files changed

+23
-10
lines changed

src/otx/algorithms/segmentation/adapters/openvino/task.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,6 @@ def add_prediction(
260260

261261
if dump_soft_prediction:
262262
for label_index, label in self._label_dictionary.items():
263-
if label_index == 0:
264-
continue
265263
current_label_soft_prediction = soft_prediction[:, :, label_index]
266264
if process_soft_prediction:
267265
current_label_soft_prediction = get_activation_map(current_label_soft_prediction)

src/otx/algorithms/segmentation/task.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def __init__(self, task_environment: TaskEnvironment, output_path: Optional[str]
9191
self._model_name = task_environment.model_template.name
9292
self._train_type = self._hyperparams.algo_backend.train_type
9393
self.metric = "mDice"
94-
self._label_dictionary = dict(enumerate(sorted(self._labels), 1))
94+
self._label_dictionary = dict(enumerate(self._labels, 1)) # It should have same order as model class order
9595

9696
self._model_dir = os.path.join(
9797
os.path.abspath(os.path.dirname(self._task_environment.model_template.model_template_path)),
@@ -287,8 +287,6 @@ def _add_predictions_to_dataset(self, prediction_results, dataset, dump_soft_pre
287287

288288
if dump_soft_prediction:
289289
for label_index, label in self._label_dictionary.items():
290-
if label_index == 0:
291-
continue
292290
current_label_soft_prediction = soft_prediction[:, :, label_index]
293291
if process_soft_prediction:
294292
current_label_soft_prediction = get_activation_map(current_label_soft_prediction)

src/otx/api/usecases/evaluation/dice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def __compute_dice_averaged_over_pixels(
114114
resultset_labels = set(resultset.prediction_dataset.get_labels() + resultset.ground_truth_dataset.get_labels())
115115
model_labels = set(resultset.model.configuration.get_label_schema().get_labels(include_empty=False))
116116
labels = sorted(resultset_labels.intersection(model_labels))
117-
labels_map = {label: i + 1 for i, label in enumerate(labels)}
118117
hard_predictions = []
119118
hard_references = []
120119
for prediction_item, reference_item in zip(
@@ -126,6 +125,8 @@ def __compute_dice_averaged_over_pixels(
126125
except:
127126
# when item consists of masks with Image properties
128127
# TODO (sungchul): how to add condition to check if polygon or mask?
128+
labels_map = {label: i + 1 for i, label in enumerate(labels)}
129+
129130
def combine_masks(annotations):
130131
combined_mask = None
131132
for annotation in annotations:

src/otx/core/data/adapter/segmentation_dataset_adapter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,9 @@ def get_otx_dataset(self) -> DatasetEntity:
5353
self.updated_label_id: Dict[int, int] = {}
5454

5555
if hasattr(self, "data_type_candidates"):
56-
if self.data_type_candidates[0] == "voc":
56+
if "voc" in self.data_type_candidates[0]:
5757
self.set_voc_labels()
58-
59-
if self.data_type_candidates[0] == "common_semantic_segmentation":
58+
elif self.data_type_candidates[0] == "common_semantic_segmentation":
6059
self.set_common_labels()
6160

6261
else:

tests/unit/algorithms/segmentation/adapters/mmseg/datasets/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_classes_sorted(self, mocker) -> None:
6767
self.otx_dataset: DatasetEntity = DatasetEntity(items=[dataset_item()])
6868
self.pipeline: list[dict] = [{"type": "LoadImageFromOTXDataset", "to_float32": True}]
6969
self.classes: list[str] = [f"class_{i+1}" for i in range(11)]
70-
labels_entities = [label_entity(name, i) for i, name in enumerate(self.classes)]
70+
labels_entities = [label_entity(name, str(i)) for i, name in enumerate(self.classes)]
7171

7272
mocker.patch.object(MPASegDataset, "filter_labels", return_value=labels_entities)
7373

tests/unit/algorithms/segmentation/adapters/test_otx_segmentation_task.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import pytest
99
from mmcv import ConfigDict
1010

11+
from otx.api.entities.label import LabelEntity, Domain
12+
from otx.api.entities.id import ID
1113
from otx.algorithms.segmentation.adapters.mmseg.task import MMSegmentationTask
1214
from otx.algorithms.segmentation.adapters.mmseg.models.heads import otx_head_factory
1315
from otx.api.configuration.helper import create
@@ -67,3 +69,18 @@ def test_save_model(self, otx_model, mocker):
6769

6870
mocker_load.assert_called_once()
6971
mocker_save.assert_called_once()
72+
73+
@e2e_pytest_unit
74+
def test_label_order(self, mocker):
75+
mocker.patch("otx.algorithms.segmentation.task.os")
76+
mocker.patch("otx.algorithms.segmentation.task.TRAIN_TYPE_DIR_PATH")
77+
mock_environemnt = mocker.MagicMock()
78+
79+
fake_label = []
80+
for i in range(20):
81+
fake_label.append(LabelEntity(name=f"class_{i}", domain=Domain.SEGMENTATION, id=ID(str(i))))
82+
mock_environemnt.get_labels.return_value = fake_label
83+
task = MMSegmentationTask(mock_environemnt)
84+
85+
for i, label_entity in task._label_dictionary.items():
86+
assert label_entity.name == f"class_{i-1}"

0 commit comments

Comments
 (0)