Skip to content

Commit 49df95c

Browse files
committed
Merge commit '325a20b1b4f61030fd5c7ebc954af85c0fd07d2b' into jonas-lig-8150-instance-segmentation-model
2 parents 3fc1180 + 325a20b commit 49df95c

File tree

10 files changed

+433
-54
lines changed

10 files changed

+433
-54
lines changed

.github/pull_request_template.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## What has changed and why?
2+
3+
(Delete this: Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.)
4+
5+
## How has it been tested?
6+
7+
(Delete this: Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration.)

.python-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.7.16
1+
3.8

src/labelformat/formats/semantic_segmentation/pascalvoc.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import annotations
2-
31
"""Pascal VOC semantic segmentation input.
42
53
Assumptions:
@@ -8,6 +6,9 @@
86
- Masks are PNGs with pixel values equal to class IDs.
97
"""
108

9+
from __future__ import annotations
10+
11+
from argparse import ArgumentParser
1112
from collections.abc import Iterable, Mapping
1213
from dataclasses import dataclass
1314
from pathlib import Path
@@ -19,10 +20,12 @@
1920
from labelformat import utils
2021
from labelformat.model.category import Category
2122
from labelformat.model.image import Image
22-
from labelformat.model.semantic_segmentation import (
23-
SemanticSegmentationInput,
24-
SemanticSegmentationMask,
23+
from labelformat.model.instance_segmentation import (
24+
ImageInstanceSegmentation,
25+
InstanceSegmentationInput,
26+
SingleInstanceSegmentation,
2527
)
28+
from labelformat.model.semantic_segmentation import SemanticSegmentationMask
2629

2730
"""TODO(Malte, 11/2025):
2831
Support what is already supported in LightlyTrain. https://docs.lightly.ai/train/stable/semantic_segmentation.html#data
@@ -34,12 +37,19 @@
3437

3538

3639
@dataclass
37-
class PascalVOCSemanticSegmentationInput(SemanticSegmentationInput):
40+
class PascalVOCSemanticSegmentationInput(InstanceSegmentationInput):
41+
"""Pascal VOC semantic segmentation input format."""
42+
3843
_images_dir: Path
3944
_masks_dir: Path
4045
_filename_to_image: dict[str, Image]
4146
_categories: list[Category]
4247

48+
@staticmethod
49+
def add_cli_arguments(parser: ArgumentParser) -> None:
50+
# TODO(Michal, 01/2026): Implement when needed.
51+
raise NotImplementedError()
52+
4353
@classmethod
4454
def from_dirs(
4555
cls,
@@ -91,7 +101,30 @@ def get_categories(self) -> Iterable[Category]:
91101
def get_images(self) -> Iterable[Image]:
92102
yield from self._filename_to_image.values()
93103

94-
def get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
104+
def get_labels(self) -> Iterable[ImageInstanceSegmentation]:
105+
"""Get semantic segmentation labels.
106+
107+
Yields an object per image, with one binary mask per category present in the mask.
108+
The order of objects is sorted by category ID. Reuses the ImageInstanceSegmentation
109+
as the return type for convenience.
110+
"""
111+
category_id_to_category = {c.id: c for c in self._categories}
112+
for image in self.get_images():
113+
mask = self._get_mask(image_filepath=image.filename)
114+
category_ids_in_mask = mask.category_ids()
115+
objects = [
116+
SingleInstanceSegmentation(
117+
category=category_id_to_category[cid],
118+
segmentation=mask.to_binary_mask(category_id=cid),
119+
)
120+
for cid in sorted(category_ids_in_mask)
121+
]
122+
yield ImageInstanceSegmentation(
123+
image=image,
124+
objects=objects,
125+
)
126+
127+
def _get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
95128
# Validate image exists in our index.
96129
image_obj = self._filename_to_image.get(image_filepath)
97130
if image_obj is None:
@@ -114,7 +147,7 @@ def get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
114147
valid_class_ids={c.id for c in self._categories},
115148
)
116149

117-
return SemanticSegmentationMask(array=mask_np)
150+
return SemanticSegmentationMask.from_array(array=mask_np)
118151

119152

120153
def _validate_mask(

src/labelformat/formats/youtubevis.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def _get_object_track_boxes(
8383
for bbox in ann["bboxes"]:
8484
if bbox is None or len(bbox) == 0:
8585
boxes.append(None)
86-
continue
87-
boxes.append(
88-
BoundingBox.from_format(
89-
bbox=[float(x) for x in bbox],
90-
format=BoundingBoxFormat.XYWH,
86+
else:
87+
boxes.append(
88+
BoundingBox.from_format(
89+
bbox=[float(x) for x in bbox],
90+
format=BoundingBoxFormat.XYWH,
91+
)
9192
)
92-
)
9393
return boxes

src/labelformat/model/binary_mask_segmentation.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,28 @@ def from_binary_mask(
4242
bounding_box=bounding_box,
4343
)
4444

45+
@classmethod
46+
def from_rle(
47+
cls,
48+
rle_row_wise: list[int],
49+
width: int,
50+
height: int,
51+
bounding_box: BoundingBox | None = None,
52+
) -> "BinaryMaskSegmentation":
53+
"""
54+
Create a BinaryMaskSegmentation instance from row-wise RLE format.
55+
"""
56+
if bounding_box is None:
57+
bounding_box = _compute_bbox_from_rle(
58+
rle_row_wise=rle_row_wise, width=width, height=height
59+
)
60+
return cls(
61+
_rle_row_wise=rle_row_wise,
62+
width=width,
63+
height=height,
64+
bounding_box=bounding_box,
65+
)
66+
4567
def get_binary_mask(self) -> NDArray[np.int_]:
4668
"""
4769
Get the binary mask (2D numpy array) from the RLE format.
@@ -50,6 +72,15 @@ def get_binary_mask(self) -> NDArray[np.int_]:
5072
self._rle_row_wise, self.height, self.width
5173
)
5274

