Skip to content

Commit abf9dca

Browse files
committed
Instantiate all tasks with user config and replicate old api
Use old create_model so that api doesn't change.
1 parent d9cec5c commit abf9dca

File tree

14 files changed

+88
-79
lines changed

14 files changed

+88
-79
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Training Extensions embed all the metadata required for inference into model fil
5353
```
5454

5555
- Build library:
56+
5657
- Create `build` folder and navigate into it:
5758
<!-- prettier-ignore-start -->
5859

@@ -61,6 +62,7 @@ Training Extensions embed all the metadata required for inference into model fil
6162
```
6263

6364
<!-- prettier-ignore-end -->
65+
6466
- Run cmake:
6567

6668
```bash

examples/cpp/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ This example demonstrates how to use a C++ API of OpenVINO Model API for synchro
1616
```
1717

1818
- Build example:
19+
1920
- Create `build` folder and navigate into it:
2021
<!-- prettier-ignore-start -->
2122

examples/cpp/main.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ int main(int argc, char* argv[]) try {
3030
}
3131

3232
// Instantiate Object Detection model
33-
auto model = DetectionModel::load(argv[1], {}); // works with SSD models. Download it using Python Model API
33+
auto model =
34+
DetectionModel::create_model(argv[1], {}); // works with SSD models. Download it using Python Model API
3435

3536
// Run the inference
3637
auto result = model.infer(image);

