|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from labelformat.model.category import Category |
| 6 | +from labelformat.model.instance_segmentation_track import ( |
| 7 | + SingleInstanceSegmentationTrack, |
| 8 | + VideoInstanceSegmentationTrack, |
| 9 | +) |
| 10 | +from labelformat.model.multipolygon import MultiPolygon |
| 11 | +from labelformat.model.video import Video |
| 12 | + |
| 13 | + |
| 14 | +class TestVideoInstanceSegmentationTrack: |
| 15 | + def test_post_init__frames_equal_segmentations_length__valid(self) -> None: |
| 16 | + track_a = SingleInstanceSegmentationTrack( |
| 17 | + category=Category(id=0, name="cat"), |
| 18 | + segmentations=[ |
| 19 | + MultiPolygon(polygons=[[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)]]), |
| 20 | + None, |
| 21 | + ], |
| 22 | + ) |
| 23 | + |
| 24 | + track_b = SingleInstanceSegmentationTrack( |
| 25 | + category=Category(id=1, name="dog"), |
| 26 | + segmentations=[ |
| 27 | + MultiPolygon(polygons=[[(2.0, 2.0), (3.0, 2.0), (3.0, 3.0)]]), |
| 28 | + MultiPolygon(polygons=[[(4.0, 4.0), (5.0, 4.0), (5.0, 5.0)]]), |
| 29 | + ], |
| 30 | + ) |
| 31 | + |
| 32 | + video = Video(id=0, filename="test.mov", width=1, height=1, number_of_frames=2) |
| 33 | + |
| 34 | + instance_seg = VideoInstanceSegmentationTrack( |
| 35 | + video=video, |
| 36 | + objects=[track_a, track_b], |
| 37 | + ) |
| 38 | + assert len(instance_seg.objects) == 2 |
| 39 | + assert len(instance_seg.objects[0].segmentations) == 2 |
| 40 | + |
| 41 | + def test_post_init__frames_equal_segmentations_length___invalid(self) -> None: |
| 42 | + track_a = SingleInstanceSegmentationTrack( |
| 43 | + category=Category(id=0, name="cat"), |
| 44 | + segmentations=[ |
| 45 | + MultiPolygon(polygons=[[(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)]]), |
| 46 | + None, |
| 47 | + None, |
| 48 | + ], |
| 49 | + ) |
| 50 | + |
| 51 | + video = Video(id=0, filename="test.mov", width=1, height=1, number_of_frames=2) |
| 52 | + |
| 53 | + with pytest.raises( |
| 54 | + ValueError, |
| 55 | + match="Length of instance segmentation track does not match the number of frames in the video.", |
| 56 | + ): |
| 57 | + VideoInstanceSegmentationTrack( |
| 58 | + video=video, |
| 59 | + objects=[track_a], |
| 60 | + ) |
0 commit comments