Skip to content

Commit d9cec5c

Browse files
authored
Support ONNX models via OV (#310)
* Add ONNX rt to cmake * Support onnx via OV * Update tests * Merge onnx config to OV model right after loading * Revert changes in the OV adapter * Cleanup in cmake * Link onnxrt publicly
1 parent 5428028 commit d9cec5c

File tree

9 files changed

+141
-5
lines changed

9 files changed

+141
-5
lines changed

src/cpp/CMakeLists.txt

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,39 @@ include(FetchContent)
1212
FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git
1313
GIT_TAG d41ca94fa85d5119852e2f7a3f94335cc7cb0486 # PR #4709, fixes cmake deprecation warnings
1414
)
15-
1615
FetchContent_MakeAvailable(json)
1716

17+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
18+
FetchContent_Declare(
19+
onnxrt
20+
URL https://github.com/microsoft/onnxruntime/releases/download/v1.22.0/onnxruntime-linux-x64-1.22.0.tgz
21+
URL_HASH MD5=2b3a728a057226a3d6577335604db9bc
22+
)
23+
FetchContent_MakeAvailable(onnxrt)
24+
25+
# fixes onnxrt cmake disabilities
26+
set(ONNXRUNTIME_ROOTDIR ${CMAKE_BINARY_DIR}/_deps/onnxrt-src)
27+
file(CREATE_LINK "${ONNXRUNTIME_ROOTDIR}/lib" "${ONNXRUNTIME_ROOTDIR}/lib64" SYMBOLIC)
28+
file(CREATE_LINK "${ONNXRUNTIME_ROOTDIR}/include" "${ONNXRUNTIME_ROOTDIR}/include/onnxruntime" SYMBOLIC)
29+
list(APPEND CMAKE_PREFIX_PATH "${ONNXRUNTIME_ROOTDIR}/lib/cmake")
30+
else()
31+
message(FATAL_ERROR "Unsupported platform: ${CMAKE_SYSTEM_NAME}. Only Linux is supported at the moment.")
32+
endif()
33+
1834

1935
find_package(OpenCV REQUIRED COMPONENTS core imgproc)
2036

2137
find_package(OpenVINO REQUIRED
2238
COMPONENTS Runtime Threading)
2339

