Skip to content

Commit a8ebbf1

Browse files
Convert segmentation mask to binary mask (#55)
1 parent 212287b commit a8ebbf1

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/labelformat/model/semantic_segmentation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
from typing import List, Optional, Tuple
44

5+
from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
6+
from labelformat.model.instance_segmentation import SingleInstanceSegmentation
7+
58
"""Semantic segmentation core types and input interface.
69
"""
710

@@ -56,6 +59,28 @@ def from_array(cls, array: NDArray[np.int_]) -> "SemanticSegmentationMask":
5659
category_id_rle=category_id_rle, width=array.shape[1], height=array.shape[0]
5760
)
5861

62+
def to_binary_mask(self, category_id: int) -> BinaryMaskSegmentation:
63+
"""Get a binary mask for a given category ID."""
64+
binary_rle = []
65+
66+
symbol = 0
67+
run_length = 0
68+
for cat_id, cur_run_length in self.category_id_rle:
69+
cur_symbol = 1 if cat_id == category_id else 0
70+
if symbol == cur_symbol:
71+
run_length += cur_run_length
72+
else:
73+
binary_rle.append(run_length)
74+
symbol = cur_symbol
75+
run_length = cur_run_length
76+
77+
binary_rle.append(run_length)
78+
return BinaryMaskSegmentation.from_rle(
79+
rle_row_wise=binary_rle,
80+
width=self.width,
81+
height=self.height,
82+
)
83+
5984

6085
class SemanticSegmentationInput(ABC):
6186

tests/unit/model/test_semantic_segmentation.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,35 @@ def test_from_array(self) -> None:
2525
assert mask.category_id_rle == expected_rle
2626
assert mask.width == 4
2727
assert mask.height == 3
28+
29+
def test_to_binary_mask(self) -> None:
30+
mask = SemanticSegmentationMask.from_array(
31+
array=np.array(
32+
[
33+
[1, 1, 2, 2],
34+
[2, 1, 1, 1],
35+
[3, 3, 3, 3],
36+
],
37+
dtype=np.int_,
38+
)
39+
)
40+
binary_mask = mask.to_binary_mask(category_id=1)
41+
assert binary_mask.get_binary_mask().tolist() == [
42+
[1, 1, 0, 0],
43+
[0, 1, 1, 1],
44+
[0, 0, 0, 0],
45+
]
46+
47+
binary_mask = mask.to_binary_mask(category_id=2)
48+
assert binary_mask.get_binary_mask().tolist() == [
49+
[0, 0, 1, 1],
50+
[1, 0, 0, 0],
51+
[0, 0, 0, 0],
52+
]
53+
54+
binary_mask = mask.to_binary_mask(category_id=4)
55+
assert binary_mask.get_binary_mask().tolist() == [
56+
[0, 0, 0, 0],
57+
[0, 0, 0, 0],
58+
[0, 0, 0, 0],
59+
]

0 commit comments

Comments
 (0)