Skip to content

Commit ab56cdf

Browse files
authored
Filter invalid annotation by task (#4515)
* Add task parameter to pre-filtering and enhance annotation validation logic * fix unit test
1 parent 17d2efb commit ab56cdf

File tree

5 files changed

+210
-22
lines changed

5 files changed

+210
-22
lines changed

src/otx/data/module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __init__(
105105
dataset,
106106
self.data_format,
107107
self.unannotated_items_ratio,
108+
self.task,
108109
ignore_index=self.ignore_index if self.task == "SEMANTIC_SEGMENTATION" else None,
109110
)
110111
if isinstance(input_size, str) and input_size == "auto":

src/otx/data/utils/pre_filtering.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""Pre filtering data for OTX."""
@@ -7,11 +7,14 @@
77

88
import secrets
99
import warnings
10+
from functools import partial
1011
from typing import TYPE_CHECKING
1112

12-
from datumaro.components.annotation import Annotation, Bbox, Polygon
13+
from datumaro.components.annotation import Annotation, Bbox, Ellipse, Polygon
1314
from datumaro.components.dataset import Dataset as DmDataset
1415

16+
from otx.types.task import OTXTaskType
17+
1518
if TYPE_CHECKING:
1619
from datumaro.components.dataset_base import DatasetItem
1720

@@ -20,6 +23,7 @@ def pre_filtering(
2023
dataset: DmDataset,
2124
data_format: str,
2225
unannotated_items_ratio: float,
26+
task: OTXTaskType,
2327
ignore_index: int | None = None,
2428
) -> DmDataset:
2529
"""Pre-filtering function to filter the dataset based on certain criteria.
@@ -29,6 +33,7 @@ def pre_filtering(
2933
data_format (str): The format of the dataset.
3034
unannotated_items_ratio (float): The ratio of background unannotated items to be used.
3135
This must be a float between 0 and 1.
36+
task (OTXTaskType): The task type of the dataset.
3237
ignore_index (int | None, optional): The index to be used for the ignored label. Defaults to None.
3338
3439
Returns:
@@ -37,7 +42,7 @@ def pre_filtering(
3742
used_background_items = set()
3843
msg = f"There are empty annotation items in train set, Of these, only {unannotated_items_ratio*100}% are used."
3944
warnings.warn(msg, stacklevel=2)
40-
dataset = DmDataset.filter(dataset, is_valid_annot, filter_annotations=True)
45+
dataset = DmDataset.filter(dataset, partial(is_valid_anno_for_task, task=task), filter_annotations=True)
4146
dataset = remove_unused_labels(dataset, data_format, ignore_index)
4247
if unannotated_items_ratio > 0:
4348
empty_items = [
@@ -77,6 +82,27 @@ def is_valid_annot(item: DatasetItem, annotation: Annotation) -> bool: # noqa:
7782
return True
7883

7984

85+
def is_valid_anno_for_task(item: DatasetItem, annotation: Annotation, task: OTXTaskType) -> bool:
86+
"""Return whether DatasetItem's annotation is valid for a specific task.
87+
88+
Args:
89+
item (DatasetItem): The item to be checked.
90+
annotation (Annotation): The annotation to be checked.
91+
task (OTXTaskType): The task type of the dataset.
92+
93+
Returns:
94+
bool: True if the annotation is valid for the task, False otherwise.
95+
"""
96+
if task == OTXTaskType.DETECTION:
97+
return isinstance(annotation, Bbox) and is_valid_annot(item, annotation)
98+
99+
# Rotated detection is a subset of instance segmentation
100+
if task in [OTXTaskType.INSTANCE_SEGMENTATION, OTXTaskType.ROTATED_DETECTION]:
101+
return isinstance(annotation, (Polygon, Bbox, Ellipse)) and is_valid_annot(item, annotation)
102+
103+
return is_valid_annot(item, annotation)
104+
105+
80106
def remove_unused_labels(
81107
dataset: DmDataset,
82108
data_format: str,

tests/unit/data/test_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ def func(
8181
dataset: DmDataset,
8282
data_format: str,
8383
unannotated_items_ratio: float,
84+
task: OTXTaskType,
8485
ignore_index: int | None,
8586
) -> DmDataset:
8687
del data_format
8788
del unannotated_items_ratio
8889
del ignore_index
90+
del task
8991
return dataset
9092

9193
return mocker.patch("otx.data.module.pre_filtering", side_effect=func)

tests/unit/data/test_pre_filtering.py

Lines changed: 163 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
# Copyright (C) 2024 Intel Corporation
1+
# Copyright (C) 2024-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
import pytest
5-
from datumaro.components.annotation import AnnotationType, Bbox, Label, Polygon
5+
from datumaro.components.annotation import AnnotationType, Bbox, Ellipse, Label, Polygon
66
from datumaro.components.dataset import Dataset as DmDataset
77
from datumaro.components.dataset_base import DatasetItem
88

9-
from otx.data.utils.pre_filtering import pre_filtering
9+
from otx.data.utils.pre_filtering import is_valid_anno_for_task, pre_filtering
10+
from otx.types.task import OTXTaskType
1011

1112

1213
@pytest.fixture()
@@ -80,7 +81,166 @@ def test_pre_filtering(fxt_dm_dataset_with_unannotated: DmDataset, unannotated_i
8081
filtered_dataset = pre_filtering(
8182
dataset=fxt_dm_dataset_with_unannotated,
8283
data_format="datumaro",
84+
task=OTXTaskType.MULTI_CLASS_CLS,
8385
unannotated_items_ratio=unannotated_items_ratio,
8486
)
8587
assert len(filtered_dataset) == 82 + int(len(empty_items) * unannotated_items_ratio)
8688
assert len(filtered_dataset.categories()[AnnotationType.label]) == 3
89+
90+
91+
@pytest.fixture()
92+
def fxt_dataset_item() -> DatasetItem:
93+
"""Create a sample dataset item for testing."""
94+
return DatasetItem(
95+
id="test_item",
96+
subset="train",
97+
media=None,
98+
annotations=[],
99+
)
100+
101+
102+
class TestIsValidAnnoForTask:
103+
"""Test cases for is_valid_anno_for_task function."""
104+
105+
@pytest.mark.parametrize(
106+
("task", "annotation", "expected"),
107+
[
108+
# DETECTION task tests
109+
(OTXTaskType.DETECTION, Bbox(x=0, y=0, w=10, h=10, label=0), True),
110+
(OTXTaskType.DETECTION, Bbox(x=0, y=0, w=-1, h=-1, label=0), False), # Invalid bbox
111+
(OTXTaskType.DETECTION, Bbox(x=10, y=10, w=5, h=5, label=0), True),
112+
(OTXTaskType.DETECTION, Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0), False), # Wrong type
113+
(OTXTaskType.DETECTION, Ellipse(x1=0, y1=0, x2=10, y2=10, label=0), False),
114+
(OTXTaskType.DETECTION, Label(label=0), False), # Wrong type
115+
# INSTANCE_SEGMENTATION task tests
116+
(OTXTaskType.INSTANCE_SEGMENTATION, Bbox(x=0, y=0, w=10, h=10, label=0), True),
117+
(OTXTaskType.INSTANCE_SEGMENTATION, Bbox(x=0, y=0, w=-1, h=-1, label=0), False), # Invalid bbox
118+
(OTXTaskType.INSTANCE_SEGMENTATION, Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0), True),
119+
(OTXTaskType.INSTANCE_SEGMENTATION, Polygon(points=[0, 0, 0, 0, 0, 0], label=0), False), # Invalid polygon
120+
(OTXTaskType.INSTANCE_SEGMENTATION, Ellipse(x1=0, y1=0, x2=10, y2=10, label=0), True),
121+
(OTXTaskType.INSTANCE_SEGMENTATION, Label(label=0), False), # Wrong type
122+
# Other task types (should use default is_valid_annot behavior)
123+
(OTXTaskType.MULTI_LABEL_CLS, Bbox(x=0, y=0, w=10, h=10, label=0), True),
124+
(OTXTaskType.MULTI_LABEL_CLS, Bbox(x=0, y=0, w=-1, h=-1, label=0), False), # Invalid bbox
125+
(OTXTaskType.MULTI_LABEL_CLS, Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0), True),
126+
(OTXTaskType.MULTI_LABEL_CLS, Polygon(points=[0, 0, 0, 0, 0, 0], label=0), False), # Invalid polygon
127+
(OTXTaskType.MULTI_LABEL_CLS, Ellipse(x1=0, y1=0, x2=10, y2=10, label=0), True),
128+
(OTXTaskType.MULTI_LABEL_CLS, Label(label=0), True), # Label is always valid
129+
(OTXTaskType.SEMANTIC_SEGMENTATION, Bbox(x=0, y=0, w=10, h=10, label=0), True),
130+
(OTXTaskType.SEMANTIC_SEGMENTATION, Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0), True),
131+
(OTXTaskType.SEMANTIC_SEGMENTATION, Ellipse(x1=0, y1=0, x2=10, y2=10, label=0), True),
132+
(OTXTaskType.SEMANTIC_SEGMENTATION, Label(label=0), True),
133+
(OTXTaskType.ANOMALY, Bbox(x=0, y=0, w=10, h=10, label=0), True),
134+
(OTXTaskType.ANOMALY, Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0), True),
135+
(OTXTaskType.ANOMALY, Ellipse(x1=0, y1=0, x2=10, y2=10, label=0), True),
136+
(OTXTaskType.ROTATED_DETECTION, Bbox(x=0, y=0, w=10, h=10, label=0), True),
137+
(OTXTaskType.ROTATED_DETECTION, Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0), True),
138+
(OTXTaskType.ROTATED_DETECTION, Ellipse(x1=0, y1=0, x2=10, y2=10, label=0), True),
139+
(OTXTaskType.ROTATED_DETECTION, Label(label=0), False),
140+
],
141+
)
142+
def test_is_valid_anno_for_task(
143+
self,
144+
fxt_dataset_item: DatasetItem,
145+
task: OTXTaskType,
146+
annotation,
147+
expected: bool,
148+
) -> None:
149+
"""Test is_valid_anno_for_task with various task types and annotations.
150+
151+
Args:
152+
fxt_dataset_item: The dataset item to test with
153+
task: The task type to test
154+
annotation: The annotation to test
155+
expected: Expected result (True if valid, False if invalid)
156+
"""
157+
result = is_valid_anno_for_task(fxt_dataset_item, annotation, task)
158+
assert result == expected, f"Expected {expected} for task {task} with annotation {type(annotation).__name__}"
159+
160+
def test_detection_task_with_valid_bbox(self, fxt_dataset_item: DatasetItem) -> None:
161+
"""Test DETECTION task with valid bounding box."""
162+
bbox = Bbox(x=5, y=5, w=20, h=15, label=0)
163+
result = is_valid_anno_for_task(fxt_dataset_item, bbox, OTXTaskType.DETECTION)
164+
assert result is True
165+
166+
def test_detection_task_with_invalid_bbox(self, fxt_dataset_item: DatasetItem) -> None:
167+
"""Test DETECTION task with invalid bounding box (negative dimensions)."""
168+
bbox = Bbox(x=10, y=10, w=-5, h=-5, label=0)
169+
result = is_valid_anno_for_task(fxt_dataset_item, bbox, OTXTaskType.DETECTION)
170+
assert result is False
171+
172+
def test_detection_task_with_zero_dimension_bbox(self, fxt_dataset_item: DatasetItem) -> None:
173+
"""Test DETECTION task with zero dimension bounding box."""
174+
bbox = Bbox(x=10, y=10, w=0, h=0, label=0)
175+
result = is_valid_anno_for_task(fxt_dataset_item, bbox, OTXTaskType.DETECTION)
176+
assert result is False
177+
178+
def test_detection_task_with_wrong_annotation_type(self, fxt_dataset_item: DatasetItem) -> None:
179+
"""Test DETECTION task with non-bbox annotation types."""
180+
polygon = Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0)
181+
ellipse = Ellipse(x1=0, y1=0, x2=10, y2=10, label=0)
182+
label = Label(label=0)
183+
184+
assert is_valid_anno_for_task(fxt_dataset_item, polygon, OTXTaskType.DETECTION) is False
185+
assert is_valid_anno_for_task(fxt_dataset_item, ellipse, OTXTaskType.DETECTION) is False
186+
assert is_valid_anno_for_task(fxt_dataset_item, label, OTXTaskType.DETECTION) is False
187+
188+
def test_instance_segmentation_task_with_valid_annotations(self, fxt_dataset_item: DatasetItem) -> None:
189+
"""Test INSTANCE_SEGMENTATION task with valid annotation types."""
190+
bbox = Bbox(x=0, y=0, w=10, h=10, label=0)
191+
polygon = Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0)
192+
ellipse = Ellipse(x1=0, y1=0, x2=10, y2=10, label=0)
193+
194+
assert is_valid_anno_for_task(fxt_dataset_item, bbox, OTXTaskType.INSTANCE_SEGMENTATION) is True
195+
assert is_valid_anno_for_task(fxt_dataset_item, polygon, OTXTaskType.INSTANCE_SEGMENTATION) is True
196+
assert is_valid_anno_for_task(fxt_dataset_item, ellipse, OTXTaskType.INSTANCE_SEGMENTATION) is True
197+
198+
def test_instance_segmentation_task_with_invalid_annotations(self, fxt_dataset_item: DatasetItem) -> None:
199+
"""Test INSTANCE_SEGMENTATION task with invalid annotation types."""
200+
invalid_bbox = Bbox(x=0, y=0, w=-1, h=-1, label=0)
201+
invalid_polygon = Polygon(points=[0, 0, 0, 0, 0, 0], label=0) # Degenerate polygon
202+
label = Label(label=0) # Wrong type
203+
204+
assert is_valid_anno_for_task(fxt_dataset_item, invalid_bbox, OTXTaskType.INSTANCE_SEGMENTATION) is False
205+
assert is_valid_anno_for_task(fxt_dataset_item, invalid_polygon, OTXTaskType.INSTANCE_SEGMENTATION) is False
206+
assert is_valid_anno_for_task(fxt_dataset_item, label, OTXTaskType.INSTANCE_SEGMENTATION) is False
207+
208+
def test_other_task_types_use_default_validation(self, fxt_dataset_item: DatasetItem) -> None:
209+
"""Test that other task types use the default is_valid_annot behavior."""
210+
valid_bbox = Bbox(x=0, y=0, w=10, h=10, label=0)
211+
invalid_bbox = Bbox(x=0, y=0, w=-1, h=-1, label=0)
212+
valid_polygon = Polygon(points=[0, 0, 10, 0, 10, 10, 0, 10], label=0)
213+
invalid_polygon = Polygon(points=[0, 0, 0, 0, 0, 0], label=0)
214+
label = Label(label=0)
215+
216+
# Test with CLASSIFICATION task
217+
assert is_valid_anno_for_task(fxt_dataset_item, valid_bbox, OTXTaskType.MULTI_CLASS_CLS) is True
218+
assert is_valid_anno_for_task(fxt_dataset_item, invalid_bbox, OTXTaskType.MULTI_CLASS_CLS) is False
219+
assert is_valid_anno_for_task(fxt_dataset_item, valid_polygon, OTXTaskType.MULTI_CLASS_CLS) is True
220+
assert is_valid_anno_for_task(fxt_dataset_item, invalid_polygon, OTXTaskType.MULTI_CLASS_CLS) is False
221+
assert is_valid_anno_for_task(fxt_dataset_item, label, OTXTaskType.MULTI_CLASS_CLS) is True
222+
223+
# Test with SEMANTIC_SEGMENTATION task
224+
assert is_valid_anno_for_task(fxt_dataset_item, valid_bbox, OTXTaskType.SEMANTIC_SEGMENTATION) is True
225+
assert is_valid_anno_for_task(fxt_dataset_item, invalid_bbox, OTXTaskType.SEMANTIC_SEGMENTATION) is False
226+
assert is_valid_anno_for_task(fxt_dataset_item, valid_polygon, OTXTaskType.SEMANTIC_SEGMENTATION) is True
227+
assert is_valid_anno_for_task(fxt_dataset_item, invalid_polygon, OTXTaskType.SEMANTIC_SEGMENTATION) is False
228+
assert is_valid_anno_for_task(fxt_dataset_item, label, OTXTaskType.SEMANTIC_SEGMENTATION) is True
229+
230+
def test_edge_cases(self, fxt_dataset_item: DatasetItem) -> None:
231+
"""Test edge cases for annotation validation."""
232+
# Very small but valid bbox
233+
small_bbox = Bbox(x=0, y=0, w=0.1, h=0.1, label=0)
234+
assert is_valid_anno_for_task(fxt_dataset_item, small_bbox, OTXTaskType.DETECTION) is True
235+
236+
# Bbox with equal coordinates (should be invalid)
237+
equal_bbox = Bbox(x=5, y=5, w=0, h=0, label=0)
238+
assert is_valid_anno_for_task(fxt_dataset_item, equal_bbox, OTXTaskType.DETECTION) is False
239+
240+
# Polygon with minimal valid area
241+
minimal_polygon = Polygon(points=[0, 0, 1, 0, 1, 1, 0, 1], label=0)
242+
assert is_valid_anno_for_task(fxt_dataset_item, minimal_polygon, OTXTaskType.INSTANCE_SEGMENTATION) is True
243+
244+
# Degenerate polygon (should be invalid)
245+
degenerate_polygon = Polygon(points=[0, 0, 0, 0, 0, 0], label=0)
246+
assert is_valid_anno_for_task(fxt_dataset_item, degenerate_polygon, OTXTaskType.INSTANCE_SEGMENTATION) is False

