Skip to content

Commit 46a9395

Browse files
authored
Fix semantic segmentation annotation handling for ExtractedMask type (#4511)
* Fix tiling when polygons are given
1 parent 4fb7e61 commit 46a9395

File tree

2 files changed

+123
-2
lines changed

2 files changed

+123
-2
lines changed

src/otx/data/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import cv2
1616
import numpy as np
1717
import torch
18-
from datumaro.components.annotation import AnnotationType, Bbox, LabelCategories, Polygon
18+
from datumaro.components.annotation import AnnotationType, Bbox, ExtractedMask, LabelCategories, Polygon
1919
from datumaro.components.annotation import Shape as _Shape
2020

2121
from otx.types import OTXTaskType
@@ -145,7 +145,7 @@ def compute_robust_dataset_statistics(
145145
data = dataset.get(id=idx, subset=dataset.name)
146146
annotations: dict[str, list] = defaultdict(list)
147147
for ann in data.annotations:
148-
if task is OTXTaskType.SEMANTIC_SEGMENTATION:
148+
if task is OTXTaskType.SEMANTIC_SEGMENTATION and isinstance(ann, ExtractedMask):
149149
# Skip background class
150150
if label_names and label_names[AnnotationType.label][ann.label].name == "background":
151151
continue
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Unit tests for compute_robust_dataset_statistics function."""
5+
6+
from __future__ import annotations
7+
8+
import numpy as np
9+
import pytest
10+
from datumaro import Dataset as DmDataset
11+
from datumaro import DatasetSubset, DatasetItem
12+
from datumaro.components.annotation import AnnotationType, ExtractedMask, LabelCategories, Polygon, Bbox
13+
from datumaro.components.media import Image
14+
15+
from otx.data.utils.utils import compute_robust_dataset_statistics
16+
from otx.types import OTXTaskType
17+
18+
19+
class TestComputeRobustDatasetStatistics:
20+
"""Test cases for compute_robust_dataset_statistics function."""
21+
22+
@pytest.fixture
23+
def mock_semantic_seg_dataset(self):
24+
"""Create a mock semantic segmentation dataset with mixed annotation types."""
25+
dataset = DmDataset(media_type=Image)
26+
27+
# Create label categories
28+
categories = LabelCategories()
29+
categories.add("background")
30+
categories.add("foreground")
31+
dataset.categories()[AnnotationType.label] = categories
32+
33+
for i in range(5):
34+
image = Image.from_numpy(np.zeros((100, 100, 3), dtype=np.uint8))
35+
36+
# ExtractedMask annotation (foreground)
37+
mask = np.zeros((100, 100), dtype=np.uint8)
38+
mask[20:40, 20:40] = 1
39+
ann_mask = ExtractedMask(
40+
index_mask=mask,
41+
index=0,
42+
label=1, # foreground
43+
)
44+
45+
# Polygon annotation (foreground)
46+
polygon = Polygon([10, 10, 50, 10, 50, 50, 10, 50], label=1)
47+
48+
# Bbox annotation (background, should be ignored for SEMANTIC_SEGMENTATION)
49+
bbox = Bbox(60, 60, 20, 20, label=0)
50+
51+
52+
dataset.put(
53+
DatasetItem(
54+
id=str(i),
55+
media=image,
56+
annotations=[ann_mask, polygon, bbox],
57+
subset="train",
58+
)
59+
)
60+
return dataset
61+
62+
def test_compute_robust_dataset_statistics_semantic_segmentation(self, mock_semantic_seg_dataset):
63+
"""Test that semantic segmentation with ExtractedMask annotations is handled correctly."""
64+
# Get the train subset
65+
train_subset = DatasetSubset(mock_semantic_seg_dataset, "train")
66+
67+
# Compute statistics
68+
stats = compute_robust_dataset_statistics(
69+
dataset=train_subset,
70+
task=OTXTaskType.SEMANTIC_SEGMENTATION,
71+
max_samples=10,
72+
)
73+
74+
# Verify the function doesn't crash and returns expected structure
75+
assert isinstance(stats, dict)
76+
assert "image" in stats
77+
assert "annotation" in stats
78+
79+
image_statistics_keys = ["avg", "min", "max", "std", "robust_min", "robust_max"]
80+
annotation_statistics_keys = ["avg", "min", "max", "std", "robust_min", "robust_max"]
81+
82+
for key in stats["image"]["height"]:
83+
assert key in image_statistics_keys
84+
85+
for key in stats["image"]["width"]:
86+
assert key in image_statistics_keys
87+
88+
for key in stats["annotation"]["num_per_image"]:
89+
assert key in annotation_statistics_keys
90+
91+
for key in stats["annotation"]["size_of_shape"]:
92+
assert key in annotation_statistics_keys
93+
94+
def test_compute_robust_dataset_statistics_empty_dataset(self):
95+
"""Test handling of empty dataset."""
96+
empty_dataset = DmDataset(media_type=Image)
97+
train_subset = DatasetSubset(empty_dataset, "train")
98+
99+
stats = compute_robust_dataset_statistics(
100+
dataset=train_subset,
101+
task=OTXTaskType.SEMANTIC_SEGMENTATION,
102+
)
103+
104+
# Should return empty statistics
105+
assert stats == {"image": {}, "annotation": {}}
106+
107+
def test_compute_robust_dataset_statistics_max_samples_limit(self, mock_semantic_seg_dataset):
108+
"""Test that max_samples parameter limits the number of processed samples."""
109+
train_subset = DatasetSubset(mock_semantic_seg_dataset, "train")
110+
111+
# Test with max_samples=2 (should only process 2 items)
112+
stats = compute_robust_dataset_statistics(
113+
dataset=train_subset,
114+
task=OTXTaskType.SEMANTIC_SEGMENTATION,
115+
max_samples=2,
116+
)
117+
118+
# Should still return valid statistics
119+
assert isinstance(stats, dict)
120+
assert "image" in stats
121+
assert "annotation" in stats

0 commit comments

Comments
 (0)