Skip to content

Commit 3c21662

Browse files
committed
Merge branch 'jonas-lig-8150-instance-segmentation-youtubevis-format-coco-helpers' of github.com:lightly-ai/labelformat into jonas-lig-8150-instance-segmentation-youtubevis-format-coco-helpers
2 parents 87549b6 + be0c221 commit 3c21662

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from argparse import ArgumentParser
5+
from dataclasses import dataclass
6+
from typing import Iterable
7+
8+
from labelformat.model.binary_mask_segmentation import BinaryMaskSegmentation
9+
from labelformat.model.category import Category
10+
from labelformat.model.multipolygon import MultiPolygon
11+
from labelformat.model.video import Video
12+
13+
14+
@dataclass(frozen=True)
15+
class SingleInstanceSegmentationTrack:
16+
category: Category
17+
segmentations: list[MultiPolygon | BinaryMaskSegmentation | None]
18+
19+
20+
@dataclass(frozen=True)
21+
class VideoInstanceSegmentationTrack:
22+
"""
23+
The base class for a video alongside with its object detection track annotations.
24+
A video consists of N frames and M objects. Each object is defined by N boxes - one for each frame.
25+
If an object is not present on a frame, the corresponding entry is set to None.
26+
"""
27+
28+
video: Video
29+
objects: list[SingleInstanceSegmentationTrack]
30+
31+
def __post_init__(self) -> None:
32+
number_of_frames = self.video.number_of_frames
33+
34+
for obj in self.objects:
35+
if len(obj.segmentations) != number_of_frames:
36+
raise ValueError(
37+
"Length of instance segmentation track does not match the number of frames in the video."
38+
)
39+
40+
41+
class InstanceSegmentationTrackInput(ABC):
42+
@staticmethod
43+
@abstractmethod
44+
def add_cli_arguments(parser: ArgumentParser) -> None:
45+
raise NotImplementedError()
46+
47+
@abstractmethod
48+
def get_categories(self) -> Iterable[Category]:
49+
raise NotImplementedError()
50+
51+
@abstractmethod
52+
def get_videos(self) -> Iterable[Video]:
53+
raise NotImplementedError()
54+
55+
@abstractmethod
56+
def get_labels(self) -> Iterable[VideoInstanceSegmentationTrack]:
57+
raise NotImplementedError()
58+
59+
60+
class InstanceSegmentationTrackOutput(ABC):
61+
@staticmethod
62+
@abstractmethod
63+
def add_cli_arguments(parser: ArgumentParser) -> None:
64+
raise NotImplementedError()
65+
66+
def save(self, label_input: InstanceSegmentationTrackInput) -> None:
67+
raise NotImplementedError()
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from labelformat.model.category import Category
6+
from labelformat.model.instance_segmentation_track import (
7+
SingleInstanceSegmentationTrack,
8+
VideoInstanceSegmentationTrack,
9+
)
10+
from labelformat.model.multipolygon import MultiPolygon
11+
from labelformat.model.video import Video
12+
13+
14+
class TestVideoInstanceSegmentationTrack:
15+
def test_post_init__frames_equal_segmentations_length__valid(self) -> None:
16+
track_a = SingleInstanceSegmentationTrack(
17+
category=Category(id=0, name="cat"),
18+
segmentations=[
19+
MultiPolygon(polygons=[[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)]]),
20+
None,
21+
],
22+
)
23+
24+
track_b = SingleInstanceSegmentationTrack(
25+
category=Category(id=1, name="dog"),
26+
segmentations=[
27+
MultiPolygon(polygons=[[(2.0, 2.0), (3.0, 2.0), (3.0, 3.0)]]),
28+
MultiPolygon(polygons=[[(4.0, 4.0), (5.0, 4.0), (5.0, 5.0)]]),
29+
],
30+
)
31+
32+
video = Video(id=0, filename="test.mov", width=1, height=1, number_of_frames=2)
33+
34+
instance_seg = VideoInstanceSegmentationTrack(
35+
video=video,
36+
objects=[track_a, track_b],
37+
)
38+
assert len(instance_seg.objects) == 2
39+
assert len(instance_seg.objects[0].segmentations) == 2
40+
41+
def test_post_init__frames_equal_segmentations_length___invalid(self) -> None:
42+
track_a = SingleInstanceSegmentationTrack(
43+
category=Category(id=0, name="cat"),
44+
segmentations=[
45+
MultiPolygon(polygons=[[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)]]),
46+
None,
47+
None,
48+
],
49+
)
50+
51+
video = Video(id=0, filename="test.mov", width=1, height=1, number_of_frames=2)
52+
53+
with pytest.raises(
54+
ValueError,
55+
match="Length of instance segmentation track does not match the number of frames in the video.",
56+
):
57+
VideoInstanceSegmentationTrack(
58+
video=video,
59+
objects=[track_a],
60+
)

0 commit comments

Comments
 (0)