Skip to content

Commit 87e10c7

Browse files
authored
Support Ellipse Shape for InstSeg algo (#4152)
* ellipse shape * Update changelog * update transform * update * Allow empty anno * Update todo
1 parent f4ffcbf commit 87e10c7

File tree

3 files changed

+120
-25
lines changed

3 files changed

+120
-25
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ All notable changes to this project will be documented in this file.
146146
(<https://github.com/openvinotoolkit/training_extensions/pull/4105>)
147147
- Disable tiling classifier toggle in configurable parameters
148148
(<https://github.com/openvinotoolkit/training_extensions/pull/4107>)
149+
- Fix Ellipse shapes for Instance Segmentation
150+
(<https://github.com/openvinotoolkit/training_extensions/pull/4152>)
149151

150152
## \[v2.1.0\]
151153

src/otx/core/data/dataset/instance_segmentation.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
from __future__ import annotations
77

8+
import warnings
9+
from collections import defaultdict
810
from functools import partial
911
from typing import Callable
1012

1113
import numpy as np
1214
import torch
15+
from datumaro import Bbox, Ellipse, Image, Polygon
1316
from datumaro import Dataset as DmDataset
14-
from datumaro import Image, Polygon
1517
from torchvision import tv_tensors
1618

1719
from otx.core.data.entity.base import ImageInfo
@@ -42,23 +44,49 @@ def _get_item_impl(self, index: int) -> InstanceSegDataEntity | None:
4244
ignored_labels: list[int] = []
4345
img_data, img_shape, _ = self._get_img_data_and_shape(img)
4446

47+
anno_collection: dict[str, list] = defaultdict(list)
48+
for anno in item.annotations:
49+
anno_collection[anno.__class__.__name__].append(anno)
50+
4551
gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []
4652

47-
for annotation in item.annotations:
48-
if isinstance(annotation, Polygon):
49-
bbox = np.array(annotation.get_bbox(), dtype=np.float32)
53+
# TODO(Eugene): https://jira.devtools.intel.com/browse/CVS-159363
54+
# Temporary solution to handle multiple annotation types.
55+
# Ideally, we should pre-filter annotations during initialization of the dataset.
56+
if Polygon.__name__ in anno_collection: # Polygon for InstSeg has higher priority
57+
for poly in anno_collection[Polygon.__name__]:
58+
bbox = Bbox(*poly.get_bbox()).points
5059
gt_bboxes.append(bbox)
51-
gt_labels.append(annotation.label)
60+
gt_labels.append(poly.label)
5261

5362
if self.include_polygons:
54-
gt_polygons.append(annotation)
63+
gt_polygons.append(poly)
5564
else:
56-
gt_masks.append(polygon_to_bitmap([annotation], *img_shape)[0])
57-
58-
# convert xywh to xyxy format
59-
bboxes = np.array(gt_bboxes, dtype=np.float32) if gt_bboxes else np.empty((0, 4))
60-
bboxes[:, 2:] += bboxes[:, :2]
65+
gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0])
66+
elif Bbox.__name__ in anno_collection:
67+
bboxes = anno_collection[Bbox.__name__]
68+
gt_bboxes = [ann.points for ann in bboxes]
69+
gt_labels = [ann.label for ann in bboxes]
70+
for box in bboxes:
71+
poly = Polygon(box.as_polygon())
72+
if self.include_polygons:
73+
gt_polygons.append(poly)
74+
else:
75+
gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0])
76+
elif Ellipse.__name__ in anno_collection:
77+
for ellipse in anno_collection[Ellipse.__name__]:
78+
bbox = Bbox(*ellipse.get_bbox()).points
79+
gt_bboxes.append(bbox)
80+
gt_labels.append(ellipse.label)
81+
poly = Polygon(ellipse.as_polygon(num_points=10))
82+
if self.include_polygons:
83+
gt_polygons.append(poly)
84+
else:
85+
gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0])
86+
else:
87+
warnings.warn(f"No valid annotations found for image {item.id}!", stacklevel=2)
6188

89+
bboxes = np.stack(gt_bboxes, dtype=np.float32, axis=0) if gt_bboxes else np.empty((0, 4))
6290
masks = np.stack(gt_masks, axis=0) if gt_masks else np.zeros((0, *img_shape), dtype=bool)
6391
labels = np.array(gt_labels, dtype=np.int64)
6492

src/otx/core/data/dataset/tile.py

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
import logging as log
99
import operator
1010
import warnings
11+
from collections import defaultdict
1112
from copy import deepcopy
1213
from itertools import product
1314
from typing import TYPE_CHECKING, Callable
1415

1516
import numpy as np
1617
import shapely.geometry as sg
1718
import torch
18-
from datumaro import Bbox, DatasetItem, Image, Polygon
19+
from datumaro import Bbox, DatasetItem, Ellipse, Image, Polygon
1920
from datumaro import Dataset as DmDataset
2021
from datumaro.components.annotation import AnnotationType
2122
from datumaro.plugins.tiling import Tile
@@ -92,6 +93,7 @@ def __init__(
9293
)
9394
self._tile_size = tile_size
9495
self._tile_ann_func_map[AnnotationType.polygon] = OTXTileTransform._tile_polygon
96+
self._tile_ann_func_map[AnnotationType.ellipse] = OTXTileTransform._tile_ellipse
9597
self.with_full_img = with_full_img
9698

