Skip to content

Commit 9e449b1

Browse files
committed
Merge onnx config to OV model right after loading
1 parent bdd44ca commit 9e449b1

File tree

14 files changed

+48
-50
lines changed

14 files changed

+48
-50
lines changed

src/cpp/include/adapters/openvino_adapter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class OpenVINOInferenceAdapter : public InferenceAdapter {
3939
virtual std::vector<std::string> getOutputNames() const override;
4040
virtual const ov::AnyMap& getModelConfig() const override;
4141

42-
void applyModelTransform(std::function<ov::AnyMap(std::shared_ptr<ov::Model>&, const ov::AnyMap&)> t);
42+
void applyModelTransform(std::function<void(std::shared_ptr<ov::Model>&)> t);
4343

4444
protected:
4545
void initInputsOutputs();

src/cpp/include/tasks/anomaly.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Anomaly {
3838
input_shape.height = utils::get_from_any_maps("orig_height", config, {}, input_shape.height);
3939
}
4040

41-
static ov::AnyMap serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config);
41+
static void serialize(std::shared_ptr<ov::Model>& ov_model);
4242
static Anomaly load(const std::string& model_path);
4343

4444
AnomalyResult infer(cv::Mat image);

src/cpp/include/tasks/classification.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Classification {
5454
}
5555
}
5656

57-
static ov::AnyMap serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config);
57+
static void serialize(std::shared_ptr<ov::Model>& ov_model);
5858
static Classification load(const std::string& model_path);
5959

6060
ClassificationResult infer(cv::Mat image);

src/cpp/include/tasks/detection/ssd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SSD {
3838
std::map<std::string, ov::Tensor> preprocess(cv::Mat);
3939
DetectionResult postprocess(InferenceResult& infResult);
4040

41-
static ov::AnyMap serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config);
41+
static void serialize(std::shared_ptr<ov::Model>& ov_model);
4242

4343
SSDOutputMode output_mode;
4444

src/cpp/include/tasks/instance_segmentation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class InstanceSegmentation {
3535
input_shape.height = utils::get_from_any_maps("orig_height", config, {}, input_shape.width);
3636
}
3737

38-
static ov::AnyMap serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config);
38+
static void serialize(std::shared_ptr<ov::Model>& ov_model);
3939
static InstanceSegmentation load(const std::string& model_path);
4040

4141
InstanceSegmentationResult infer(cv::Mat image);

src/cpp/include/tasks/semantic_segmentation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class SemanticSegmentation {
3333
blur_strength = utils::get_from_any_maps("blur_strength", config, {}, blur_strength);
3434
}
3535

36-
static ov::AnyMap serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config);
36+
static void serialize(std::shared_ptr<ov::Model>& ov_model);
3737
static SemanticSegmentation load(const std::string& model_path);
3838

3939
std::map<std::string, ov::Tensor> preprocess(cv::Mat);

