Skip to content

Commit 6ef53ed

Browse files
committed
Merge ModelBase into ImageModel
Example works, but probably some missing parts.
1 parent 9efe90c commit 6ef53ed

File tree

6 files changed

+224
-314
lines changed

6 files changed

+224
-314
lines changed

src/cpp/models/include/models/image_model.h

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
#include <memory>
1010
#include <string>
1111

12+
#include "adapters/inference_adapter.h"
1213
#include "models/input_data.h"
13-
#include "models/model_base.h"
14+
#include "models/results.h"
1415
#include "utils/image_utils.h"
16+
#include "utils/ocv_common.hpp"
17+
#include "utils/args_helper.hpp"
1518

1619
namespace ov {
1720
class InferRequest;
@@ -20,7 +23,7 @@ struct InputData;
2023
struct InternalModelData;
2124

2225
// ImageModel implements preprocess(), ImageModel's direct or indirect children are expected to implement prostprocess()
23-
class ImageModel : public ModelBase {
26+
class ImageModel {
2427
public:
2528
/// Constructor
2629
/// @param modelFile name of model to load
@@ -33,9 +36,24 @@ class ImageModel : public ModelBase {
3336

3437
ImageModel(std::shared_ptr<ov::Model>& model, const ov::AnyMap& configuration);
3538
ImageModel(std::shared_ptr<InferenceAdapter>& adapter, const ov::AnyMap& configuration = {});
36-
using ModelBase::ModelBase;
3739

38-
std::shared_ptr<InternalModelData> preprocess(const InputData& inputData, InferenceInput& input) override;
40+
virtual std::shared_ptr<InternalModelData> preprocess(const InputData& inputData, InferenceInput& input);
41+
virtual std::unique_ptr<ResultBase> postprocess(InferenceResult& infResult) = 0;
42+
43+
void load(ov::Core& core, const std::string& device, size_t num_infer_requests = 1);
44+
45+
std::shared_ptr<ov::Model> prepare();
46+
47+
virtual size_t getNumAsyncExecutors() const;
48+
virtual bool isReady();
49+
virtual void awaitAll();
50+
virtual void awaitAny();
51+
virtual void setCallback(
52+
std::function<void(std::unique_ptr<ResultBase>, const ov::AnyMap& callback_args)> callback);
53+
54+
std::shared_ptr<ov::Model> getModel();
55+
std::shared_ptr<InferenceAdapter> getInferenceAdapter();
56+
3957
static std::vector<std::string> loadLabels(const std::string& labelFilename);
4058
std::shared_ptr<ov::Model> embedProcessing(std::shared_ptr<ov::Model>& model,
4159
const std::string& inputName,
@@ -54,7 +72,7 @@ class ImageModel : public ModelBase {
5472

5573
protected:
5674
RESIZE_MODE selectResizeMode(const std::string& resize_type);
57-
void updateModelInfo() override;
75+
virtual void updateModelInfo();
5876
void init_from_config(const ov::AnyMap& top_priority, const ov::AnyMap& mid_priority);
5977

6078
std::string getLabelName(size_t labelID) {
@@ -73,4 +91,18 @@ class ImageModel : public ModelBase {
7391
bool reverse_input_channels = false;
7492
std::vector<float> scale_values;
7593
std::vector<float> mean_values;
94+
95+
protected:
96+
virtual void prepareInputsOutputs(std::shared_ptr<ov::Model>& model) = 0;
97+
98+
InputTransform inputTransform = InputTransform();
99+
100+
std::shared_ptr<ov::Model> model;
101+
std::vector<std::string> inputNames;
102+
std::vector<std::string> outputNames;
103+
std::string modelFile;
104+
std::shared_ptr<InferenceAdapter> inferenceAdapter;
105+
std::map<std::string, ov::Layout> inputsLayouts;
106+
ov::Layout getInputLayout(const ov::Output<ov::Node>& input);
107+
std::function<void(std::unique_ptr<ResultBase>, const ov::AnyMap&)> lastCallback;
76108
};

src/cpp/models/include/models/model_base.h

Lines changed: 0 additions & 73 deletions
This file was deleted.

src/cpp/models/src/image_model.cpp

Lines changed: 184 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,137 @@
1515
#include <utils/ocv_common.hpp>
1616
#include <vector>
1717

18+
#include "adapters/openvino_adapter.h"
1819
#include "models/input_data.h"
1920
#include "models/internal_model_data.h"
2021
#include "models/results.h"
22+
#include "utils/common.hpp"
23+
24+
namespace {
25+
class TmpCallbackSetter {
26+
public:
27+
ImageModel* model;
28+
std::function<void(std::unique_ptr<ResultBase>, const ov::AnyMap&)> last_callback;
29+
TmpCallbackSetter(ImageModel* model_,
30+
std::function<void(std::unique_ptr<ResultBase>, const ov::AnyMap&)> tmp_callback,
31+
std::function<void(std::unique_ptr<ResultBase>, const ov::AnyMap&)> last_callback_)
32+
: model(model_),
33+
last_callback(last_callback_) {
34+
model->setCallback(tmp_callback);
35+
}
36+
~TmpCallbackSetter() {
37+
if (last_callback) {
38+
model->setCallback(last_callback);
39+
} else {
40+
model->setCallback([](std::unique_ptr<ResultBase>, const ov::AnyMap&) {});
41+
}
42+
}
43+
};
44+
} // namespace
2145

2246
ImageModel::ImageModel(const std::string& modelFile,
2347
const std::string& resize_type,
2448
bool useAutoResize,
2549
const std::string& layout)
26-
: ModelBase(modelFile, layout),
27-
useAutoResize(useAutoResize),
28-
resizeMode(selectResizeMode(resize_type)) {}
50+
: useAutoResize(useAutoResize),
51+
resizeMode(selectResizeMode(resize_type)),
52+
modelFile(modelFile),
53+
inputsLayouts(parseLayoutString(layout)) {
54+
auto core = ov::Core();
55+
model = core.read_model(modelFile);
56+
}
57+
58+
59+
void ImageModel::load(ov::Core& core, const std::string& device, size_t num_infer_requests) {
60+
if (!inferenceAdapter) {
61+
inferenceAdapter = std::make_shared<OpenVINOInferenceAdapter>();
62+
}
63+
64+
// Update model_info erased by pre/postprocessing
65+
updateModelInfo();
66+
67+
inferenceAdapter->loadModel(model, core, device, {}, num_infer_requests);
68+
}
69+
70+
std::shared_ptr<ov::Model> ImageModel::prepare() {
71+
prepareInputsOutputs(model);
72+
logBasicModelInfo(model);
73+
ov::set_batch(model, 1);
74+
75+
return model;
76+
}
77+
78+
ov::Layout ImageModel::getInputLayout(const ov::Output<ov::Node>& input) {
79+
ov::Layout layout = ov::layout::get_layout(input);
80+
if (layout.empty()) {
81+
if (inputsLayouts.empty()) {
82+
layout = getLayoutFromShape(input.get_partial_shape());
83+
slog::warn << "Automatically detected layout '" << layout.to_string() << "' for input '"
84+
<< input.get_any_name() << "' will be used." << slog::endl;
85+
} else if (inputsLayouts.size() == 1) {
86+
layout = inputsLayouts.begin()->second;
87+
} else {
88+
layout = inputsLayouts[input.get_any_name()];
89+
}
90+
}
91+
92+
return layout;
93+
}
94+
95+
size_t ImageModel::getNumAsyncExecutors() const {
96+
return inferenceAdapter->getNumAsyncExecutors();
97+
}
98+
99+
bool ImageModel::isReady() {
100+
return inferenceAdapter->isReady();
101+
}
102+
void ImageModel::awaitAll() {
103+
inferenceAdapter->awaitAll();
104+
}
105+
void ImageModel::awaitAny() {
106+
inferenceAdapter->awaitAny();
107+
}
108+
109+
void ImageModel::setCallback(
110+
std::function<void(std::unique_ptr<ResultBase>, const ov::AnyMap& callback_args)> callback) {
111+
lastCallback = callback;
112+
inferenceAdapter->setCallback([this, callback](ov::InferRequest request, CallbackData args) {
113+
InferenceResult result;
114+
115+
InferenceOutput output;
116+
for (const auto& item : this->getInferenceAdapter()->getOutputNames()) {
117+
output.emplace(item, request.get_tensor(item));
118+
}
119+
120+
result.outputsData = output;
121+
auto model_data_iter = args->find("internalModelData");
122+
if (model_data_iter != args->end()) {
123+
result.internalModelData = std::move(model_data_iter->second.as<std::shared_ptr<InternalModelData>>());
124+
}
125+
auto retVal = this->postprocess(result);
126+
*retVal = static_cast<ResultBase&>(result);
127+
callback(std::move(retVal), args ? *args : ov::AnyMap());
128+
});
129+
}
130+
131+
std::shared_ptr<ov::Model> ImageModel::getModel() {
132+
if (!model) {
133+
throw std::runtime_error(std::string("ov::Model is not accessible for the current model adapter: ") +
134+
typeid(inferenceAdapter).name());
135+
}
136+
137+
updateModelInfo();
138+
return model;
139+
}
140+
141+
std::shared_ptr<InferenceAdapter> ImageModel::getInferenceAdapter() {
142+
if (!inferenceAdapter) {
143+
throw std::runtime_error(std::string("Model wasn't loaded"));
144+
}
145+
146+
return inferenceAdapter;
147+
}
148+
29149

30150
RESIZE_MODE ImageModel::selectResizeMode(const std::string& resize_type) {
31151
RESIZE_MODE resize = RESIZE_FILL;
@@ -68,36 +188,88 @@ void ImageModel::init_from_config(const ov::AnyMap& top_priority, const ov::AnyM
68188
}
69189

70190
ImageModel::ImageModel(std::shared_ptr<ov::Model>& model, const ov::AnyMap& configuration)
71-
: ModelBase(model, configuration) {
191+
: model(model) {
192+
auto layout_iter = configuration.find("layout");
193+
std::string layout = "";
194+
195+
if (layout_iter != configuration.end()) {
196+
layout = layout_iter->second.as<std::string>();
197+
} else {
198+
if (model->has_rt_info("model_info", "layout")) {
199+
layout = model->get_rt_info<std::string>("model_info", "layout");
200+
}
201+
}
202+
inputsLayouts = parseLayoutString(layout);
72203
init_from_config(configuration,
73204
model->has_rt_info("model_info") ? model->get_rt_info<ov::AnyMap>("model_info") : ov::AnyMap{});
74205
}
75206

76207
ImageModel::ImageModel(std::shared_ptr<InferenceAdapter>& adapter, const ov::AnyMap& configuration)
77-
: ModelBase(adapter, configuration) {
208+
: inferenceAdapter(adapter) {
209+
const ov::AnyMap& adapter_configuration = adapter->getModelConfig();
210+
211+
std::string layout = "";
212+
layout = get_from_any_maps("layout", configuration, adapter_configuration, layout);
213+
inputsLayouts = parseLayoutString(layout);
214+
215+
inputNames = adapter->getInputNames();
216+
outputNames = adapter->getOutputNames();
217+
78218
init_from_config(configuration, adapter->getModelConfig());
79219
}
80220

81221
std::unique_ptr<ResultBase> ImageModel::inferImage(const ImageInputData& inputData) {
82-
return ModelBase::infer(static_cast<const InputData&>(inputData));
83-
;
222+
InferenceInput inputs;
223+
InferenceResult result;
224+
auto internalModelData = this->preprocess(inputData, inputs);
225+
226+
result.outputsData = inferenceAdapter->infer(inputs);
227+
result.internalModelData = std::move(internalModelData);
228+
229+
auto retVal = this->postprocess(result);
230+
*retVal = static_cast<ResultBase&>(result);
231+
return retVal;
84232
}
85233

86234
std::vector<std::unique_ptr<ResultBase>> ImageModel::inferBatchImage(const std::vector<ImageInputData>& inputImgs) {
87-
std::vector<std::reference_wrapper<const InputData>> inputData;
235+
std::vector<std::reference_wrapper<const ImageInputData>> inputData;
88236
inputData.reserve(inputImgs.size());
89237
for (const auto& img : inputImgs) {
90-
inputData.push_back(static_cast<const InputData&>(img));
238+
inputData.push_back(img);
239+
}
240+
auto results = std::vector<std::unique_ptr<ResultBase>>(inputData.size());
241+
auto setter = TmpCallbackSetter(
242+
this,
243+
[&](std::unique_ptr<ResultBase> result, const ov::AnyMap& callback_args) {
244+
size_t id = callback_args.find("id")->second.as<size_t>();
245+
results[id] = std::move(result);
246+
},
247+
lastCallback);
248+
size_t req_id = 0;
249+
for (const auto& data : inputData) {
250+
inferAsync(data, {{"id", req_id++}});
91251
}
92-
return ModelBase::inferBatch(inputData);
252+
awaitAll();
253+
return results;
93254
}
94255

95256
void ImageModel::inferAsync(const ImageInputData& inputData, const ov::AnyMap& callback_args) {
96-
ModelBase::inferAsync(static_cast<const InputData&>(inputData), callback_args);
257+
InferenceInput inputs;
258+
auto internalModelData = this->preprocess(inputData, inputs);
259+
auto callback_args_ptr = std::make_shared<ov::AnyMap>(callback_args);
260+
(*callback_args_ptr)["internalModelData"] = std::move(internalModelData);
261+
inferenceAdapter->inferAsync(inputs, callback_args_ptr);
97262
}
98263

99264
void ImageModel::updateModelInfo() {
100-
ModelBase::updateModelInfo();
265+
if (!model) {
266+
throw std::runtime_error("The ov::Model object is not accessible");
267+
}
268+
269+
if (!inputsLayouts.empty()) {
270+
auto layouts = formatLayouts(inputsLayouts);
271+
model->set_rt_info(layouts, "model_info", "layout");
272+
}
101273

102274
model->set_rt_info(useAutoResize, "model_info", "auto_resize");
103275
model->set_rt_info(formatResizeMode(resizeMode), "model_info", "resize_type");

0 commit comments

Comments
 (0)