Skip to content

Commit 8168c27

Browse files
committed
Adds preload tests back in.
Also fix the way the model loads Preload is basically calling the compileModel so dont pass it in initially unless we're sure that it doesnt need transforming
1 parent abf9dca commit 8168c27

File tree

13 files changed

+198
-23
lines changed

13 files changed

+198
-23
lines changed

src/cpp/include/tasks/anomaly.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class Anomaly {
4040
}
4141

4242
static void serialize(std::shared_ptr<ov::Model>& ov_model);
43-
static Anomaly create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
43+
static Anomaly create_model(const std::string& model_path, const ov::AnyMap& user_config = {}, bool preload = true, const std::string& device = "AUTO");
4444

4545
AnomalyResult infer(cv::Mat image);
4646
std::vector<AnomalyResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/classification.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class Classification {
5858
}
5959

6060
static void serialize(std::shared_ptr<ov::Model>& ov_model);
61-
static Classification create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
61+
static Classification create_model(const std::string& model_path, const ov::AnyMap& user_config = {}, bool preload = true, const std::string& device = "AUTO");
6262

6363
ClassificationResult infer(cv::Mat image);
6464
std::vector<ClassificationResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/detection.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class DetectionModel {
6767
const std::vector<cv::Rect>& tile_coords,
6868
const utils::TilingInfo& tiling_info);
6969

70-
static DetectionModel create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
70+
static DetectionModel create_model(const std::string& model_path, const ov::AnyMap& user_config = {}, bool preload = true, const std::string& device = "AUTO");
7171

7272
DetectionResult infer(cv::Mat image);
7373
std::vector<DetectionResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/instance_segmentation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class InstanceSegmentation {
3737
}
3838

3939
static void serialize(std::shared_ptr<ov::Model>& ov_model);
40-
static InstanceSegmentation create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
40+
static InstanceSegmentation create_model(const std::string& model_path, const ov::AnyMap& user_config = {}, bool preload = true, const std::string& device = "AUTO");
4141

4242
InstanceSegmentationResult infer(cv::Mat image);
4343
std::vector<InstanceSegmentationResult> inferBatch(std::vector<cv::Mat> image);

src/cpp/include/tasks/semantic_segmentation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class SemanticSegmentation {
3434
}
3535

3636
static void serialize(std::shared_ptr<ov::Model>& ov_model);
37-
static SemanticSegmentation create_model(const std::string& model_path, const ov::AnyMap& user_config = {});
37+
static SemanticSegmentation create_model(const std::string& model_path, const ov::AnyMap& user_config = {}, bool preload = true, const std::string& device = "AUTO");
3838

3939
std::map<std::string, ov::Tensor> preprocess(cv::Mat);
4040
SemanticSegmentationResult postprocess(InferenceResult& infResult);

src/cpp/src/tasks/anomaly.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ void Anomaly::serialize(std::shared_ptr<ov::Model>& ov_model) {
5151
ov_model->set_rt_info(input_shape[1], "model_info", "orig_height");
5252
}
5353

54-
Anomaly Anomaly::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
54+
Anomaly Anomaly::create_model(const std::string& model_path, const ov::AnyMap& user_config, bool preload, const std::string& device) {
5555
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
56-
adapter->loadModel(model_path, "", user_config, false);
56+
adapter->loadModel(model_path, device, user_config, false);
5757

5858
std::string model_type;
5959
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), user_config, model_type);
@@ -65,7 +65,9 @@ Anomaly Anomaly::create_model(const std::string& model_path, const ov::AnyMap& u
6565
}
6666

6767
adapter->applyModelTransform(Anomaly::serialize);
68-
adapter->compileModel("AUTO", user_config);
68+
if (preload) {
69+
adapter->compileModel(device, user_config);
70+
}
6971

7072
return Anomaly(adapter, user_config);
7173
}

src/cpp/src/tasks/classification.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,9 @@ void Classification::serialize(std::shared_ptr<ov::Model>& ov_model) {
180180
ov_model->set_rt_info(input_shape[1], "model_info", "orig_height");
181181
}
182182

