1010from functools import partial
1111from typing import TYPE_CHECKING
1212
13- from datumaro .components .annotation import Annotation , Bbox , Ellipse , Polygon
13+ from datumaro .components .annotation import Annotation , AnnotationType , Bbox , Ellipse , Points , Polygon
1414from datumaro .components .dataset import Dataset as DmDataset
1515
1616from otx .types .task import OTXTaskType
1919 from datumaro .components .dataset_base import DatasetItem
2020
2121
22+ def get_labels (dataset : DmDataset , task : OTXTaskType ) -> list [str ]:
23+ """Get the labels from the dataset."""
24+ # label is funky from arrow dataset
25+ if task == OTXTaskType .KEYPOINT_DETECTION :
26+ return dataset .categories ()[AnnotationType .points ][0 ].labels
27+ return dataset .categories ()[AnnotationType .label ]
28+
29+
2230def pre_filtering (
2331 dataset : DmDataset ,
2432 data_format : str ,
@@ -42,7 +50,16 @@ def pre_filtering(
4250 used_background_items = set ()
4351 msg = f"There are empty annotation items in train set, Of these, only { unannotated_items_ratio * 100 } % are used."
4452 warnings .warn (msg , stacklevel = 2 )
45- dataset = DmDataset .filter (dataset , partial (is_valid_anno_for_task , task = task ), filter_annotations = True )
53+
54+ labels = get_labels (dataset , task )
55+
56+ dataset = DmDataset .filter (
57+ dataset ,
58+ partial (is_valid_anno_for_task , task = task , labels = labels ),
59+ filter_annotations = True ,
60+ )
61+ if task == OTXTaskType .KEYPOINT_DETECTION :
62+ return dataset
4663 dataset = remove_unused_labels (dataset , data_format , ignore_index )
4764 if unannotated_items_ratio > 0 :
4865 empty_items = [
@@ -61,7 +78,7 @@ def pre_filtering(
6178 )
6279
6380
64- def is_valid_annot (item : DatasetItem , annotation : Annotation ) -> bool : # noqa: ARG001
81+ def is_valid_annot (item : DatasetItem , annotation : Annotation , labels : list [ str ] ) -> bool : # noqa: ARG001
6582 """Return whether DatasetItem's annotation is valid."""
6683 if isinstance (annotation , Bbox ):
6784 x1 , y1 , x2 , y2 = annotation .points
@@ -79,28 +96,45 @@ def is_valid_annot(item: DatasetItem, annotation: Annotation) -> bool: # noqa:
7996 return True
8097 msg = "There are invalid polygon, they will be filtered out before training."
8198 return False
99+ if isinstance (annotation , Points ):
100+ # For keypoint detection, num of (x, y) points should be equal to num of labels
101+ if len (annotation .points ) == 0 :
102+ msg = "There are invalid points, they will be filtered out before training."
103+ warnings .warn (msg , stacklevel = 2 )
104+ return False
105+ return len (annotation .points ) // 2 == len (labels )
106+
82107 return True
83108
84109
85- def is_valid_anno_for_task (item : DatasetItem , annotation : Annotation , task : OTXTaskType ) -> bool :
110+ def is_valid_anno_for_task (
111+ item : DatasetItem ,
112+ annotation : Annotation ,
113+ task : OTXTaskType ,
114+ labels : list [str ],
115+ ) -> bool :
86116 """Return whether DatasetItem's annotation is valid for a specific task.
87117
88118 Args:
89119 item (DatasetItem): The item to be checked.
90120 annotation (Annotation): The annotation to be checked.
91121 task (OTXTaskType): The task type of the dataset.
122+ labels (list[str]): The labels of the dataset.
92123
93124 Returns:
94125 bool: True if the annotation is valid for the task, False otherwise.
95126 """
96127 if task == OTXTaskType .DETECTION :
97- return isinstance (annotation , Bbox ) and is_valid_annot (item , annotation )
128+ return isinstance (annotation , Bbox ) and is_valid_annot (item , annotation , labels )
98129
99130 # Rotated detection is a subset of instance segmentation
100131 if task in [OTXTaskType .INSTANCE_SEGMENTATION , OTXTaskType .ROTATED_DETECTION ]:
101- return isinstance (annotation , (Polygon , Bbox , Ellipse )) and is_valid_annot (item , annotation )
132+ return isinstance (annotation , (Polygon , Bbox , Ellipse )) and is_valid_annot (item , annotation , labels )
133+
134+ if task == OTXTaskType .KEYPOINT_DETECTION :
135+ return isinstance (annotation , Points ) and is_valid_annot (item , annotation , labels )
102136
103- return is_valid_annot (item , annotation )
137+ return is_valid_annot (item , annotation , labels )
104138
105139
106140def remove_unused_labels (
0 commit comments