@@ -169,10 +169,8 @@ def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInf
169169 dm_label_categories (LabelCategories): the label categories of datumaro.
170170 """
171171
172- def get_exclusive_group_info (all_groups : list [Label | list [Label ]]) -> dict [str , Any ]:
172+ def get_exclusive_group_info (exclusive_groups : list [Label | list [Label ]]) -> dict [str , Any ]:
173173 """Get exclusive group information."""
174- exclusive_groups = [g for g in all_groups if len (g ) > 1 ]
175-
176174 last_logits_pos = 0
177175 num_single_label_classes = 0
178176 head_idx_to_logits_range = {}
@@ -193,12 +191,10 @@ def get_exclusive_group_info(all_groups: list[Label | list[Label]]) -> dict[str,
193191 }
194192
195193 def get_single_label_group_info (
196- all_groups : list [Label | list [Label ]],
194+ single_label_groups : list [Label | list [Label ]],
197195 num_exclusive_groups : int ,
198196 ) -> dict [str , Any ]:
199197 """Get single label group information."""
200- single_label_groups = [g for g in all_groups if len (g ) == 1 ]
201-
202198 class_to_idx = {}
203199
204200 for i , group in enumerate (single_label_groups ):
@@ -256,8 +252,13 @@ def convert_labels_if_needed(
256252 label_names = [item .name for item in dm_label_categories .items ]
257253 all_groups = convert_labels_if_needed (dm_label_categories , label_names )
258254
259- exclusive_group_info = get_exclusive_group_info (all_groups )
260- single_label_group_info = get_single_label_group_info (all_groups , exclusive_group_info ["num_multiclass_heads" ])
255+ exclusive_groups = [g for g in all_groups if len (g ) > 1 ]
256+ exclusive_group_info = get_exclusive_group_info (exclusive_groups )
257+ single_label_groups = [g for g in all_groups if len (g ) == 1 ]
258+ single_label_group_info = get_single_label_group_info (
259+ single_label_groups ,
260+ exclusive_group_info ["num_multiclass_heads" ],
261+ )
261262
262263 merged_class_to_idx = merge_class_to_idx (
263264 exclusive_group_info ["class_to_idx" ],
@@ -268,13 +269,13 @@ def convert_labels_if_needed(
268269
269270 return HLabelInfo (
270271 label_names = label_names ,
271- label_groups = all_groups ,
272+ label_groups = exclusive_groups + single_label_groups ,
272273 num_multiclass_heads = exclusive_group_info ["num_multiclass_heads" ],
273274 num_multilabel_classes = single_label_group_info ["num_multilabel_classes" ],
274275 head_idx_to_logits_range = exclusive_group_info ["head_idx_to_logits_range" ],
275276 num_single_label_classes = exclusive_group_info ["num_single_label_classes" ],
276277 class_to_group_idx = merged_class_to_idx ,
277- all_groups = all_groups ,
278+ all_groups = exclusive_groups + single_label_groups ,
278279 label_to_idx = label_to_idx ,
279280 label_tree_edges = get_label_tree_edges (dm_label_categories .items ),
280281 empty_multiclass_head_indices = [], # consider the label removing case
0 commit comments