Skip to content

Commit 95c651c

Browse files
authored
Fix keypoint annotation filter logic (#4685)
1 parent 88930ca commit 95c651c

File tree

3 files changed

+196
-38
lines changed

3 files changed

+196
-38
lines changed

lib/src/otx/data/module.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def __init__(
9898
self.save_hyperparameters(ignore=["input_size"])
9999

100100
dataset = DmDataset.import_from(self.data_root, format=self.data_format)
101-
if self.task != OTXTaskType.H_LABEL_CLS and not (
102-
self.task == OTXTaskType.KEYPOINT_DETECTION and self.data_format == "arrow"
103-
):
101+
if self.task != OTXTaskType.H_LABEL_CLS:
104102
dataset = pre_filtering(
105103
dataset,
106104
self.data_format,

lib/src/otx/data/utils/pre_filtering.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from functools import partial
1111
from typing import TYPE_CHECKING
1212

13-
from datumaro.components.annotation import Annotation, Bbox, Ellipse, Polygon
13+
from datumaro.components.annotation import Annotation, AnnotationType, Bbox, Ellipse, Points, Polygon
1414
from datumaro.components.dataset import Dataset as DmDataset
1515

1616
from otx.types.task import OTXTaskType
@@ -19,6 +19,14 @@
1919
from datumaro.components.dataset_base import DatasetItem
2020

2121

22+
def get_labels(dataset: DmDataset, task: OTXTaskType) -> list[str]:
23+
"""Get the labels from the dataset."""
24+
# label is funky from arrow dataset
25+
if task == OTXTaskType.KEYPOINT_DETECTION:
26+
return dataset.categories()[AnnotationType.points][0].labels
27+
return dataset.categories()[AnnotationType.label]
28+
29+
2230
def pre_filtering(
2331
dataset: DmDataset,
2432
data_format: str,
@@ -42,7 +50,16 @@ def pre_filtering(
4250
used_background_items = set()
4351
msg = f"There are empty annotation items in train set, Of these, only {unannotated_items_ratio*100}% are used."
4452
warnings.warn(msg, stacklevel=2)
45-
dataset = DmDataset.filter(dataset, partial(is_valid_anno_for_task, task=task), filter_annotations=True)
53+
54+
labels = get_labels(dataset, task)
55+
56+
dataset = DmDataset.filter(
57+
dataset,
58+
partial(is_valid_anno_for_task, task=task, labels=labels),
59+
filter_annotations=True,
60+
)
61+
if task == OTXTaskType.KEYPOINT_DETECTION:
62+
return dataset
4663
dataset = remove_unused_labels(dataset, data_format, ignore_index)
4764
if unannotated_items_ratio > 0:
4865
empty_items = [
@@ -61,7 +78,7 @@ def pre_filtering(
6178
)
6279

6380

64-
def is_valid_annot(item: DatasetItem, annotation: Annotation) -> bool: # noqa: ARG001
81+
def is_valid_annot(item: DatasetItem, annotation: Annotation, labels: list[str]) -> bool: # noqa: ARG001
6582
"""Return whether DatasetItem's annotation is valid."""
6683
if isinstance(annotation, Bbox):
6784
x1, y1, x2, y2 = annotation.points
@@ -79,28 +96,45 @@ def is_valid_annot(item: DatasetItem, annotation: Annotation) -> bool: # noqa:
7996
return True
8097
msg = "There are invalid polygon, they will be filtered out before training."
8198
return False
99+
if isinstance(annotation, Points):
100+
# For keypoint detection, num of (x, y) points should be equal to num of labels
101+
if len(annotation.points) == 0:
102+
msg = "There are invalid points, they will be filtered out before training."
103+
warnings.warn(msg, stacklevel=2)
104+
return False
105+
return len(annotation.points) // 2 == len(labels)
106+
82107
return True
83108

84109

85-
def is_valid_anno_for_task(item: DatasetItem, annotation: Annotation, task: OTXTaskType) -> bool:
110+
def is_valid_anno_for_task(
111+
item: DatasetItem,
112+
annotation: Annotation,
113+
task: OTXTaskType,
114+
labels: list[str],
115+
) -> bool:
86116
"""Return whether DatasetItem's annotation is valid for a specific task.
87117
88118
Args:
89119
item (DatasetItem): The item to be checked.
90120
annotation (Annotation): The annotation to be checked.
91121
task (OTXTaskType): The task type of the dataset.
122+
labels (list[str]): The labels of the dataset.
92123
93124
Returns:
94125
bool: True if the annotation is valid for the task, False otherwise.
95126
"""
96127
if task == OTXTaskType.DETECTION:
97-
return isinstance(annotation, Bbox) and is_valid_annot(item, annotation)
128+
return isinstance(annotation, Bbox) and is_valid_annot(item, annotation, labels)
98129

99130
# Rotated detection is a subset of instance segmentation
100131
if task in [OTXTaskType.INSTANCE_SEGMENTATION, OTXTaskType.ROTATED_DETECTION]:
101-
return isinstance(annotation, (Polygon, Bbox, Ellipse)) and is_valid_annot(item, annotation)
132+
return isinstance(annotation, (Polygon, Bbox, Ellipse)) and is_valid_annot(item, annotation, labels)
133+
134+
if task == OTXTaskType.KEYPOINT_DETECTION:
135+
return isinstance(annotation, Points) and is_valid_annot(item, annotation, labels)
102136

103-
return is_valid_annot(item, annotation)
137+
return is_valid_annot(item, annotation, labels)
104138

105139

106140
def remove_unused_labels(

0 commit comments

Comments
 (0)