src/cpp/include/utils/config.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ inline bool get_from_any_maps(const std::string& key,
4444

4545
ov::AnyMap get_config_from_onnx(const std::string& model_path);
4646

47+
void add_ov_model_info(std::shared_ptr<ov::Model> model, const ov::AnyMap& config);
48+
4749
inline bool model_has_embedded_processing(std::shared_ptr<ov::Model> model) {
4850
if (model->has_rt_info("model_info")) {
4951
auto model_info = model->get_rt_info<ov::AnyMap>("model_info");

src/cpp/src/adapters/openvino_adapter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,18 +50,18 @@ void OpenVINOInferenceAdapter::loadModel(const std::string& modelPath,
5050
modelConfig = model->get_rt_info<ov::AnyMap>("model_info");
5151
} else if (modelPath.find("onnx") != std::string::npos || modelPath.find("ONNX") != std::string::npos) {
5252
modelConfig = utils::get_config_from_onnx(modelPath);
53+
utils::add_ov_model_info(model, modelConfig);
5354
}
5455
if (preCompile) {
5556
compileModel(device, adapterConfig);
5657
}
5758
}
5859

59-
void OpenVINOInferenceAdapter::applyModelTransform(
60-
std::function<ov::AnyMap(std::shared_ptr<ov::Model>&, const ov::AnyMap&)> t) {
60+
void OpenVINOInferenceAdapter::applyModelTransform(std::function<void(std::shared_ptr<ov::Model>&)> t) {
6161
if (!model) {
6262
throw std::runtime_error("Model is not loaded");
6363
}
64-
modelConfig = t(model, modelConfig);
64+
t(model);
6565
}
6666

6767
void OpenVINOInferenceAdapter::infer(const InferenceInput& input, InferenceOutput& output) {

src/cpp/src/tasks/anomaly.cpp

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
#include "utils/preprocessing.h"
55
#include "utils/tensor.h"
66

7-
ov::AnyMap Anomaly::serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config) {
7+
void Anomaly::serialize(std::shared_ptr<ov::Model>& ov_model) {
88
if (utils::model_has_embedded_processing(ov_model)) {
99
std::cout << "model already was serialized" << std::endl;
10-
return input_config;
10+
return;
1111
}
1212

1313
auto input = ov_model->inputs().front();
@@ -26,13 +26,13 @@ ov::AnyMap Anomaly::serialize(std::shared_ptr<ov::Model>& ov_model, const ov::An
2626

2727
std::vector<float> scale_values;
2828
std::vector<float> mean_values;
29-
30-
auto config(input_config);
31-
32-
reverse_input_channels =
33-
utils::get_from_any_maps("reverse_input_channels", config, ov::AnyMap{}, reverse_input_channels);
34-
scale_values = utils::get_from_any_maps("scale_values", config, ov::AnyMap{}, scale_values);
35-
mean_values = utils::get_from_any_maps("mean_values", config, ov::AnyMap{}, mean_values);
29+
if (ov_model->has_rt_info("model_info")) {
30+
auto config = ov_model->get_rt_info<ov::AnyMap>("model_info");
31+
reverse_input_channels =
32+
utils::get_from_any_maps("reverse_input_channels", config, ov::AnyMap{}, reverse_input_channels);
33+
scale_values = utils::get_from_any_maps("scale_values", config, ov::AnyMap{}, scale_values);
34+
mean_values = utils::get_from_any_maps("mean_values", config, ov::AnyMap{}, mean_values);
35+
}
3636

3737
auto input_shape = ov::Shape{shape[ov::layout::width_idx(layout)], shape[ov::layout::height_idx(layout)]};
3838

@@ -47,10 +47,8 @@ ov::AnyMap Anomaly::serialize(std::shared_ptr<ov::Model>& ov_model, const ov::An
4747
mean_values,
4848
scale_values);
4949

50-
config["orig_width"] = std::to_string(input_shape[0]);
51-
config["orig_height"] = std::to_string(input_shape[1]);
52-
53-
return config;
50+
ov_model->set_rt_info(input_shape[0], "model_info", "orig_width");
51+
ov_model->set_rt_info(input_shape[1], "model_info", "orig_height");
5452
}
5553

5654
Anomaly Anomaly::load(const std::string& model_path) {

src/cpp/src/tasks/classification.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ std::vector<size_t> get_non_xai_output_indices(const std::vector<ov::Output<ov::
8282
}
8383
} // namespace
8484

85-
ov::AnyMap Classification::serialize(std::shared_ptr<ov::Model>& ov_model, const ov::AnyMap& input_config) {
85+
void Classification::serialize(std::shared_ptr<ov::Model>& ov_model) {
8686
if (utils::model_has_embedded_processing(ov_model)) {
8787
std::cout << "model already was serialized" << std::endl;
88-
return input_config;
88+
return;
8989
}
9090
// --------------------------- Configure input & output -------------------------------------------------
9191
// --------------------------- Prepare input ------------------------------------------------------
92-
auto config(input_config);
92+
auto config = ov_model->has_rt_info("model_info") ? ov_model->get_rt_info<ov::AnyMap>("model_info") : ov::AnyMap{};
9393
std::string layout = "";
9494
layout = utils::get_from_any_maps("layout", config, {}, layout);
9595
auto inputsLayouts = utils::parseLayoutString(layout);
@@ -176,10 +176,8 @@ ov::AnyMap Classification::serialize(std::shared_ptr<ov::Model>& ov_model, const
176176
addOrFindSoftmaxAndTopkOutputs(ov_model, topk, output_raw_scores);
177177
}
178178

179-
config["orig_width"] = std::to_string(input_shape[0]);
180-
config["orig_height"] = std::to_string(input_shape[1]);
181-
182-
return config;
179+
ov_model->set_rt_info(input_shape[0], "model_info", "orig_width");
180+
ov_model->set_rt_info(input_shape[1], "model_info", "orig_height");
183181
}
184182

185183
Classification Classification::load(const std::string& model_path) {

0 commit comments

Comments
 (0)