75+
def get_rle(self) -> list[int]:
76+
"""
77+
Get the run-length encoding (RLE) of the binary mask in row-wise format.
78+
79+
The first element corresponds to the number of 0s at the start of the mask.
80+
If the mask starts with a 1, the first element will be 0. No other zeros can appear.
81+
"""
82+
return self._rle_row_wise
83+
5384

5485
class RLEDecoderEncoder:
5586
"""
@@ -112,3 +143,49 @@ def decode_column_wise_rle(
112143
decoded.extend([run_val] * count)
113144
run_val = 1 - run_val
114145
return np.array(decoded, dtype=np.int_).reshape((height, width), order="F")
146+
147+
148+
def _compute_bbox_from_rle(
149+
rle_row_wise: list[int], width: int, height: int
150+
) -> BoundingBox:
151+
"""Compute bounding box from row-wise RLE.
152+
153+
Scans through the RLE and tracks the min/max x/y coordinates of the '1' pixels.
154+
The time complexity is O(len(rle_row_wise)).
155+
"""
156+
xmin = width
157+
ymin = height
158+
xmax = 0
159+
ymax = 0
160+
161+
x = 0
162+
y = 0
163+
next_symbol = 0
164+
for run_length in rle_row_wise:
165+
if next_symbol == 1:
166+
# Compute coordinates for the end of the run
167+
run_end_x = x + run_length - 1
168+
run_end_y = y
169+
if run_end_x >= width:
170+
run_end_y += run_end_x // width
171+
run_end_x = run_end_x % width
172+
173+
# Update bounding box
174+
ymin = min(ymin, y)
175+
ymax = max(ymax, run_end_y)
176+
if run_end_y > y:
177+
xmin = 0
178+
xmax = width - 1
179+
else:
180+
xmin = min(xmin, x)
181+
xmax = max(xmax, run_end_x)
182+
183+
# Compute coordinates for the start of the next run
184+
x += run_length
185+
if x >= width:
186+
y += x // width
187+
x = x % width
188+
189+
next_symbol = 1 - next_symbol
190+
191+
return BoundingBox(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,79 @@
11
from __future__ import annotations
22

3+
from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
4+
35
"""Semantic segmentation core types and input interface.
46
"""
57

6-
from abc import ABC, abstractmethod
7-
from collections.abc import Iterable
88
from dataclasses import dataclass
99

1010
import numpy as np
1111
from numpy.typing import NDArray
1212

13-
from labelformat.model.category import Category
14-
from labelformat.model.image import Image
15-
1613

1714
@dataclass
1815
class SemanticSegmentationMask:
1916
"""Semantic segmentation mask with integer class IDs.
2017
21-
The mask is stored as a 2D numpy array of integer class IDs with shape (H, W).
18+
For internal purposes only, interface might change between minor versions!
2219
23-
Args:
24-
array: The 2D numpy array with integer class IDs of shape (H, W).
20+
The mask is stored as multiclass run-length encoding (RLE).
2521
"""
2622

27-
array: NDArray[np.int_]
28-
29-
def __post_init__(self) -> None:
30-
if self.array.ndim != 2:
23+
category_id_rle: list[tuple[int, int]]
24+
"""The mask as a run-length encoding (RLE) list of (category_id, run_length) tuples."""
25+
width: int
26+
"""Width of the mask in pixels."""
27+
height: int
28+
"""Height of the mask in pixels."""
29+
30+
@classmethod
31+
def from_array(cls, array: NDArray[np.int_]) -> "SemanticSegmentationMask":
32+
"""Create a SemanticSegmentationMask from a 2D numpy array."""
33+
if array.ndim != 2:
3134
raise ValueError("SemSegMask.array must be 2D with shape (H, W).")
3235

33-
34-
class SemanticSegmentationInput(ABC):
35-
36-
# TODO(Malte, 11/2025): Add a CLI interface later if needed.
37-
38-
@abstractmethod
39-
def get_categories(self) -> Iterable[Category]:
40-
raise NotImplementedError()
41-
42-
@abstractmethod
43-
def get_images(self) -> Iterable[Image]:
44-
raise NotImplementedError()
45-
46-
@abstractmethod
47-
def get_mask(self, image_filepath: str) -> SemanticSegmentationMask:
48-
raise NotImplementedError()
36+
category_id_rle: list[tuple[int, int]] = []
37+
38+
cur_cat_id: int | None = None
39+
cur_run_length = 0
40+
for cat_id in array.flatten():
41+
if cat_id == cur_cat_id:
42+
cur_run_length += 1
43+
else:
44+
if cur_cat_id is not None:
45+
category_id_rle.append((cur_cat_id, cur_run_length))
46+
cur_cat_id = cat_id
47+
cur_run_length = 1
48+
if cur_cat_id is not None:
49+
category_id_rle.append((cur_cat_id, cur_run_length))
50+
51+
return cls(
52+
category_id_rle=category_id_rle, width=array.shape[1], height=array.shape[0]
53+
)
54+
55+
def to_binary_mask(self, category_id: int) -> BinaryMaskSegmentation:
56+
"""Get a binary mask for a given category ID."""
57+
binary_rle = []
58+
59+
symbol = 0
60+
run_length = 0
61+
for cat_id, cur_run_length in self.category_id_rle:
62+
cur_symbol = 1 if cat_id == category_id else 0
63+
if symbol == cur_symbol:
64+
run_length += cur_run_length
65+
else:
66+
binary_rle.append(run_length)
67+
symbol = cur_symbol
68+
run_length = cur_run_length
69+
70+
binary_rle.append(run_length)
71+
return BinaryMaskSegmentation.from_rle(
72+
rle_row_wise=binary_rle,
73+
width=self.width,
74+
height=self.height,
75+
)
76+
77+
def category_ids(self) -> set[int]:
78+
"""Get the set of category IDs present in the mask."""
79+
return {cat_id for cat_id, _ in self.category_id_rle}

0 commit comments

Comments
 (0)