40+
find_package(onnxruntime REQUIRED)
41+
2442
file(GLOB TASK_SOURCES src/tasks/**/*.cpp)
2543
file(GLOB TASKS_SOURCES src/tasks/*.cpp)
2644
file(GLOB UTILS_SOURCES src/utils/*.cpp)
2745
file(GLOB ADAPTERS_SOURCES src/adapters/*.cpp)
2846

2947
add_library(model_api STATIC ${TASK_SOURCES} ${TASKS_SOURCES} ${UTILS_SOURCES} ${ADAPTERS_SOURCES} ${TILERS_SOURCES})
3048

31-
target_link_libraries(model_api PUBLIC openvino::runtime opencv_core opencv_imgproc PRIVATE nlohmann_json::nlohmann_json)
49+
target_link_libraries(model_api PUBLIC openvino::runtime opencv_core opencv_imgproc onnxruntime::onnxruntime PRIVATE nlohmann_json::nlohmann_json)
3250
target_include_directories(model_api PUBLIC ${PROJECT_SOURCE_DIR}/include)

src/cpp/include/tasks/detection/ssd.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SSD {
3838
std::map<std::string, ov::Tensor> preprocess(cv::Mat);
3939
DetectionResult postprocess(InferenceResult& infResult);
4040

41-
static void serialize(std::shared_ptr<ov::Model> ov_model);
41+
static void serialize(std::shared_ptr<ov::Model>& ov_model);
4242

4343
SSDOutputMode output_mode;
4444

src/cpp/include/utils/config.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ inline bool get_from_any_maps(const std::string& key,
4242
return low_priority;
4343
}
4444

45+
ov::AnyMap get_config_from_onnx(const std::string& model_path);
46+
47+
void add_ov_model_info(std::shared_ptr<ov::Model> model, const ov::AnyMap& config);
48+
4549
inline bool model_has_embedded_processing(std::shared_ptr<ov::Model> model) {
4650
if (model->has_rt_info("model_info")) {
4751
auto model_info = model->get_rt_info<ov::AnyMap>("model_info");

src/cpp/src/adapters/openvino_adapter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ void OpenVINOInferenceAdapter::loadModel(const std::string& modelPath,
4848
model = core.read_model(modelPath);
4949
if (model->has_rt_info({"model_info"})) {
5050
modelConfig = model->get_rt_info<ov::AnyMap>("model_info");
51+
} else if (modelPath.find("onnx") != std::string::npos || modelPath.find("ONNX") != std::string::npos) {
52+
modelConfig = utils::get_config_from_onnx(modelPath);
53+
utils::add_ov_model_info(model, modelConfig);
5154
}
5255
if (preCompile) {
5356
compileModel(device, adapterConfig);

src/cpp/src/tasks/detection.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
#include "tasks/detection.h"
77

8+
#include <algorithm>
9+
810
#include "adapters/openvino_adapter.h"
911
#include "tasks/detection/ssd.h"
1012
#include "utils/config.h"
@@ -17,7 +19,7 @@ DetectionModel DetectionModel::load(const std::string& model_path, const ov::Any
1719

1820
std::string model_type;
1921
model_type = utils::get_from_any_maps("model_type", adapter->getModelConfig(), {}, model_type);
20-
transform(model_type.begin(), model_type.end(), model_type.begin(), ::tolower);
22+
std::transform(model_type.begin(), model_type.end(), model_type.begin(), ::tolower);
2123

2224
if (model_type.empty() || model_type != "ssd") {
2325
throw std::runtime_error("Incorrect or unsupported model_type, expected: ssd");

src/cpp/src/tasks/detection/ssd.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ std::map<std::string, ov::Tensor> SSD::preprocess(cv::Mat image) {
6767
return input;
6868
}
6969

70-
void SSD::serialize(std::shared_ptr<ov::Model> ov_model) {
70+
void SSD::serialize(std::shared_ptr<ov::Model>& ov_model) {
7171
if (utils::model_has_embedded_processing(ov_model)) {
7272
std::cout << "model already was serialized" << std::endl;
7373
return;

src/cpp/src/utils/config.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,57 @@
44
*/
55

66
#include "utils/config.h"
7+
8+
#include <onnxruntime_cxx_api.h>
9+
10+
namespace {
11+
std::vector<std::string> split(const std::string& str, const std::string& delimiter) {
12+
std::vector<std::string> output;
13+
size_t start = 0;
14+
size_t end = 0;
15+
while ((end = str.find(delimiter, start)) != std::string::npos) {
16+
output.push_back(str.substr(start, end - start));
17+
start = end + delimiter.length();
18+
}
19+
output.push_back(str.substr(start));
20+
return output;
21+
}
22+
} // namespace
23+
24+
ov::AnyMap utils::get_config_from_onnx(const std::string& model_path) {
25+
ov::AnyMap config;
26+
if (model_path.find("onnx") != std::string::npos || model_path.find("ONNX") != std::string::npos) {
27+
Ort::Env env;
28+
Ort::SessionOptions ort_session_options;
29+
Ort::Session session = Ort::Session(env, model_path.c_str(), ort_session_options);
30+
Ort::AllocatorWithDefaultOptions ort_alloc;
31+
32+
Ort::ModelMetadata model_metadata = session.GetModelMetadata();
33+
std::vector<Ort::AllocatedStringPtr> keys = model_metadata.GetCustomMetadataMapKeysAllocated(ort_alloc);
34+
for (const auto& key : keys) {
35+
std::vector<std::string> attr_names;
36+
if (key != nullptr) {
37+
const std::array<const char*, 1> list_names = {key.get()};
38+
39+
Ort::AllocatedStringPtr values_search =
40+
model_metadata.LookupCustomMetadataMapAllocated(list_names[0], ort_alloc);
41+
if (values_search != nullptr) {
42+
const std::array<const char*, 1> value = {values_search.get()};
43+
attr_names = split(std::string(list_names[0]), " ");
44+
// only flat metadata is supported
45+
if (attr_names.size() == 2 && attr_names[0] == "model_info")
46+
config[attr_names[1]] = std::string(value[0]);
47+
}
48+
}
49+
}
50+
} else {
51+
throw std::runtime_error("Model is not ONNX, can't get config from it");
52+
}
53+
return config;
54+
}
55+
56+
void utils::add_ov_model_info(std::shared_ptr<ov::Model> model, const ov::AnyMap& config) {
57+
for (const auto& k : config) {
58+
model->set_rt_info(k.second, "model_info", k.first);
59+
}
60+
}

tests/cpp/test_accuracy.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ TEST_P(ModelParameterizedTest, SerializedAccuracyTest) {
146146

147147
const std::string& basename = data.name.substr(data.name.find_last_of("/\\") + 1);
148148
auto model_path = DATA_DIR + "/serialized/" + basename;
149+
150+
if (model_path.find(".onnx") != std::string::npos) {
151+
GTEST_SKIP() << "ONNX models are not serializable";
152+
}
153+
149154
if (data.type == "DetectionModel") {
150155
auto use_tiling = !data.input_res.empty();
151156
auto model = DetectionModel::load(model_path, {{"tiling", use_tiling}});
@@ -207,6 +212,10 @@ TEST_P(ModelParameterizedTest, AccuracyTestBatch) {
207212
const std::string& basename = data.name.substr(data.name.find_last_of("/\\") + 1);
208213
auto model_path = DATA_DIR + "/serialized/" + basename;
209214

215+
if (model_path.find(".onnx") != std::string::npos) {
216+
GTEST_SKIP() << "ONNX models are not serializable";
217+
}
218+
210219
if (data.type == "DetectionModel") {
211220
auto use_tiling = !data.input_res.empty();
212221
auto model = DetectionModel::load(model_path, {{"tiling", use_tiling}});

tests/python/accuracy/public_scope.json

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,18 @@
2323
}
2424
]
2525
},
26+
{
27+
"name": "otx_models/Lite-hrnet-s_mod2.onnx",
28+
"type": "SegmentationModel",
29+
"test_data": [
30+
{
31+
"image": "coco128/images/train2017/000000000074.jpg",
32+
"reference": [
33+
"0: 0.563, 1: 0.437, [426,640,2], [0], [0]; object: 0.520, 26, object: 0.530, 42, object: 0.501, 4, object: 0.507, 27, object: 0.503, 8, object: 0.502, 6, object: 0.505, 18, object: 0.504, 13, object: 0.524, 87, object: 0.521, 89, object: 0.757, 2706, "
34+
]
35+
}
36+
]
37+
},
2638
{
2739
"name": "otx_models/segmentation_model_with_xai_head.xml",
2840
"type": "SegmentationModel",
@@ -47,6 +59,18 @@
4759
}
4860
]
4961
},
62+
{
63+
"name": "otx_models/is_efficientnetb2b_maskrcnn_coco_reduced_onnx.onnx",
64+
"type": "MaskRCNNModel",
65+
"test_data": [
66+
{
67+
"image": "coco128/images/train2017/000000000074.jpg",
68+
"reference": [
69+
"458, 106, 495, 150, 1 (person): 0.818, 852, RotatedRect: 478.119 130.332 28.677 46.408 46.637; 0, 30, 178, 323, 2 (bicycle): 0.753, 26728, RotatedRect: 79.739 177.262 251.785 156.656 87.397; 0; [0]; person: 0.818, 139; bicycle: 0.753, 622; "
70+
]
71+
}
72+
]
73+
},
5074
{
5175
"name": "otx_models/is_resnet50_maskrcnn_coco_reduced.xml",
5276
"type": "MaskRCNNModel",
@@ -71,6 +95,18 @@
7195
}
7296
]
7397
},
98+
{
99+
"name": "otx_models/det_mobilenetv2_atss_bccd_onnx.onnx",
100+
"type": "DetectionModel",
101+
"test_data": [
102+
{
103+
"image": "BloodImage_00007.jpg",
104+
"reference": [
105+
"494, 159, 637, 308, 2 (WBC): 0.697; 28, 139, 135, 228, 1 (RBC): 0.628; 535, 375, 638, 479, 1 (RBC): 0.524; 513, 8, 633, 152, 1 (RBC): 0.430; 21, 291, 143, 399, 1 (RBC): 0.422; 196, 86, 410, 286, 1 (RBC): 0.422; [0]; [0]"
106+
]
107+
}
108+
]
109+
},
74110
{
75111
"name": "otx_models/detection_model_with_xai_head.xml",
76112
"type": "DetectionModel",
@@ -142,6 +178,16 @@
142178
}
143179
]
144180
},
181+
{
182+
"name": "otx_models/cls_mobilenetv3_large_cars.onnx",
183+
"type": "ClassificationModel",
184+
"test_data": [
185+
{
186+
"image": "coco128/images/train2017/000000000471.jpg",
187+
"reference": ["105 (194): 0.456, [0], [0], [196]"]
188+
}
189+
]
190+
},
145191
{
146192
"name": "otx_models/cls_efficient_b0_cars.xml",
147193
"type": "ClassificationModel",

0 commit comments

Comments
 (0)