Skip to content

Commit 28eed62

Browse files
committed
Adjust adapters for OVMS
This is required to allow no-copy inference with Model API within Mediapipe graphs.
1 parent 513583f commit 28eed62

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

model_api/cpp/adapters/include/adapters/inference_adapter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class InferenceAdapter
3838
virtual ~InferenceAdapter() = default;
3939

4040
virtual InferenceOutput infer(const InferenceInput& input) = 0;
41+
virtual void infer(const InferenceInput& input, InferenceOutput& output) = 0;
4142
virtual void setCallback(std::function<void(ov::InferRequest, CallbackData)> callback) = 0;
4243
virtual void inferAsync(const InferenceInput& input, CallbackData callback_args) = 0;
4344
virtual bool isReady() = 0;
@@ -48,6 +49,9 @@ class InferenceAdapter
4849
const std::string& device = "", const ov::AnyMap& compilationConfig = {},
4950
size_t max_num_requests = 0) = 0;
5051
virtual ov::PartialShape getInputShape(const std::string& inputName) const = 0;
52+
virtual ov::PartialShape getOutputShape(const std::string& inputName) const = 0;
53+
virtual ov::element::Type_t getInputDatatype(const std::string& inputName) const = 0;
54+
virtual ov::element::Type_t getOutputDatatype(const std::string& outputName) const = 0;
5155
virtual std::vector<std::string> getInputNames() const = 0;
5256
virtual std::vector<std::string> getOutputNames() const = 0;
5357
virtual const ov::AnyMap& getModelConfig() const = 0;

model_api/cpp/adapters/include/adapters/openvino_adapter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class OpenVINOInferenceAdapter :public InferenceAdapter
3232
OpenVINOInferenceAdapter() = default;
3333

3434
virtual InferenceOutput infer(const InferenceInput& input) override;
35+
virtual void infer(const InferenceInput& input, InferenceOutput& output) override;
3536
virtual void inferAsync(const InferenceInput& input, const CallbackData callback_args) override;
3637
virtual void setCallback(std::function<void(ov::InferRequest, const CallbackData)> callback);
3738
virtual bool isReady();
@@ -42,6 +43,9 @@ class OpenVINOInferenceAdapter :public InferenceAdapter
4243
size_t max_num_requests = 1) override;
4344
virtual size_t getNumAsyncExecutors() const;
4445
virtual ov::PartialShape getInputShape(const std::string& inputName) const override;
46+
virtual ov::PartialShape getOutputShape(const std::string& outputName) const override;
47+
virtual ov::element::Type_t getInputDatatype(const std::string& inputName) const override;
48+
virtual ov::element::Type_t getOutputDatatype(const std::string& outputName) const override;
4549
virtual std::vector<std::string> getInputNames() const override;
4650
virtual std::vector<std::string> getOutputNames() const override;
4751
virtual const ov::AnyMap& getModelConfig() const override;

model_api/cpp/adapters/src/openvino_adapter.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ void OpenVINOInferenceAdapter::loadModel(const std::shared_ptr<const ov::Model>&
4949
}
5050
}
5151

52+
void OpenVINOInferenceAdapter::infer(const InferenceInput&, InferenceOutput&) {
53+
throw std::runtime_error("Not implemented");
54+
}
55+
5256
InferenceOutput OpenVINOInferenceAdapter::infer(const InferenceInput& input) {
5357
auto request = asyncQueue->operator[](asyncQueue->get_idle_request_id());
5458
// Fill input blobs
@@ -95,6 +99,9 @@ size_t OpenVINOInferenceAdapter::getNumAsyncExecutors() const {
9599
ov::PartialShape OpenVINOInferenceAdapter::getInputShape(const std::string& inputName) const {
96100
return compiledModel.input(inputName).get_partial_shape();
97101
}
102+
ov::PartialShape OpenVINOInferenceAdapter::getOutputShape(const std::string& outputName) const {
103+
return compiledModel.output(outputName).get_shape();
104+
}
98105

99106
void OpenVINOInferenceAdapter::initInputsOutputs() {
100107
for (const auto& input : compiledModel.inputs()) {
@@ -105,6 +112,12 @@ void OpenVINOInferenceAdapter::initInputsOutputs() {
105112
outputNames.push_back(output.get_any_name());
106113
}
107114
}
115+
ov::element::Type_t OpenVINOInferenceAdapter::getInputDatatype(const std::string& inputName) const {
116+
throw std::runtime_error("Not implemented");
117+
}
118+
ov::element::Type_t OpenVINOInferenceAdapter::getOutputDatatype(const std::string& outputName) const {
119+
throw std::runtime_error("Not implemented");
120+
}
108121

109122
std::vector<std::string> OpenVINOInferenceAdapter::getInputNames() const {
110123
return inputNames;

0 commit comments

Comments
 (0)