183-
Classification Classification::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
183+
Classification Classification::create_model(const std::string& model_path, const ov::AnyMap& user_config, bool preload, const std::string& device) {
184184
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
185-
adapter->loadModel(model_path, "", user_config, false);
185+
adapter->loadModel(model_path, device, user_config, false);
186186

187187
std::string model_type;
188188
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), user_config, model_type);
@@ -191,7 +191,9 @@ Classification Classification::create_model(const std::string& model_path, const
191191
throw std::runtime_error("Incorrect or unsupported model_type, expected: Classification");
192192
}
193193
adapter->applyModelTransform(Classification::serialize);
194-
adapter->compileModel("AUTO", user_config);
194+
if (preload) {
195+
adapter->compileModel(device, user_config);
196+
}
195197

196198
return Classification(adapter, user_config);
197199
}

src/cpp/src/tasks/detection.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
#include "utils/nms.h"
1414
#include "utils/tensor.h"
1515

16-
DetectionModel DetectionModel::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
16+
DetectionModel DetectionModel::create_model(const std::string& model_path, const ov::AnyMap& user_config, bool preload, const std::string& device) {
1717
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
18-
adapter->loadModel(model_path, "", user_config, false);
18+
adapter->loadModel(model_path, device, user_config, false);
1919

2020
std::string model_type;
2121
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), user_config, model_type);
@@ -25,7 +25,9 @@ DetectionModel DetectionModel::create_model(const std::string& model_path, const
2525
throw std::runtime_error("Incorrect or unsupported model_type, expected: ssd");
2626
}
2727
adapter->applyModelTransform(SSD::serialize);
28-
adapter->compileModel("AUTO", user_config);
28+
if (preload) {
29+
adapter->compileModel(device, user_config);
30+
}
2931

3032
return DetectionModel(std::make_unique<SSD>(adapter), user_config);
3133
}

src/cpp/src/tasks/instance_segmentation.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,9 @@ void InstanceSegmentation::serialize(std::shared_ptr<ov::Model>& ov_model) {
190190
ov_model->set_rt_info(input_shape.height, "model_info", "orig_height");
191191
}
192192

193-
InstanceSegmentation InstanceSegmentation::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
193+
InstanceSegmentation InstanceSegmentation::create_model(const std::string& model_path, const ov::AnyMap& user_config, bool preload, const std::string& device) {
194194
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
195-
adapter->loadModel(model_path, "", user_config, false);
195+
adapter->loadModel(model_path, device, user_config, false);
196196

197197
std::string model_type;
198198
model_type = utils::get_from_any_maps("model_type", user_config, adapter->getModelConfig(), model_type);
@@ -201,7 +201,9 @@ InstanceSegmentation InstanceSegmentation::create_model(const std::string& model
201201
throw std::runtime_error("Incorrect or unsupported model_type, expected: MaskRCNN");
202202
}
203203
adapter->applyModelTransform(InstanceSegmentation::serialize);
204-
adapter->compileModel("AUTO", user_config);
204+
if (preload) {
205+
adapter->compileModel(device, user_config);
206+
}
205207

206208
return InstanceSegmentation(adapter, user_config);
207209
}

src/cpp/src/tasks/semantic_segmentation.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ cv::Mat get_activation_map(const cv::Mat& features) {
2020
return int_act_map;
2121
}
2222

23-
SemanticSegmentation SemanticSegmentation::create_model(const std::string& model_path, const ov::AnyMap& user_config) {
23+
SemanticSegmentation SemanticSegmentation::create_model(const std::string& model_path, const ov::AnyMap& user_config, bool preload, const std::string& device) {
2424
auto adapter = std::make_shared<OpenVINOInferenceAdapter>();
25-
adapter->loadModel(model_path, "", user_config, false);
25+
adapter->loadModel(model_path, device, user_config, false);
2626

2727
std::string model_type;
2828
model_type = utils::get_from_any_maps("model_type", user_config, adapter->getModelConfig(), model_type);
@@ -31,7 +31,9 @@ SemanticSegmentation SemanticSegmentation::create_model(const std::string& model
3131
throw std::runtime_error("Incorrect or unsupported model_type, expected: Segmentation");
3232
}
3333
adapter->applyModelTransform(SemanticSegmentation::serialize);
34-
adapter->compileModel("AUTO", user_config);
34+
if (preload) {
35+
adapter->compileModel(device, user_config);
36+
}
3537

3638
return SemanticSegmentation(adapter, user_config);
3739
}

0 commit comments

Comments
 (0)