tests/unit/data/test_robust_dataset_statistics.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import numpy as np
99
import pytest
1010
from datumaro import Dataset as DmDataset
11-
from datumaro import DatasetSubset, DatasetItem
12-
from datumaro.components.annotation import AnnotationType, ExtractedMask, LabelCategories, Polygon, Bbox
11+
from datumaro import DatasetItem, DatasetSubset
12+
from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, LabelCategories, Polygon
1313
from datumaro.components.media import Image
1414

1515
from otx.data.utils.utils import compute_robust_dataset_statistics
@@ -19,17 +19,17 @@
1919
class TestComputeRobustDatasetStatistics:
2020
"""Test cases for compute_robust_dataset_statistics function."""
2121

22-
@pytest.fixture
22+
@pytest.fixture()
2323
def mock_semantic_seg_dataset(self):
2424
"""Create a mock semantic segmentation dataset with mixed annotation types."""
2525
dataset = DmDataset(media_type=Image)
26-
26+
2727
# Create label categories
2828
categories = LabelCategories()
2929
categories.add("background")
3030
categories.add("foreground")
3131
dataset.categories()[AnnotationType.label] = categories
32-
32+
3333
for i in range(5):
3434
image = Image.from_numpy(np.zeros((100, 100, 3), dtype=np.uint8))
3535

