Skip to content

Commit b4b9724

Browse files
committed
Reduce already known hcls config fields
1 parent 5a06e0c commit b4b9724

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

model_api/cpp/models/src/classification_model.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,14 +555,17 @@ std::vector<std::unique_ptr<ClassificationResult>> ClassificationModel::inferBat
555555
return clsResults;
556556
}
557557

558-
HierarchicalConfig::HierarchicalConfig(const std::string& json_repr) {
558+
HierarchicalConfig::HierarchicalConfig(const std::string& json_repr, const std::vector<std::string>& labels) {
559559
nlohmann::json data = nlohmann::json::parse(json_repr);
560560

561561
num_multilabel_heads = data.at("cls_heads_info").at("num_multilabel_classes");
562562
num_multiclass_heads = data.at("cls_heads_info").at("num_multiclass_heads");
563563
num_single_label_classes = data.at("cls_heads_info").at("num_single_label_classes");
564564

565-
data.at("cls_heads_info").at("label_to_idx").get_to(label_to_idx);
565+
int idx = 0;
566+
for (const auto& lbl_name : labels) {
567+
label_to_idx[lbl_name] = idx++;
568+
}
566569
data.at("cls_heads_info").at("all_groups").get_to(all_groups);
567570
data.at("label_tree_edges").get_to(label_tree_edges);
568571

model_api/python/model_api/models/classification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(self, inference_adapter, configuration=dict(), preload=False):
5050
self.raise_error("Hierarchical classification config is empty.")
5151
self.raw_scores_name = self.out_layer_names[0]
5252
self.hierarchical_info = json.loads(self.hierarchical_config)
53+
self.hierarchical_info["cls_heads_info"]["label_to_idx"] = \
54+
{label_name: i for i, label_name in enumerate(self.labels)}
5355

5456
if self.hierarchical_postproc == "probabilistic":
5557
self.labels_resolver = ProbabilisticLabelsResolver(

0 commit comments

Comments
 (0)