@@ -229,7 +229,32 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
229229 """Get label tree edges information. Each edges represent [child, parent]."""
230230 return [[item .name , item .parent ] for item in dm_label_items if item .parent != "" ]
231231
232- all_groups = [label_group .labels for label_group in dm_label_categories .label_groups ]
232+ def convert_labels_if_needed (
233+ dm_label_categories : LabelCategories ,
234+ label_names : list [str ],
235+ ) -> list [list [str ]]:
236+ # Check if the labels need conversion and create name to ID mapping if required
237+ name_to_id_mapping = None
238+ for label_group in dm_label_categories .label_groups :
239+ if label_group .labels and label_group .labels [0 ] not in label_names :
240+ name_to_id_mapping = {
241+ attr [len ("__name__" ) :]: category .name
242+ for category in dm_label_categories .items
243+ for attr in category .attributes
244+ if attr .startswith ("__name__" )
245+ }
246+ break
247+
248+ # If mapping exists, update the labels
249+ if name_to_id_mapping :
250+ for label_group in dm_label_categories .label_groups :
251+ label_group .labels = [name_to_id_mapping .get (label , label ) for label in label_group .labels ]
252+
253+ # Retrieve all label groups after conversion
254+ return [group .labels for group in dm_label_categories .label_groups ]
255+
256+ label_names = [item .name for item in dm_label_categories .items ]
257+ all_groups = convert_labels_if_needed (dm_label_categories , label_names )
233258
234259 exclusive_group_info = get_exclusive_group_info (all_groups )
235260 single_label_group_info = get_single_label_group_info (all_groups , exclusive_group_info ["num_multiclass_heads" ])
@@ -240,7 +265,7 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
240265 )
241266
242267 return HLabelInfo (
243- label_names = [ item . name for item in dm_label_categories . items ] ,
268+ label_names = label_names ,
244269 label_groups = all_groups ,
245270 num_multiclass_heads = exclusive_group_info ["num_multiclass_heads" ],
246271 num_multilabel_classes = single_label_group_info ["num_multilabel_classes" ],
0 commit comments