@@ -47,35 +47,34 @@ def mock_semantic_seg_dataset(self):
4747

4848
# Bbox annotation (background, should be ignored for SEMANTIC_SEGMENTATION)
4949
bbox = Bbox(60, 60, 20, 20, label=0)
50-
5150

5251
dataset.put(
5352
DatasetItem(
5453
id=str(i),
5554
media=image,
5655
annotations=[ann_mask, polygon, bbox],
5756
subset="train",
58-
)
57+
),
5958
)
6059
return dataset
6160

6261
def test_compute_robust_dataset_statistics_semantic_segmentation(self, mock_semantic_seg_dataset):
6362
"""Test that semantic segmentation with ExtractedMask annotations is handled correctly."""
6463
# Get the train subset
6564
train_subset = DatasetSubset(mock_semantic_seg_dataset, "train")
66-
65+
6766
# Compute statistics
6867
stats = compute_robust_dataset_statistics(
6968
dataset=train_subset,
7069
task=OTXTaskType.SEMANTIC_SEGMENTATION,
7170
max_samples=10,
7271
)
73-
72+
7473
# Verify the function doesn't crash and returns expected structure
7574
assert isinstance(stats, dict)
7675
assert "image" in stats
7776
assert "annotation" in stats
78-
77+
7978
image_statistics_keys = ["avg", "min", "max", "std", "robust_min", "robust_max"]
8079
annotation_statistics_keys = ["avg", "min", "max", "std", "robust_min", "robust_max"]
8180

