Skip to content

Commit 1f79c94

Browse files
Change PascalVOC input interface to InstanceSegmentationInpuT
1 parent a8ebbf1 commit 1f79c94

File tree

4 files changed

+132
-34
lines changed

4 files changed

+132
-34
lines changed

src/labelformat/formats/semantic_segmentation/pascalvoc.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
"""Pascal VOC semantic segmentation input.
42
53
Assumptions:
@@ -8,6 +6,9 @@
86
- Masks are PNGs with pixel values equal to class IDs.
97
"""
108

9+
from __future__ import annotations
10+
11+
from argparse import ArgumentParser
1112
from collections.abc import Iterable, Mapping
1213
from dataclasses import dataclass
1314
from pathlib import Path
@@ -19,8 +20,12 @@
1920
from labelformat import utils
2021
from labelformat.model.category import Category
2122
from labelformat.model.image import Image
23+
from labelformat.model.instance_segmentation import (
24+
ImageInstanceSegmentation,
25+
InstanceSegmentationInput,
26+
SingleInstanceSegmentation,
27+
)
2228
from labelformat.model.semantic_segmentation import (
23-
SemanticSegmentationInput,
2429
SemanticSegmentationMask,
2530
)
2631

@@ -34,12 +39,19 @@
3439

3540

3641
@dataclass
37-
class PascalVOCSemanticSegmentationInput(SemanticSegmentationInput):
42+
class PascalVOCSemanticSegmentationInput(InstanceSegmentationInput):
43+
"""Pascal VOC semantic segmentation input format."""
44+
3845
_images_dir: Path
3946
_masks_dir: Path
4047
_filename_to_image: dict[str, Image]
4148
_categories: list[Category]
4249

