Skip to content

Commit 327dbe4

Browse files
committed
Add constraint for IO Buffer feature with new compile model API
1 parent acd8444 commit 327dbe4

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,44 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto,
4646
model_proto.SerializeToOstream(outfile);
4747
}
4848
#endif
49-
50-
if ((subgraph_context.precision == InferenceEngine::Precision::FP16)||
51-
(!global_context.is_wholly_supported_graph)){
52-
try {
53-
ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
54-
49+
try{
50+
if (global_context.is_wholly_supported_graph){
51+
#if defined(IO_BUFFER_ENABLED)
5552
if ((global_context.device_type.find("GPU") != std::string::npos) &&
56-
(global_context_.context != nullptr) &&
57-
(openvino_ep::BackendManager::GetGlobalContext().is_wholly_supported_graph)) {
58-
#if defined(IO_BUFFER_ENABLED)
59-
LOGS_DEFAULT(INFO) << log_tag << "IO Buffering Enabled";
60-
cl_context ctx = static_cast<cl_context>(global_context_.context);
61-
remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx);
62-
exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name);
63-
#endif
64-
} else {
53+
(global_context_.context != nullptr)){
54+
LOGS_DEFAULT(INFO) << log_tag << "IO Buffering Enabled";
55+
cl_context ctx = static_cast<cl_context>(global_context_.context);
56+
remote_context_ = new ov::intel_gpu::ocl::ClContext(global_context_.ie_core.Get(), ctx);
57+
ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
58+
exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, remote_context_, subgraph_context_.subgraph_name);
59+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
60+
} else if (subgraph_context.precision == InferenceEngine::Precision::FP16){
61+
ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
6562
exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, config, device_config, subgraph_context_.subgraph_name);
63+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
64+
} else {
65+
const std::string model = model_proto.SerializeAsString();
66+
exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, config, device_config, subgraph_context_.subgraph_name);
67+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
68+
}
69+
#else
70+
if (subgraph_context.precision == InferenceEngine::Precision::FP16){
71+
ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
72+
exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, config, device_config, subgraph_context_.subgraph_name);
73+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
74+
} else {
75+
const std::string model = model_proto.SerializeAsString();
76+
exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, config, device_config, subgraph_context_.subgraph_name);
77+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
6678
}
67-
}catch (const char* msg) {
68-
throw(msg);
69-
}
70-
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
71-
} else {
72-
std::string model;
73-
model_proto.SerializeToString(model);
74-
exe_network_ = global_context_.ie_core.LoadNetwork(model, hw_target, config, device_config, subgraph_context_.subgraph_name);
75-
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
79+
#endif
80+
} else{
81+
ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_);
82+
exe_network_ = global_context_.ie_core.LoadNetwork(ie_cnn_network_, hw_target, config, device_config, subgraph_context_.subgraph_name);
83+
LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin";
84+
}
85+
}catch (const char* msg) {
86+
throw(msg);
7687
}
7788

7889

0 commit comments

Comments
 (0)