Skip to content

Commit c3749e3

Browse files
authored
Correct Keyerror for h-label cls in label_groups for dm_label_categories using label's id/key (#3932)
Modify label_groups for dm_label_categories with id/key of label
1 parent 53a7d9a commit c3749e3

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

src/otx/core/types/label.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,32 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
229229
"""Get label tree edges information. Each edges represent [child, parent]."""
230230
return [[item.name, item.parent] for item in dm_label_items if item.parent != ""]
231231

232-
all_groups = [label_group.labels for label_group in dm_label_categories.label_groups]
232+
def convert_labels_if_needed(
233+
dm_label_categories: LabelCategories,
234+
label_names: list[str],
235+
) -> list[list[str]]:
236+
# Check if the labels need conversion and create name to ID mapping if required
237+
name_to_id_mapping = None
238+
for label_group in dm_label_categories.label_groups:
239+
if label_group.labels and label_group.labels[0] not in label_names:
240+
name_to_id_mapping = {
241+
attr[len("__name__") :]: category.name
242+
for category in dm_label_categories.items
243+
for attr in category.attributes
244+
if attr.startswith("__name__")
245+
}
246+
break
247+
248+
# If mapping exists, update the labels
249+
if name_to_id_mapping:
250+
for label_group in dm_label_categories.label_groups:
251+
label_group.labels = [name_to_id_mapping.get(label, label) for label in label_group.labels]
252+
253+
# Retrieve all label groups after conversion
254+
return [group.labels for group in dm_label_categories.label_groups]
255+
256+
label_names = [item.name for item in dm_label_categories.items]
257+
all_groups = convert_labels_if_needed(dm_label_categories, label_names)
233258

234259
exclusive_group_info = get_exclusive_group_info(all_groups)
235260
single_label_group_info = get_single_label_group_info(all_groups, exclusive_group_info["num_multiclass_heads"])
@@ -240,7 +265,7 @@ def get_label_tree_edges(dm_label_items: list[LabelCategories]) -> list[list[str
240265
)
241266

242267
return HLabelInfo(
243-
label_names=[item.name for item in dm_label_categories.items],
268+
label_names=label_names,
244269
label_groups=all_groups,
245270
num_multiclass_heads=exclusive_group_info["num_multiclass_heads"],
246271
num_multilabel_classes=single_label_group_info["num_multilabel_classes"],

0 commit comments

Comments
 (0)