diff --git a/model_api/cpp/models/src/classification_model.cpp b/model_api/cpp/models/src/classification_model.cpp index 80b7fa75..29ffad84 100644 --- a/model_api/cpp/models/src/classification_model.cpp +++ b/model_api/cpp/models/src/classification_model.cpp @@ -555,14 +555,17 @@ std::vector> ClassificationModel::inferBat return clsResults; } -HierarchicalConfig::HierarchicalConfig(const std::string& json_repr) { +HierarchicalConfig::HierarchicalConfig(const std::string& json_repr, const std::vector& labels) { nlohmann::json data = nlohmann::json::parse(json_repr); num_multilabel_heads = data.at("cls_heads_info").at("num_multilabel_classes"); num_multiclass_heads = data.at("cls_heads_info").at("num_multiclass_heads"); num_single_label_classes = data.at("cls_heads_info").at("num_single_label_classes"); - data.at("cls_heads_info").at("label_to_idx").get_to(label_to_idx); + int idx = 0; + for (const auto& lbl_name : labels) { + label_to_idx[lbl_name] = idx++; + } data.at("cls_heads_info").at("all_groups").get_to(all_groups); data.at("label_tree_edges").get_to(label_tree_edges); diff --git a/model_api/python/model_api/models/classification.py b/model_api/python/model_api/models/classification.py index 18a7beb2..29d52044 100644 --- a/model_api/python/model_api/models/classification.py +++ b/model_api/python/model_api/models/classification.py @@ -50,6 +50,9 @@ def __init__(self, inference_adapter, configuration=dict(), preload=False): self.raise_error("Hierarchical classification config is empty.") self.raw_scores_name = self.out_layer_names[0] self.hierarchical_info = json.loads(self.hierarchical_config) + self.hierarchical_info["cls_heads_info"]["label_to_idx"] = { + label_name: i for i, label_name in enumerate(self.labels) + } if self.hierarchical_postproc == "probabilistic": self.labels_resolver = ProbabilisticLabelsResolver(