@@ -87,35 +86,35 @@ def test_compute_robust_dataset_statistics_semantic_segmentation(self, mock_sema
8786

8887
for key in stats["annotation"]["num_per_image"]:
8988
assert key in annotation_statistics_keys
90-
89+
9190
for key in stats["annotation"]["size_of_shape"]:
9291
assert key in annotation_statistics_keys
9392

9493
def test_compute_robust_dataset_statistics_empty_dataset(self):
9594
"""Test handling of empty dataset."""
9695
empty_dataset = DmDataset(media_type=Image)
9796
train_subset = DatasetSubset(empty_dataset, "train")
98-
97+
9998
stats = compute_robust_dataset_statistics(
10099
dataset=train_subset,
101100
task=OTXTaskType.SEMANTIC_SEGMENTATION,
102101
)
103-
102+
104103
# Should return empty statistics
105104
assert stats == {"image": {}, "annotation": {}}
106105

107106
def test_compute_robust_dataset_statistics_max_samples_limit(self, mock_semantic_seg_dataset):
108107
"""Test that max_samples parameter limits the number of processed samples."""
109108
train_subset = DatasetSubset(mock_semantic_seg_dataset, "train")
110-
109+
111110
# Test with max_samples=2 (should only process 2 items)
112111
stats = compute_robust_dataset_statistics(
113112
dataset=train_subset,
114113
task=OTXTaskType.SEMANTIC_SEGMENTATION,
115114
max_samples=2,
116115
)
117-
116+
118117
# Should still return valid statistics
119118
assert isinstance(stats, dict)
120119
assert "image" in stats
121-
assert "annotation" in stats
120+
assert "annotation" in stats

0 commit comments

Comments
 (0)