@@ -80,17 +80,21 @@ def _get_item_impl(self, index: int) -> MultilabelClsDataEntity | None:
8080 ignored_labels : list [int ] = [] # This should be assigned form item
8181 img_data , img_shape , _ = self ._get_img_data_and_shape (img )
8282
83- label_anns = []
83+ label_ids = set ()
8484 for ann in item .annotations :
85+ # multilabel information stored in 'multi_label_ids' attribute when the source format is arrow
86+ if "multi_label_ids" in ann .attributes :
87+ for lbl_idx in ann .attributes ["multi_label_ids" ]:
88+ label_ids .add (lbl_idx )
89+
8590 if isinstance (ann , Label ):
86- label_anns . append (ann )
91+ label_ids . add (ann . label )
8792 else :
8893 # If the annotation is not Label, it should be converted to Label.
8994 # For Chained Task: Detection (Bbox) -> Classification (Label)
9095 label = Label (label = ann .label )
91- if label not in label_anns :
92- label_anns .append (label )
93- labels = torch .as_tensor ([ann .label for ann in label_anns ])
96+ label_ids .add (label .label )
97+ labels = torch .as_tensor (list (label_ids ))
9498
9599 entity = MultilabelClsDataEntity (
96100 image = img_data ,
@@ -128,13 +132,22 @@ def __init__(self, **kwargs) -> None:
128132 self .dm_categories = self .dm_subset .categories ()[AnnotationType .label ]
129133
130134 # Hlabel classification used HLabelInfo to insert the HLabelData.
131- self .label_info = HLabelInfo .from_dm_label_groups (self .dm_categories )
135+ if self .data_format == "arrow" :
136+ # arrow format stores label IDs as names, have to deal with that here
137+ self .label_info = HLabelInfo .from_dm_label_groups_arrow (self .dm_categories )
138+ else :
139+ self .label_info = HLabelInfo .from_dm_label_groups (self .dm_categories )
140+
141+ self .id_to_name_mapping = dict (zip (self .label_info .label_ids , self .label_info .label_names ))
142+ self .id_to_name_mapping ["" ] = ""
143+
132144 if self .label_info .num_multiclass_heads == 0 :
133145 msg = "The number of multiclass heads should be larger than 0."
134146 raise ValueError (msg )
135147
136- for dm_item in self .dm_subset :
137- self ._add_ancestors (dm_item .annotations )
148+ if self .data_format != "arrow" :
149+ for dm_item in self .dm_subset :
150+ self ._add_ancestors (dm_item .annotations )
138151
139152 def _add_ancestors (self , label_anns : list [Label ]) -> None :
140153 """Add ancestors recursively if some label miss the ancestor information.
@@ -149,14 +162,16 @@ def _add_ancestors(self, label_anns: list[Label]) -> None:
149162 """
150163
151164 def _label_idx_to_name (idx : int ) -> str :
152- return self .label_info . label_names [idx ]
165+ return self .dm_categories [idx ]. name
153166
154167 def _label_name_to_idx (name : str ) -> int :
155168 indices = [idx for idx , val in enumerate (self .label_info .label_names ) if val == name ]
156169 return indices [0 ]
157170
158171 def _get_label_group_idx (label_name : str ) -> int :
159172 if isinstance (self .label_info , HLabelInfo ):
173+ if self .data_format == "arrow" :
174+ return self .label_info .class_to_group_idx [self .id_to_name_mapping [label_name ]][0 ]
160175 return self .label_info .class_to_group_idx [label_name ][0 ]
161176 msg = f"self.label_info should have HLabelInfo type, got { type (self .label_info )} "
162177 raise ValueError (msg )
@@ -197,17 +212,22 @@ def _get_item_impl(self, index: int) -> HlabelClsDataEntity | None:
197212 ignored_labels : list [int ] = [] # This should be assigned form item
198213 img_data , img_shape , _ = self ._get_img_data_and_shape (img )
199214
200- label_anns = []
215+ label_ids = set ()
201216 for ann in item .annotations :
217+ # in h-cls scenario multilabel information stored in 'multi_label_ids' attribute
218+ if "multi_label_ids" in ann .attributes :
219+ for lbl_idx in ann .attributes ["multi_label_ids" ]:
220+ label_ids .add (lbl_idx )
221+
202222 if isinstance (ann , Label ):
203- label_anns . append (ann )
223+ label_ids . add (ann . label )
204224 else :
205225 # If the annotation is not Label, it should be converted to Label.
206226 # For Chained Task: Detection (Bbox) -> Classification (Label)
207227 label = Label (label = ann .label )
208- if label not in label_anns :
209- label_anns . append ( label )
210- hlabel_labels = self ._convert_label_to_hlabel_format (label_anns , ignored_labels )
228+ label_ids . add ( label . label )
229+
230+ hlabel_labels = self ._convert_label_to_hlabel_format ([ Label ( label = idx ) for idx in label_ids ] , ignored_labels )
211231
212232 entity = HlabelClsDataEntity (
213233 image = img_data ,
@@ -256,18 +276,18 @@ def _convert_label_to_hlabel_format(self, label_anns: list[Label], ignored_label
256276 class_indices [i ] = - 1
257277
258278 for ann in label_anns :
259- ann_name = self .dm_categories .items [ann .label ].name
260- ann_parent = self .dm_categories .items [ann .label ].parent
279+ if self .data_format == "arrow" :
280+ # skips unknown labels for instance, the empty one
281+ if self .dm_categories .items [ann .label ].name not in self .id_to_name_mapping :
282+ continue
283+ ann_name = self .id_to_name_mapping [self .dm_categories .items [ann .label ].name ]
284+ else :
285+ ann_name = self .dm_categories .items [ann .label ].name
261286 group_idx , in_group_idx = self .label_info .class_to_group_idx [ann_name ]
262- (parent_group_idx , parent_in_group_idx ) = (
263- self .label_info .class_to_group_idx [ann_parent ] if ann_parent else (None , None )
264- )
265287
266288 if group_idx < num_multiclass_heads :
267289 class_indices [group_idx ] = in_group_idx
268- if parent_group_idx is not None and parent_in_group_idx is not None :
269- class_indices [parent_group_idx ] = parent_in_group_idx
270- elif not ignored_labels or ann .label not in ignored_labels :
290+ elif ann .label not in ignored_labels :
271291 class_indices [num_multiclass_heads + in_group_idx ] = 1
272292 else :
273293 class_indices [num_multiclass_heads + in_group_idx ] = - 1
0 commit comments