Skip to content

Commit 25c88f8

Browse files
authored
Adjust cpp inference adapters for OVMS (#212)
* Adjust adapters for OVMS This is required to allow no-copy inference with Model API within Mediapipe graphs. * Fix test build * Fixes
1 parent d2c52d5 commit 25c88f8

File tree

3 files changed

+31
-0
lines changed

3 files changed

+31
-0
lines changed

model_api/cpp/adapters/include/adapters/inference_adapter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class InferenceAdapter {
2424
virtual ~InferenceAdapter() = default;
2525

2626
virtual InferenceOutput infer(const InferenceInput& input) = 0;
27+
virtual void infer(const InferenceInput& input, InferenceOutput& output) = 0;
2728
virtual void setCallback(std::function<void(ov::InferRequest, CallbackData)> callback) = 0;
2829
virtual void inferAsync(const InferenceInput& input, CallbackData callback_args) = 0;
2930
virtual bool isReady() = 0;
@@ -36,6 +37,9 @@ class InferenceAdapter {
3637
const ov::AnyMap& compilationConfig = {},
3738
size_t max_num_requests = 0) = 0;
3839
virtual ov::PartialShape getInputShape(const std::string& inputName) const = 0;
40+
virtual ov::PartialShape getOutputShape(const std::string& inputName) const = 0;
41+
virtual ov::element::Type_t getInputDatatype(const std::string& inputName) const = 0;
42+
virtual ov::element::Type_t getOutputDatatype(const std::string& outputName) const = 0;
3943
virtual std::vector<std::string> getInputNames() const = 0;
4044
virtual std::vector<std::string> getOutputNames() const = 0;
4145
virtual const ov::AnyMap& getModelConfig() const = 0;

model_api/cpp/adapters/include/adapters/openvino_adapter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class OpenVINOInferenceAdapter : public InferenceAdapter {
1919
OpenVINOInferenceAdapter() = default;
2020

2121
virtual InferenceOutput infer(const InferenceInput& input) override;
22+
virtual void infer(const InferenceInput& input, InferenceOutput& output) override;
2223
virtual void inferAsync(const InferenceInput& input, const CallbackData callback_args) override;
2324
virtual void setCallback(std::function<void(ov::InferRequest, const CallbackData)> callback);
2425
virtual bool isReady();
@@ -31,6 +32,9 @@ class OpenVINOInferenceAdapter : public InferenceAdapter {
3132
size_t max_num_requests = 1) override;
3233
virtual size_t getNumAsyncExecutors() const;
3334
virtual ov::PartialShape getInputShape(const std::string& inputName) const override;
35+
virtual ov::PartialShape getOutputShape(const std::string& outputName) const override;
36+
virtual ov::element::Type_t getInputDatatype(const std::string& inputName) const override;
37+
virtual ov::element::Type_t getOutputDatatype(const std::string& outputName) const override;
3438
virtual std::vector<std::string> getInputNames() const override;
3539
virtual std::vector<std::string> getOutputNames() const override;
3640
virtual const ov::AnyMap& getModelConfig() const override;

model_api/cpp/adapters/src/openvino_adapter.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,20 @@ void OpenVINOInferenceAdapter::loadModel(const std::shared_ptr<const ov::Model>&
4141
}
4242
}
4343

44+
void OpenVINOInferenceAdapter::infer(const InferenceInput& input, InferenceOutput& output) {
45+
auto request = asyncQueue->operator[](asyncQueue->get_idle_request_id());
46+
for (const auto& [name, tensor] : input) {
47+
request.set_tensor(name, tensor);
48+
}
49+
for (const auto& [name, tensor] : output) {
50+
request.set_tensor(name, tensor);
51+
}
52+
request.infer();
53+
for (const auto& name : outputNames) {
54+
output[name] = request.get_tensor(name);
55+
}
56+
}
57+
4458
InferenceOutput OpenVINOInferenceAdapter::infer(const InferenceInput& input) {
4559
auto request = asyncQueue->operator[](asyncQueue->get_idle_request_id());
4660
// Fill input blobs
@@ -87,6 +101,9 @@ size_t OpenVINOInferenceAdapter::getNumAsyncExecutors() const {
87101
ov::PartialShape OpenVINOInferenceAdapter::getInputShape(const std::string& inputName) const {
88102
return compiledModel.input(inputName).get_partial_shape();
89103
}
104+
ov::PartialShape OpenVINOInferenceAdapter::getOutputShape(const std::string& outputName) const {
105+
return compiledModel.output(outputName).get_partial_shape();
106+
}
90107

91108
void OpenVINOInferenceAdapter::initInputsOutputs() {
92109
for (const auto& input : compiledModel.inputs()) {
@@ -97,6 +114,12 @@ void OpenVINOInferenceAdapter::initInputsOutputs() {
97114
outputNames.push_back(output.get_any_name());
98115
}
99116
}
117+
ov::element::Type_t OpenVINOInferenceAdapter::getInputDatatype(const std::string&) const {
118+
throw std::runtime_error("Not implemented");
119+
}
120+
ov::element::Type_t OpenVINOInferenceAdapter::getOutputDatatype(const std::string&) const {
121+
throw std::runtime_error("Not implemented");
122+
}
100123

101124
std::vector<std::string> OpenVINOInferenceAdapter::getInputNames() const {
102125
return inputNames;

0 commit comments

Comments
 (0)