Skip to content

Commit 7bdf708

Browse files
authored
Fix labels names in hierarchical config (#3879)
* Fix hierarchical config * Add exceptions handling * Add exceptions checks to other tasks * Fix black
1 parent 4bb9e1c commit 7bdf708

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

src/otx/algorithms/classification/adapters/openvino/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def _async_callback(self, request: Any, callback_args: tuple) -> None:
150150
result_handler(id, annotation, aux_data)
151151

152152
except Exception as e:
153+
logger.exception(e)
153154
self.callback_exceptions.append(e)
154155

155156
def predict(self, image: np.ndarray) -> Tuple[ClassificationResult, AnnotationSceneEntity]:
@@ -280,6 +281,9 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu
280281

281282
self.inferencer.await_all()
282283

284+
if self.inferencer.callback_exceptions:
285+
raise RuntimeError("Inference failed, check the exceptions log.")
286+
283287
self._avg_time_per_image = total_time / len(dataset)
284288
logger.info(f"Avg time per image: {self._avg_time_per_image} secs")
285289
logger.info(f"Total time: {total_time} secs")

src/otx/algorithms/classification/utils/cls_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,15 @@
2323
from otx.api.entities.label import LabelEntity
2424
from otx.api.entities.label_schema import LabelSchemaEntity
2525
from otx.api.serialization.label_mapper import LabelSchemaMapper
26+
from otx.api.utils.labels_utils import get_normalized_label_name
2627

2728

2829
def get_multihead_class_info(label_schema: LabelSchemaEntity): # pylint: disable=too-many-locals
2930
"""Get multihead info by label schema."""
3031
all_groups = label_schema.get_groups(include_empty=False)
3132
all_groups_str = []
3233
for g in all_groups:
33-
group_labels_str = [lbl.name for lbl in g.labels]
34+
group_labels_str = [get_normalized_label_name(lbl) for lbl in g.labels]
3435
all_groups_str.append(group_labels_str)
3536

3637
single_label_groups = [g for g in all_groups_str if len(g) == 1]
@@ -112,7 +113,7 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
112113
all_labels = ""
113114
all_label_ids = ""
114115
for lbl in label_entities:
115-
all_labels += lbl.name.replace(" ", "_") + " "
116+
all_labels += get_normalized_label_name(lbl) + " "
116117
all_label_ids += f"{lbl.id_} "
117118

118119
mapi_config[("model_info", "labels")] = all_labels.strip()
@@ -122,7 +123,9 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
122123
hierarchical_config["cls_heads_info"] = get_multihead_class_info(label_schema)
123124
hierarchical_config["label_tree_edges"] = []
124125
for edge in label_schema.label_tree.edges: # (child, parent)
125-
hierarchical_config["label_tree_edges"].append((edge[0].name, edge[1].name))
126+
hierarchical_config["label_tree_edges"].append(
127+
(get_normalized_label_name(edge[0]), get_normalized_label_name(edge[1]))
128+
)
126129

127130
mapi_config[("model_info", "hierarchical_config")] = json.dumps(hierarchical_config)
128131
return mapi_config
@@ -137,7 +140,7 @@ def get_hierarchical_label_list(hierarchical_cls_heads_info: Dict, labels: List)
137140
hierarchical_labels = []
138141
for label_str, _ in label_to_idx.items():
139142
for label_entity in labels:
140-
if label_entity.name == label_str:
143+
if get_normalized_label_name(label_entity) == label_str:
141144
hierarchical_labels.append(label_entity)
142145
break
143146
return hierarchical_labels

src/otx/algorithms/detection/adapters/openvino/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def _async_callback(self, request: Any, callback_args: tuple) -> None:
155155
result_handler(id, processed_prediciton, features)
156156

157157
except Exception as e:
158+
logger.exception(e)
158159
self.callback_exceptions.append(e)
159160

160161
def enqueue_prediction(self, image: np.ndarray, id: int, result_handler: Any) -> None:
@@ -557,6 +558,9 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu
557558

558559
self.inferencer.await_all()
559560

561+
if self.inferencer.callback_exceptions:
562+
raise RuntimeError("Inference failed, check the exceptions log.")
563+
560564
self._avg_time_per_image = total_time / len(dataset)
561565
logger.info(f"Avg time per image: {self._avg_time_per_image} secs")
562566
logger.info(f"Total time: {total_time} secs")

src/otx/algorithms/segmentation/adapters/openvino/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _async_callback(self, request: Any, callback_args: tuple) -> None:
151151
result_handler(id, annotation, processed_prediciton.feature_vector, processed_prediciton.saliency_map)
152152

153153
except Exception as e:
154+
logger.exception(e)
154155
self.callback_exceptions.append(e)
155156

156157

@@ -254,6 +255,9 @@ def add_prediction(
254255

255256
self.inferencer.await_all()
256257

258+
if self.inferencer.callback_exceptions:
259+
raise RuntimeError("Inference failed, check the exceptions log.")
260+
257261
self._avg_time_per_image = total_time / len(dataset)
258262
logger.info(f"Avg time per image: {self._avg_time_per_image} secs")
259263
logger.info(f"Total time: {total_time} secs")

src/otx/api/utils/labels_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,8 @@ def get_empty_label(label_schema: LabelSchemaEntity) -> Optional[LabelEntity]:
1818
if empty_candidates:
1919
return empty_candidates[0]
2020
return None
21+
22+
23+
def get_normalized_label_name(label: LabelEntity) -> str:
24+
"""Gets a nomalized label name"""
25+
return label.name.replace(" ", "_")

0 commit comments

Comments
 (0)