Skip to content

Commit b9192de

Browse files
committed
Clean up config retrievals using get_from_any_maps
1 parent d46033b commit b9192de

File tree

4 files changed

+11
-56
lines changed

4 files changed

+11
-56
lines changed

src/cpp/include/tasks/detection/ssd.h

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "adapters/inference_adapter.h"
1111
#include "tasks/results.h"
12+
#include "utils/config.h"
1213
#include "utils/preprocessing.h"
1314

1415
enum SSDOutputMode { single, multi };
@@ -29,20 +30,8 @@ class SSD {
2930

3031
SSD(std::shared_ptr<InferenceAdapter> adapter, cv::Size input_shape) : adapter(adapter), input_shape(input_shape) {
3132
auto config = adapter->getModelConfig();
32-
{
33-
auto iter = config.find("labels");
34-
if (iter != config.end()) {
35-
labels = iter->second.as<std::vector<std::string>>();
36-
} else {
37-
std::cout << "could not find labels from model config" << std::endl;
38-
}
39-
}
40-
{
41-
auto iter = config.find("confidence_threshold");
42-
if (iter != config.end()) {
43-
confidence_threshold = iter->second.as<float>();
44-
}
45-
}
33+
labels = utils::get_from_any_maps("labels", config, {}, labels);
34+
confidence_threshold = utils::get_from_any_maps("confidence_threshold", config, {}, confidence_threshold);
4635
}
4736
std::map<std::string, ov::Tensor> preprocess(cv::Mat);
4837
DetectionResult postprocess(InferenceResult& infResult);

src/cpp/include/tasks/instance_segmentation.h

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include "adapters/inference_adapter.h"
1212
#include "tasks/results.h"
13+
#include "utils/config.h"
1314
#include "utils/vision_pipeline.h"
1415

1516
class InstanceSegmentation {
@@ -30,20 +31,10 @@ class InstanceSegmentation {
3031
});
3132

3233
auto config = adapter->getModelConfig();
33-
auto iter = config.find("labels");
34-
if (iter != config.end()) {
35-
labels = iter->second.as<std::vector<std::string>>();
36-
} else {
37-
std::cout << "could not find labels from model config" << std::endl;
38-
}
39-
40-
{
41-
auto iter = config.find("confidence_threshold");
42-
if (iter != config.end()) {
43-
confidence_threshold = iter->second.as<float>();
44-
}
45-
}
34+
labels = utils::get_from_any_maps("labels", config, {}, labels);
35+
confidence_threshold = utils::get_from_any_maps("confidence_threshold", config, {}, confidence_threshold);
4636
}
37+
4738
static cv::Size serialize(std::shared_ptr<ov::Model>& ov_model);
4839
static InstanceSegmentation load(const std::string& model_path);
4940

src/cpp/include/tasks/semantic_segmentation.h

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "adapters/inference_adapter.h"
1111
#include "tasks/results.h"
12+
#include "utils/config.h"
1213
#include "utils/preprocessing.h"
1314
#include "utils/vision_pipeline.h"
1415

@@ -27,26 +28,9 @@ class SemanticSegmentation {
2728
});
2829

2930
auto config = adapter->getModelConfig();
30-
auto iter = config.find("labels");
31-
if (iter != config.end()) {
32-
labels = iter->second.as<std::vector<std::string>>();
33-
} else {
34-
std::cout << "could not find labels from model config" << std::endl;
35-
}
36-
37-
{
38-
auto iter = config.find("soft_threshold");
39-
if (iter != config.end()) {
40-
soft_threshold = iter->second.as<float>();
41-
}
42-
}
43-
44-
{
45-
auto iter = config.find("blur_strength");
46-
if (iter != config.end()) {
47-
blur_strength = iter->second.as<int>();
48-
}
49-
}
31+
labels = utils::get_from_any_maps("labels", config, {}, labels);
32+
soft_threshold = utils::get_from_any_maps("soft_threshold", config, {}, soft_threshold);
33+
blur_strength = utils::get_from_any_maps("blur_strength", config, {}, blur_strength);
5034
}
5135

5236
static cv::Size serialize(std::shared_ptr<ov::Model>& ov_model);

src/cpp/src/tasks/detection/ssd.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -148,15 +148,6 @@ void SSD::prepareMultipleOutputs(std::shared_ptr<ov::Model> ov_model) {
148148
for (auto& name : output_names) {
149149
std::cout << "output name: " << name << std::endl;
150150
}
151-
152-
// ov::preprocess::PrePostProcessor ppp(ov_model);
153-
154-
// for (const auto& output_name : output_names) {
155-
// if (output_name != "labels") { //TODO: Discover why this isnt needed in original?
156-
// ppp.output(output_name).tensor().set_element_type(ov::element::f32);
157-
// }
158-
// }
159-
// ov_model = ppp.build();
160151
}
161152

162153
std::vector<std::string> SSD::filterOutXai(const std::vector<std::string>& names) {

0 commit comments

Comments
 (0)