Skip to content

Commit 4f9c2f1

Browse files
author
Galina Zalesskaya
authored
Fix label list order for h-label classification (#2440)
* Fix label list for h-label cls * Fix unit tests
1 parent 6b09e65 commit 4f9c2f1

File tree

5 files changed

+50
-4
lines changed

5 files changed

+50
-4
lines changed

src/otx/algorithms/classification/adapters/openvino/task.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from otx.algorithms.classification.utils import (
3737
get_cls_deploy_config,
3838
get_cls_inferencer_configuration,
39+
get_hierarchical_label_list,
3940
)
4041
from otx.algorithms.common.utils import OTXOpenVinoDataLoader
4142
from otx.algorithms.common.utils.ir import check_if_quantized
@@ -228,12 +229,18 @@ def add_prediction(id: int, predicted_scene: AnnotationSceneEntity, aux_data: tu
228229
if saliency_map is not None and repr_vector is not None:
229230
feature_vec_media = TensorEntity(name="representation_vector", numpy=repr_vector.reshape(-1))
230231
dataset_item.append_metadata_item(feature_vec_media, model=self.model)
232+
label_list = self.task_environment.get_labels()
233+
# Fix the order for hierarchical labels to adjust classes with model outputs
234+
if self.inferencer.model.hierarchical:
235+
label_list = get_hierarchical_label_list(
236+
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
237+
)
231238

232239
add_saliency_maps_to_dataset_item(
233240
dataset_item=dataset_item,
234241
saliency_map=saliency_map,
235242
model=self.model,
236-
labels=self.task_environment.get_labels(),
243+
labels=label_list,
237244
predicted_scored_labels=item_labels,
238245
explain_predicted_classes=explain_predicted_classes,
239246
process_saliency_maps=process_saliency_maps,
@@ -284,6 +291,12 @@ def explain(
284291
explain_predicted_classes = explain_parameters.explain_predicted_classes
285292

286293
dataset_size = len(dataset)
294+
label_list = self.task_environment.get_labels()
295+
# Fix the order for hierarchical labels to adjust classes with model outputs
296+
if self.inferencer.model.hierarchical:
297+
label_list = get_hierarchical_label_list(
298+
self.inferencer.model.hierarchical_info["cls_heads_info"], label_list
299+
)
287300
for i, dataset_item in enumerate(dataset, 1):
288301
predicted_scene, _, saliency_map, _, _ = self.inferencer.predict(dataset_item.numpy)
289302
if saliency_map is None:
@@ -298,7 +311,7 @@ def explain(
298311
dataset_item=dataset_item,
299312
saliency_map=saliency_map,
300313
model=self.model,
301-
labels=self.task_environment.get_labels(),
314+
labels=label_list,
302315
predicted_scored_labels=item_labels,
303316
explain_predicted_classes=explain_predicted_classes,
304317
process_saliency_maps=process_saliency_maps,

src/otx/algorithms/classification/task.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
get_cls_deploy_config,
2929
get_cls_inferencer_configuration,
3030
get_cls_model_api_configuration,
31+
get_hierarchical_label_list,
3132
)
3233
from otx.algorithms.classification.utils import (
3334
get_multihead_class_info as get_hierarchical_info,
@@ -345,6 +346,10 @@ def _add_predictions_to_dataset(
345346

346347
dataset_size = len(dataset)
347348
pos_thr = 0.5
349+
label_list = self._labels
350+
# Fix the order for hierarchical labels to adjust classes with model outputs
351+
if self._hierarchical:
352+
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
348353
for i, (dataset_item, prediction_items) in enumerate(zip(dataset, prediction_results)):
349354
prediction_item, feature_vector, saliency_map = prediction_items
350355
if any(np.isnan(prediction_item)):
@@ -373,7 +378,7 @@ def _add_predictions_to_dataset(
373378
dataset_item=dataset_item,
374379
saliency_map=saliency_map,
375380
model=self._task_environment.model,
376-
labels=self._labels,
381+
labels=label_list,
377382
predicted_scored_labels=item_labels,
378383
explain_predicted_classes=explain_predicted_classes,
379384
process_saliency_maps=process_saliency_maps,
@@ -436,13 +441,17 @@ def _add_explanations_to_dataset(
436441
):
437442
"""Loop over dataset again and assign saliency maps."""
438443
dataset_size = len(dataset)
444+
label_list = self._labels
445+
# Fix the order for hierarchical labels to adjust classes with model outputs
446+
if self._hierarchical:
447+
label_list = get_hierarchical_label_list(self._hierarchical_info, label_list)
439448
for i, (dataset_item, prediction_item, saliency_map) in enumerate(zip(dataset, predictions, saliency_maps)):
440449
item_labels = self._get_item_labels(prediction_item, pos_thr=0.5)
441450
add_saliency_maps_to_dataset_item(
442451
dataset_item=dataset_item,
443452
saliency_map=saliency_map,
444453
model=self._task_environment.model,
445-
labels=self._labels,
454+
labels=label_list,
446455
predicted_scored_labels=item_labels,
447456
explain_predicted_classes=explain_predicted_classes,
448457
process_saliency_maps=process_saliency_maps,

src/otx/algorithms/classification/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
get_cls_deploy_config,
99
get_cls_inferencer_configuration,
1010
get_cls_model_api_configuration,
11+
get_hierarchical_label_list,
1112
get_multihead_class_info,
1213
)
1314

1415
__all__ = [
16+
"get_hierarchical_label_list",
1517
"get_multihead_class_info",
1618
"get_cls_inferencer_configuration",
1719
"get_cls_deploy_config",

src/otx/algorithms/classification/utils/cls_utils.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,24 @@ def get_cls_model_api_configuration(label_schema: LabelSchemaEntity, inference_c
117117

118118
mapi_config[("model_info", "hierarchical_config")] = json.dumps(hierarchical_config)
119119
return mapi_config
120+
121+
122+
def get_hierarchical_label_list(hierarchical_info, labels):
123+
"""Return hierarchical labels list which is adjusted to model outputs classes."""
124+
hierarchical_labels = []
125+
for head_idx in range(hierarchical_info["num_multiclass_heads"]):
126+
logits_begin, logits_end = hierarchical_info["head_idx_to_logits_range"][str(head_idx)]
127+
for logit in range(0, logits_end - logits_begin):
128+
label_str = hierarchical_info["all_groups"][head_idx][logit]
129+
label_idx = hierarchical_info["label_to_idx"][label_str]
130+
hierarchical_labels.append(labels[label_idx])
131+
132+
if hierarchical_info["num_multilabel_classes"]:
133+
logits_begin = hierarchical_info["num_single_label_classes"]
134+
logits_end = len(labels)
135+
for logit_idx, logit in enumerate(range(0, logits_end - logits_begin)):
136+
label_str_idx = hierarchical_info["num_multiclass_heads"] + logit_idx
137+
label_str = hierarchical_info["all_groups"][label_str_idx][0]
138+
label_idx = hierarchical_info["label_to_idx"][label_str]
139+
hierarchical_labels.append(labels[label_idx])
140+
return hierarchical_labels

tests/unit/algorithms/classification/tasks/test_classification_openvino_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def test_explain(self, mocker):
182182
self.fake_input,
183183
),
184184
)
185+
self.cls_ov_task.inferencer.model.hierarchical = False
185186
updpated_dataset = self.cls_ov_task.explain(self.dataset)
186187

187188
assert updpated_dataset is not None

0 commit comments

Comments
 (0)