9799
@staticmethod
@@ -132,6 +134,45 @@ def _tile_polygon(
132134
attributes=deepcopy(ann.attributes),
133135
)
134136

137+
@staticmethod
138+
def _tile_ellipse(
139+
ann: Ellipse,
140+
roi_box: sg.Polygon,
141+
threshold_drop_ann: float = 0.8,
142+
*args, # noqa: ARG004
143+
**kwargs, # noqa: ARG004
144+
) -> Polygon | None:
145+
polygon = sg.Polygon(ann.get_points(num_points=10))
146+
147+
# NOTE: polygon may be invalid, e.g. self-intersecting
148+
if not roi_box.intersects(polygon) or not polygon.is_valid:
149+
return None
150+
151+
# NOTE: intersection may return a GeometryCollection or MultiPolygon
152+
inter = polygon.intersection(roi_box)
153+
if isinstance(inter, (sg.GeometryCollection, sg.MultiPolygon)):
154+
shapes = [(geom, geom.area) for geom in list(inter.geoms) if geom.is_valid]
155+
if not shapes:
156+
return None
157+
158+
inter, _ = max(shapes, key=operator.itemgetter(1))
159+
160+
if not isinstance(inter, sg.Polygon) and not inter.is_valid:
161+
return None
162+
163+
prop_area = inter.area / polygon.area
164+
165+
if prop_area < threshold_drop_ann:
166+
return None
167+
168+
inter = _apply_offset(inter, roi_box)
169+
170+
return Polygon(
171+
points=[p for xy in inter.exterior.coords for p in xy],
172+
attributes=deepcopy(ann.attributes),
173+
label=ann.label,
174+
)
175+
135176
def _extract_rois(self, image: Image) -> list[BboxIntCoords]:
136177
"""Extracts Tile ROIs from the given image.
137178
@@ -467,26 +508,50 @@ def _get_item_impl(self, index: int) -> TileInstSegDataEntity: # type: ignore[o
467508
img = item.media_as(Image)
468509
img_data, img_shape, _ = self._get_img_data_and_shape(img)
469510

511+
anno_collection: dict[str, list] = defaultdict(list)
512+
for anno in item.annotations:
513+
anno_collection[anno.__class__.__name__].append(anno)
514+
470515
gt_bboxes, gt_labels, gt_masks, gt_polygons = [], [], [], []
471516

472-
for annotation in item.annotations:
473-
if isinstance(annotation, Polygon):
474-
bbox = np.array(annotation.get_bbox(), dtype=np.float32)
517+
# TODO(Eugene): https://jira.devtools.intel.com/browse/CVS-159363
518+
# Temporary solution to handle multiple annotation types.
519+
# Ideally, we should pre-filter annotations during initialization of the dataset.
520+
521+
if Polygon.__name__ in anno_collection: # Polygon for InstSeg has higher priority
522+
for poly in anno_collection[Polygon.__name__]:
523+
bbox = Bbox(*poly.get_bbox()).points
475524
gt_bboxes.append(bbox)
476-
gt_labels.append(annotation.label)
525+
gt_labels.append(poly.label)
477526

478527
if self._dataset.include_polygons:
479-
gt_polygons.append(annotation)
528+
gt_polygons.append(poly)
480529
else:
481-
gt_masks.append(polygon_to_bitmap([annotation], *img_shape)[0])
482-
483-
if empty_anno := len(gt_bboxes) == 0:
484-
warnings.warn(f"Empty annotation for image {item.id}", stacklevel=2)
485-
486-
# convert xywh to xyxy format
487-
bboxes = np.empty((0, 4), dtype=np.float32) if empty_anno else np.stack(gt_bboxes, dtype=np.float32)
488-
bboxes[:, 2:] += bboxes[:, :2]
530+
gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0])
531+
elif Bbox.__name__ in anno_collection:
532+
boxes = anno_collection[Bbox.__name__]
533+
gt_bboxes = [ann.points for ann in boxes]
534+
gt_labels = [ann.label for ann in boxes]
535+
for box in boxes:
536+
poly = Polygon(box.as_polygon())
537+
if self._dataset.include_polygons:
538+
gt_polygons.append(poly)
539+
else:
540+
gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0])
541+
elif Ellipse.__name__ in anno_collection:
542+
for ellipse in anno_collection[Ellipse.__name__]:
543+
bbox = Bbox(*ellipse.get_bbox()).points
544+
gt_bboxes.append(bbox)
545+
gt_labels.append(ellipse.label)
546+
poly = Polygon(ellipse.as_polygon(num_points=10))
547+
if self._dataset.include_polygons:
548+
gt_polygons.append(poly)
549+
else:
550+
gt_masks.append(polygon_to_bitmap([poly], *img_shape)[0])
551+
else:
552+
warnings.warn(f"No valid annotations found for image {item.id}!", stacklevel=2)
489553

554+
bboxes = np.stack(gt_bboxes, dtype=np.float32) if gt_bboxes else np.empty((0, 4), dtype=np.float32)
490555
masks = np.stack(gt_masks, axis=0) if gt_masks else np.empty((0, *img_shape), dtype=bool)
491556
labels = np.array(gt_labels, dtype=np.int64)
492557

0 commit comments

Comments
 (0)