|
26 | 26 | from otx.api.utils.labels_utils import get_normalized_label_name |
27 | 27 |
|
28 | 28 |
|
29 | | -def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disable=too-many-locals |
| 29 | +def get_multihead_class_info( |
| 30 | + label_schema: LabelSchemaEntity, normalize_labels: bool = False |
| 31 | +): # pylint: disable=too-many-locals |
30 | 32 | """Get multihead info by label schema.""" |
31 | 33 | all_groups = label_schema.get_groups(include_empty=False) |
32 | 34 | all_groups_str = [] |
33 | 35 | for g in all_groups: |
34 | | - group_labels_str = [get_normalized_label_name(lbl) for lbl in g.labels] |
| 36 | + if normalize_labels: |
| 37 | + group_labels_str = [get_normalized_label_name(lbl) for lbl in g.labels] |
| 38 | + else: |
| 39 | + group_labels_str = [lbl.name for lbl in g.labels] |
35 | 40 | all_groups_str.append(group_labels_str) |
36 | 41 |
|
37 | 42 | single_label_groups = [g for g in all_groups_str if len(g) == 1] |
@@ -77,7 +82,7 @@ def get_cls_inferencer_configuration(label_schema: LabelSchemaEntity): |
77 | 82 | hierarchical = not multilabel and len(label_schema.get_groups(False)) > 1 |
78 | 83 | multihead_class_info = {} |
79 | 84 | if hierarchical: |
80 | | - multihead_class_info = get_multihead_class_info(label_schema) |
| 85 | + multihead_class_info = get_multihead_class_info(label_schema, normalize_labels=True) |
81 | 86 | return { |
82 | 87 | "multilabel": multilabel, |
83 | 88 | "hierarchical": hierarchical, |
@@ -120,7 +125,7 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c |
120 | 125 | mapi_config[("model_info", "label_ids")] = all_label_ids.strip() |
121 | 126 |
|
122 | 127 | hierarchical_config = {} |
123 | | - hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema) |
| 128 | + hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema, normalize_labels=True) |
124 | 129 | hierarchical_config["label_tree_edges"] = [] |
125 | 130 | for edge in label_schema.label_tree.edges: # (child, parent) |
126 | 131 | hierarchical_config["label_tree_edges"].append( |
|
0 commit comments