src/cpp/include/tasks/anomaly.h

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Anomaly {
1818
std::shared_ptr<InferenceAdapter> adapter;
1919
VisionPipeline<AnomalyResult> pipeline;
2020

21-
Anomaly(std::shared_ptr<InferenceAdapter> adapter) : adapter(adapter) {
21+
Anomaly(std::shared_ptr<InferenceAdapter> adapter, const ov::AnyMap& user_config) : adapter(adapter) {
2222
pipeline = VisionPipeline<AnomalyResult>(
2323
adapter,
2424
[&](cv::Mat image) {
@@ -28,18 +28,19 @@ class Anomaly {
2828
return postprocess(result);
2929
});
3030

31-
auto config = adapter->getModelConfig();
32-
image_threshold = utils::get_from_any_maps("image_threshold", config, {}, image_threshold);
33-
pixel_threshold = utils::get_from_any_maps("pixel_threshold", config, {}, pixel_threshold);
34-
normalization_scale = utils::get_from_any_maps("normalization_scale", config, {}, normalization_scale);
35-
task = utils::get_from_any_maps("pixel_threshold", config, {}, task);
36-
labels = utils::get_from_any_maps("labels", config, {}, labels);
37-
input_shape.width = utils::get_from_any_maps("orig_width", config, {}, input_shape.width);
38-
input_shape.height = utils::get_from_any_maps("orig_height", config, {}, input_shape.height);
31+
auto model_config = adapter->getModelConfig();
32+
image_threshold = utils::get_from_any_maps("image_threshold", user_config, model_config, image_threshold);
33+
pixel_threshold = utils::get_from_any_maps("pixel_threshold", user_config, model_config, pixel_threshold);
34+
normalization_scale =
35+
utils::get_from_any_maps("normalization_scale", user_config, model_config, normalization_scale);
36+
task = utils::get_from_any_maps("pixel_threshold", user_config, model_config, task);
37+
labels = utils::get_from_any_maps("labels", user_config, model_config, labels);
38+
input_shape.width = utils::get_from_any_maps("orig_width", user_config, model_config, input_shape.width);
39+
input_shape.height = utils::get_from_any_maps("orig_height", user_config, model_config, input_shape.height);
3940
}
4041

4142
static void serialize(std::shared_ptr<ov::Model>& ov_model);
42-
static Anomaly load(const std::string& model_path);
43+
static Anomaly create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
4344

4445
AnomalyResult infer(cv::Mat image);
4546
std::vector<AnomalyResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/classification.h

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class Classification {
1919
std::shared_ptr<InferenceAdapter> adapter;
2020
VisionPipeline<ClassificationResult> pipeline;
2121

22-
Classification(std::shared_ptr<InferenceAdapter> adapter) : adapter(adapter) {
22+
Classification(std::shared_ptr<InferenceAdapter> adapter, const ov::AnyMap& user_config) : adapter(adapter) {
2323
pipeline = VisionPipeline<ClassificationResult>(
2424
adapter,
2525
[&](cv::Mat image) {
@@ -29,16 +29,19 @@ class Classification {
2929
return postprocess(result);
3030
});
3131

32-
auto config = adapter->getModelConfig();
33-
labels = utils::get_from_any_maps("labels", config, {}, labels);
34-
35-
topk = utils::get_from_any_maps("topk", config, {}, topk);
36-
multilabel = utils::get_from_any_maps("multilabel", config, {}, multilabel);
37-
output_raw_scores = utils::get_from_any_maps("output_raw_scores", config, {}, output_raw_scores);
38-
confidence_threshold = utils::get_from_any_maps("confidence_threshold", config, {}, confidence_threshold);
39-
hierarchical = utils::get_from_any_maps("hierarchical", config, {}, hierarchical);
40-
hierarchical_config = utils::get_from_any_maps("hierarchical_config", config, {}, hierarchical_config);
41-
hierarchical_postproc = utils::get_from_any_maps("hierarchical_postproc", config, {}, hierarchical_postproc);
32+
auto model_config = adapter->getModelConfig();
33+
labels = utils::get_from_any_maps("labels", user_config, model_config, labels);
34+
35+
topk = utils::get_from_any_maps("topk", user_config, model_config, topk);
36+
multilabel = utils::get_from_any_maps("multilabel", user_config, model_config, multilabel);
37+
output_raw_scores = utils::get_from_any_maps("output_raw_scores", user_config, model_config, output_raw_scores);
38+
confidence_threshold =
39+
utils::get_from_any_maps("confidence_threshold", user_config, model_config, confidence_threshold);
40+
hierarchical = utils::get_from_any_maps("hierarchical", user_config, model_config, hierarchical);
41+
hierarchical_config =
42+
utils::get_from_any_maps("hierarchical_config", user_config, model_config, hierarchical_config);
43+
hierarchical_postproc =
44+
utils::get_from_any_maps("hierarchical_postproc", user_config, model_config, hierarchical_postproc);
4245
if (hierarchical) {
4346
if (hierarchical_config.empty()) {
4447
throw std::runtime_error("Error: empty hierarchical classification config");
@@ -55,7 +58,7 @@ class Classification {
5558
}
5659

5760
static void serialize(std::shared_ptr<ov::Model>& ov_model);
58-
static Classification load(const std::string& model_path);
61+
static Classification create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
5962

6063
ClassificationResult infer(cv::Mat image);
6164
std::vector<ClassificationResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/detection.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ class DetectionModel {
1919
public:
2020
std::unique_ptr<Pipeline<DetectionResult>> pipeline;
2121

22-
DetectionModel(std::unique_ptr<SSD> algorithm, const ov::AnyMap& configuration) : algorithm(std::move(algorithm)) {
22+
DetectionModel(std::unique_ptr<SSD> algorithm, const ov::AnyMap& user_config) : algorithm(std::move(algorithm)) {
2323
auto config = this->algorithm->adapter->getModelConfig();
24-
if (configuration.count("tiling") && configuration.at("tiling").as<bool>()) {
24+
if (user_config.count("tiling") && user_config.at("tiling").as<bool>()) {
2525
if (!utils::config_contains_tiling_info(config)) {
2626
throw std::runtime_error("Model config does not contain tiling properties.");
2727
}
@@ -67,7 +67,7 @@ class DetectionModel {
6767
const std::vector<cv::Rect>& tile_coords,
6868
const utils::TilingInfo& tiling_info);
6969

70-
static DetectionModel load(const std::string& model_path, const ov::AnyMap& configuration = {});
70+
static DetectionModel create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
7171

7272
DetectionResult infer(cv::Mat image);
7373
std::vector<DetectionResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/instance_segmentation.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class InstanceSegmentation {
1818
std::shared_ptr<InferenceAdapter> adapter;
1919
VisionPipeline<InstanceSegmentationResult> pipeline;
2020

21-
InstanceSegmentation(std::shared_ptr<InferenceAdapter> adapter) : adapter(adapter) {
21+
InstanceSegmentation(std::shared_ptr<InferenceAdapter> adapter, const ov::AnyMap& user_config) : adapter(adapter) {
2222
pipeline = VisionPipeline<InstanceSegmentationResult>(
2323
adapter,
2424
[&](cv::Mat image) {
@@ -28,15 +28,16 @@ class InstanceSegmentation {
2828
return postprocess(result);
2929
});
3030

31-
auto config = adapter->getModelConfig();
32-
labels = utils::get_from_any_maps("labels", config, {}, labels);
33-
confidence_threshold = utils::get_from_any_maps("confidence_threshold", config, {}, confidence_threshold);
34-
input_shape.width = utils::get_from_any_maps("orig_width", config, {}, input_shape.width);
35-
input_shape.height = utils::get_from_any_maps("orig_height", config, {}, input_shape.width);
31+
auto model_config = adapter->getModelConfig();
32+
labels = utils::get_from_any_maps("labels", user_config, model_config, labels);
33+
confidence_threshold =
34+
utils::get_from_any_maps("confidence_threshold", user_config, model_config, confidence_threshold);
35+
input_shape.width = utils::get_from_any_maps("orig_width", user_config, model_config, input_shape.width);
36+
input_shape.height = utils::get_from_any_maps("orig_height", user_config, model_config, input_shape.width);
3637
}
3738

3839
static void serialize(std::shared_ptr<ov::Model>& ov_model);
39-
static InstanceSegmentation load(const std::string& model_path);
40+
static InstanceSegmentation create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
4041

4142
InstanceSegmentationResult infer(cv::Mat image);
4243
std::vector<InstanceSegmentationResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/semantic_segmentation.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SemanticSegmentation {
1717
public:
1818
VisionPipeline<SemanticSegmentationResult> pipeline;
1919
std::shared_ptr<InferenceAdapter> adapter;
20-
SemanticSegmentation(std::shared_ptr<InferenceAdapter> adapter) : adapter(adapter) {
20+
SemanticSegmentation(std::shared_ptr<InferenceAdapter> adapter, const ov::AnyMap& user_config) : adapter(adapter) {
2121
pipeline = VisionPipeline<SemanticSegmentationResult>(
2222
adapter,
2323
[&](cv::Mat image) {
@@ -27,14 +27,14 @@ class SemanticSegmentation {
2727
return postprocess(result);
2828
});
2929

30-
auto config = adapter->getModelConfig();
31-
labels = utils::get_from_any_maps("labels", config, {}, labels);
32-
soft_threshold = utils::get_from_any_maps("soft_threshold", config, {}, soft_threshold);
33-
blur_strength = utils::get_from_any_maps("blur_strength", config, {}, blur_strength);
30+
auto model_config = adapter->getModelConfig();
31+
labels = utils::get_from_any_maps("labels", user_config, model_config, labels);
32+
soft_threshold = utils::get_from_any_maps("soft_threshold", user_config, model_config, soft_threshold);
33+
blur_strength = utils::get_from_any_maps("blur_strength", user_config, model_config, blur_strength);
3434
}
3535

3636
static void serialize(std::shared_ptr<ov::Model>& ov_model);
37-
static SemanticSegmentation load(const std::string& model_path);
37+
static SemanticSegmentation create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
3838

3939
std::map<std::string, ov::Tensor> preprocess(cv::Mat);
4040
SemanticSegmentationResult postprocess(InferenceResult& infResult);

src/cpp/src/tasks/anomaly.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,12 @@ void Anomaly::serialize(std::shared_ptr<ov::Model>& ov_model) {
5151
ov_model->set_rt_info(input_shape[1], "model_info", "orig_height");
5252
}
5353

54-
Anomaly Anomaly::load(const std::string& model_path) {
54+
Anomaly Anomaly::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
5555
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
56-
adapter->loadModel(model_path, "", {}, false);
56+
adapter->loadModel(model_path, "", user_config, false);
5757

5858
std::string model_type;
59-
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), {}, model_type);
59+
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), user_config, model_type);
6060

6161
if (!model_type.empty()) {
6262
std::cout << "has model type in info: " << model_type << std::endl;
@@ -65,9 +65,9 @@ Anomaly Anomaly::load(const std::string& model_path) {
6565
}
6666

6767
adapter->applyModelTransform(Anomaly::serialize);
68-
adapter->compileModel("AUTO", {});
68+
adapter->compileModel("AUTO", user_config);
6969

70-
return Anomaly(adapter);
70+
return Anomaly(adapter, user_config);
7171
}
7272

7373
AnomalyResult Anomaly::infer(cv::Mat image) {

src/cpp/src/tasks/classification.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,20 +180,20 @@ void Classification::serialize(std::shared_ptr<ov::Model>& ov_model) {
180180
ov_model->set_rt_info(input_shape[1], "model_info", "orig_height");
181181
}
182182

183-
Classification Classification::load(const std::string& model_path) {
183+
Classification Classification::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
184184
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
185-
adapter->loadModel(model_path, "", {}, false);
185+
adapter->loadModel(model_path, "", user_config, false);
186186

187187
std::string model_type;
188-
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), {}, model_type);
188+
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), user_config, model_type);
189189

190190
if (model_type.empty() || model_type != "Classification") {
191191
throw std::runtime_error("Incorrect or unsupported model_type, expected: Classification");
192192
}
193193
adapter->applyModelTransform(Classification::serialize);
194-
adapter->compileModel("AUTO", {});
194+
adapter->compileModel("AUTO", user_config);
195195

196-
return Classification(adapter);
196+
return Classification(adapter, user_config);
197197
}
198198

199199
ClassificationResult Classification::infer(cv::Mat image) {

0 commit comments

Comments
 (0)