50+
@staticmethod
51+
def add_cli_arguments(parser: ArgumentParser) -> None:
52+
# TODO(Michal, 01/2026): Implement when needed.
53+
raise NotImplementedError()
54+
4355
@classmethod
4456
def from_dirs(
4557
cls,
@@ -91,7 +103,30 @@ def get_categories(self) -> Iterable[Category]:
91103
def get_images(self) -> Iterable[Image]:
92104
yield from self._filename_to_image.values()
93105

94-
def get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
106+
def get_labels(self) -> Iterable[ImageInstanceSegmentation]:
107+
"""Get semantic segmentation labels.
108+
109+
Yields an object per image, with one binary mask per category present in the mask.
110+
The order of objects is sorted by category ID. Reuses the ImageInstanceSegmentation
111+
as the return type for convenience.
112+
"""
113+
category_id_to_category = {c.id: c for c in self._categories}
114+
for image in self.get_images():
115+
mask = self._get_mask(image_filepath=image.filename)
116+
category_ids_in_mask = mask.category_ids()
117+
objects = [
118+
SingleInstanceSegmentation(
119+
category=category_id_to_category[cid],
120+
segmentation=mask.to_binary_mask(category_id=cid),
121+
)
122+
for cid in sorted(category_ids_in_mask)
123+
]
124+
yield ImageInstanceSegmentation(
125+
image=image,
126+
objects=objects,
127+
)
128+
129+
def _get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
95130
# Validate image exists in our index.
96131
image_obj = self._filename_to_image.get(image_filepath)
97132
if image_obj is None:

src/labelformat/model/semantic_segmentation.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import List, Optional, Tuple
4-
53
from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
6-
from labelformat.model.instance_segmentation import SingleInstanceSegmentation
74

85
"""Semantic segmentation core types and input interface.
96
"""
107

11-
from abc import ABC, abstractmethod
12-
from collections.abc import Iterable
138
from dataclasses import dataclass
149

1510
import numpy as np
1611
from numpy.typing import NDArray
1712

18-
from labelformat.model.category import Category
19-
from labelformat.model.image import Image
20-
2113

2214
@dataclass
2315
class SemanticSegmentationMask:
@@ -29,7 +21,7 @@ class SemanticSegmentationMask:
2921
array: The 2D numpy array with integer class IDs of shape (H, W).
3022
"""
3123

32-
category_id_rle: List[Tuple[int, int]]
24+
category_id_rle: list[tuple[int, int]]
3325
"""The mask as a run-length encoding (RLE) list of (category_id, run_length) tuples."""
3426
width: int
3527
height: int
@@ -40,9 +32,9 @@ def from_array(cls, array: NDArray[np.int_]) -> "SemanticSegmentationMask":
4032
if array.ndim != 2:
4133
raise ValueError("SemSegMask.array must be 2D with shape (H, W).")
4234

43-
category_id_rle: List[Tuple[int, int]] = []
35+
category_id_rle: list[tuple[int, int]] = []
4436

45-
cur_cat_id: Optional[int] = None
37+
cur_cat_id: int | None = None
4638
cur_run_length = 0
4739
for cat_id in array.flatten():
4840
if cat_id == cur_cat_id:
@@ -81,19 +73,6 @@ def to_binary_mask(self, category_id: int) -> BinaryMaskSegmentation:
8173
height=self.height,
8274
)
8375

84-
85-
class SemanticSegmentationInput(ABC):
86-
87-
# TODO(Malte, 11/2025): Add a CLI interface later if needed.
88-
89-
@abstractmethod
90-
def get_categories(self) -> Iterable[Category]:
91-
raise NotImplementedError()
92-
93-
@abstractmethod
94-
def get_images(self) -> Iterable[Image]:
95-
raise NotImplementedError()
96-
97-
@abstractmethod
98-
def get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
99-
raise NotImplementedError()
76+
def category_ids(self) -> set[int]:
77+
"""Get the set of category IDs present in the mask."""
78+
return {cat_id for cat_id, _ in self.category_id_rle}

tests/unit/formats/semantic_segmentation/test_pascalvoc.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pathlib import Path
55
from typing import Dict
66

7+
import cv2
78
import numpy as np
89
import pytest
910
from PIL import Image as PILImage
@@ -12,6 +13,7 @@
1213
from labelformat.formats.semantic_segmentation.pascalvoc import (
1314
PascalVOCSemanticSegmentationInput,
1415
)
16+
from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
1517
from labelformat.model.image import Image
1618
from tests.unit.test_utils import FIXTURES_DIR
1719

@@ -52,7 +54,7 @@ def test_get_mask__returns_rle_and_matches_image_length(self) -> None:
5254
)
5355

5456
for img in ds.get_images():
55-
mask = ds.get_mask(img.filename)
57+
mask = ds._get_mask(img.filename)
5658
length = sum(run_length for _, run_length in mask.category_id_rle)
5759
assert length == img.width * img.height
5860

@@ -83,7 +85,77 @@ def test_get_mask__unknown_image_raises(self) -> None:
8385
ValueError,
8486
match=r"Unknown image filepath does_not_exist\.jpg",
8587
):
86-
ds.get_mask("does_not_exist.jpg")
88+
ds._get_mask("does_not_exist.jpg")
89+
90+
def test_get_labels(self, tmp_path: Path) -> None:
91+
images_dir = tmp_path / "images"
92+
images_dir.mkdir()
93+
masks_dir = tmp_path / "masks"
94+
masks_dir.mkdir()
95+
96+
# Create a simple image and mask
97+
image0_bgr = np.full((3, 4, 3), (255, 0, 0), dtype=np.uint8)
98+
cv2.imwrite(str(images_dir / "image0.jpg"), image0_bgr)
99+
mask0 = np.array([[1, 0, 0, 0], [1, 0, 2, 2], [0, 0, 2, 0]], dtype=np.uint8)
100+
cv2.imwrite(str(masks_dir / "image0.png"), mask0)
101+
102+
# Create another image and mask
103+
image1_bgr = np.full((2, 2, 3), (0, 255, 0), dtype=np.uint8)
104+
cv2.imwrite(str(images_dir / "image1.jpg"), image1_bgr)
105+
mask1 = np.array([[1, 1], [1, 1]], dtype=np.uint8)
106+
cv2.imwrite(str(masks_dir / "image1.png"), mask1)
107+
108+
# Create input instance
109+
label_input = PascalVOCSemanticSegmentationInput.from_dirs(
110+
images_dir=images_dir,
111+
masks_dir=masks_dir,
112+
class_id_to_name={0: "a", 1: "b", 2: "c", 3: "d"},
113+
)
114+
115+
# Call get_labels
116+
labels = sorted(label_input.get_labels(), key=lambda x: x.image.filename)
117+
assert len(labels) == 2
118+
119+
# Verify first image labels
120+
assert labels[0].image.filename == "image0.jpg"
121+
objects = labels[0].objects
122+
assert len(objects) == 3
123+
assert objects[0].category.id == 0
124+
assert objects[0].category.name == "a"
125+
assert isinstance(objects[0].segmentation, BinaryMaskSegmentation)
126+
assert objects[0].segmentation.get_binary_mask().tolist() == [
127+
[0, 1, 1, 1],
128+
[0, 1, 0, 0],
129+
[1, 1, 0, 1],
130+
]
131+
assert objects[1].category.id == 1
132+
assert objects[1].category.name == "b"
133+
assert isinstance(objects[1].segmentation, BinaryMaskSegmentation)
134+
assert objects[1].segmentation.get_binary_mask().tolist() == [
135+
[1, 0, 0, 0],
136+
[1, 0, 0, 0],
137+
[0, 0, 0, 0],
138+
]
139+
assert objects[2].category.id == 2
140+
assert objects[2].category.name == "c"
141+
assert isinstance(objects[2].segmentation, BinaryMaskSegmentation)
142+
assert objects[2].segmentation.get_binary_mask().tolist() == [
143+
[0, 0, 0, 0],
144+
[0, 0, 1, 1],
145+
[0, 0, 1, 0],
146+
]
147+
148+
# Verify second image labels
149+
assert labels[1].image.filename == "image1.jpg"
150+
assert len(labels[1].objects) == 1
151+
obj = labels[1].objects[0]
152+
assert obj.category.id == 1
153+
assert obj.category.name == "b"
154+
assert isinstance(obj.segmentation, BinaryMaskSegmentation)
155+
assert obj.segmentation.get_binary_mask().tolist() == [
156+
[1, 1],
157+
[1, 1],
158+
]
87159

88160

89161
def test__validate_mask__unknown_class_value_raises() -> None:

tests/unit/model/test_semantic_segmentation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,15 @@ def test_to_binary_mask(self) -> None:
5757
[0, 0, 0, 0],
5858
[0, 0, 0, 0],
5959
]
60+
61+
def test_category_ids(self) -> None:
62+
mask = SemanticSegmentationMask.from_array(
63+
array=np.array(
64+
[
65+
[1, 1, 4],
66+
[4, 1, 1],
67+
],
68+
dtype=np.int_,
69+
)
70+
)
71+
assert mask.category_ids() == {1, 4}

0 commit comments

Comments
 (0)