Skip to content

Commit c2791bc

Browse files
committed
update: Refactor, fix & enable EPContext Import for XML & BIN
1 parent 665f4c2 commit c2791bc

File tree

7 files changed

+154
-70
lines changed

7 files changed

+154
-70
lines changed

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,14 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
8585
auto auto_unified_compile = ((hw_target.find("AUTO") == std::string::npos) ||
8686
(session_context_.OpenVINO_Version.at(0) >= 2024 &&
8787
session_context_.OpenVINO_Version.at(1) > 2));
88-
if (subgraph_context_.is_ep_ctx_graph) {
88+
if (subgraph_context_.is_ep_ctx_graph && enable_causallm) {
8989
// If the blob is held in an EPContext node, then skip FE+Compile
9090
// and directly move on to creating a backend with the executable blob
9191
exe_network_ = OVCore::Get()->ImportModel(*model_stream,
9292
hw_target,
9393
device_config,
94-
subgraph_context_.subgraph_name);
94+
enable_causallm,
95+
session_context_.onnx_model_path_name.string());
9596
model_stream.reset(); // Delete stream after it is no longer needed
9697
} else if (!session_context_.has_external_weights &&
9798
!subgraph_context_.has_dynamic_input_shape &&
@@ -285,7 +286,7 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) {
285286
//// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO")
286287
std::unordered_set<std::string> supported_mode = {"AUTO", "HETERO", "MULTI"};
287288
auto device_mode = find_device_type_mode(session_context_.device_type);
288-
ORT_ENFORCE(supported_mode.find(device_mode)!=supported_mode.end(), " Invalid device mode is passed : " , session_context_.device_type);
289+
ORT_ENFORCE(supported_mode.find(device_mode) != supported_mode.end(), " Invalid device mode is passed : ", session_context_.device_type);
289290
// Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"])
290291
auto individual_devices = parse_individual_devices(session_context_.device_type);
291292
if (!device_mode.empty()) individual_devices.emplace_back(device_mode);
@@ -379,7 +380,7 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque
379380
// for the stateful PoC, the ONNX model will have KV cache (past/present) tensors, but
380381
// we internally converted it to stateful, which removed these. So, we just continue here
381382
// to avoid runtime exception.
382-
if (input_name.empty()) continue;
383+
if (input_name.empty() || input_name == "beam_idx") continue;
383384

384385
ORT_ENFORCE(!input_name.empty(), log_tag,
385386
"Input names mismatch between OpenVINO and ONNX. ", onnx_input_name,

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ OpenVINOExecutionProvider::OpenVINOExecutionProvider(const ProviderInfo& info, s
9696
// If idx is 0, maybe index is not set (e.g. GPU)
9797
// Then the device is found if we have at least one device of the type
9898
if (device_idx == 0 && available_devices.size() >= 1) {
99-
device_found = true;
99+
device_found = true;
100100
} else {
101101
// Find full device (e.g GPU.1) in the list
102102
if (std::find(std::begin(available_devices), std::end(available_devices), device) != std::end(available_devices))

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void ParseConfigOptions(ProviderInfo& pi) {
3333
map["NPU_COMPILATION_MODE_PARAMS"] = "enable-wd-blockarg-input=true compute-layers-with-higher-precision=Sqrt,Power,ReduceSum";
3434
pi.load_config["NPU"] = std::move(map);
3535
}
36-
3736
}
3837

3938
void* ParseUint64(const ProviderOptions& provider_options, std::string option_name) {

onnxruntime/core/providers/openvino/ov_interface.cc

Lines changed: 115 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -74,61 +74,72 @@ std::shared_ptr<OVNetwork> OVCore::ReadModel(std::string&& model, const std::str
7474
}
7575
}
7676

77-
OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_network,
78-
std::string& hw_target,
79-
ov::AnyMap& device_config,
80-
bool enable_causallm,
81-
const std::string& name) {
82-
ov::CompiledModel obj;
83-
try {
84-
if (enable_causallm) {
85-
ov::AnyMap config;
77+
OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
78+
std::string& hw_target,
79+
const ov::AnyMap& device_config) {
80+
ov::CompiledModel compiled_model;
81+
ov::AnyMap config = device_config;
8682

87-
// Create a clone of ie_cnn_network, since it's a const ov::Model, and we need to patch it..
88-
// Note! With this default path, the model runs but produces garbage (for NPUW). For CPU it's fine.
89-
auto mutable_model = ie_cnn_network->clone();
83+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
84+
std::cout << "Stateless OV Model Statistic:" << std::endl;
85+
LogBasicModelInfo(model);
86+
}
9087

91-
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
92-
std::cout << "Stateless OV Model Statistic" << std::endl;
93-
LogBasicModelInfo(mutable_model);
94-
}
95-
LogBasicModelInfo(mutable_model);
88+
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
89+
bool status = IsStateful(model);
90+
std::cout << "IsStateful Status:\t" << status << std::endl;
91+
if (!status) {
92+
PatchStatefulDecoder(model);
93+
}
9694

97-
LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl;
98-
PatchStatefulDecoder(mutable_model);
95+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
96+
std::cout << "Stateful OV Model Statistic:" << std::endl;
97+
LogBasicModelInfo(model);
98+
}
9999

100-
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
101-
std::cout << "Stateful OV Model Statistic" << std::endl;
102-
LogBasicModelInfo(mutable_model);
103-
}
100+
auto kv_pos = GetKVAxesPos(model);
101+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
102+
std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl;
103+
std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
104+
}
104105

105-
// This patches the model so that it only produces the logits required for sampling.
106-
// Actually either way that happens within NPUW::LLMCompiledModel creation, but this is
107-
// here mostly to align this behavior for other devices (CPU, GPU).
108-
ApplySliceBeforeMatmulTransformation(mutable_model);
106+
if (hw_target.find("NPU") != std::string::npos) {
107+
KVDesc kv_desc;
108+
kv_desc.max_prompt_len = PopIntAndCast(config, "MAX_PROMPT_LEN").value_or(1024u);
109+
kv_desc.min_response_len = PopIntAndCast(config, "MIN_RESPONSE_LEN").value_or(128u);
109110

110-
auto kv_pos = GetKVAxesPos(mutable_model);
111-
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
112-
std::cout << "kv_pos.batch = " << kv_pos.batch << std::endl;
113-
std::cout << "kv_pos.seq_len = " << kv_pos.seq_len << std::endl;
114-
}
111+
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
112+
std::cout << "kv_desc.max_prompt_len:\t" << kv_desc.max_prompt_len << std::endl;
113+
std::cout << "kv_desc.min_response_len:\t" << kv_desc.min_response_len << std::endl;
114+
}
115115

116-
if (hw_target.find("NPU") != std::string::npos) {
117-
KVDesc kv_desc;
118-
kv_desc.max_prompt_len = PopIntAndCast(device_config, "MAX_PROMPT_LEN").value_or(1024u);
119-
kv_desc.min_response_len = PopIntAndCast(device_config, "MIN_RESPONSE_LEN").value_or(128u);
116+
UpdateNPUConfig(config, kv_pos, kv_desc);
117+
} else {
118+
// This patches the model so that it only produces the logits required for sampling.
119+
// Actually either way that happens within NPUW::LLMCompiledModel creation, but this is
120+
// here mostly to align this behavior for other devices (CPU, GPU).
121+
ApplySliceBeforeMatmulTransformation(model);
122+
}
120123

121-
if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) {
122-
std::cout << "kv_desc.max_prompt_len = " << kv_desc.max_prompt_len << std::endl;
123-
std::cout << "kv_desc.min_response_len = " << kv_desc.min_response_len << std::endl;
124-
}
124+
std::cout << "Compiling Stateful OV Model ..." << std::endl;
125+
compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config);
126+
std::cout << "Stateful OV Model Compilation Complete" << std::endl;
125127

126-
UpdateNPUConfig(config, kv_pos, kv_desc);
127-
}
128+
OVExeNetwork exe(compiled_model);
129+
return exe;
130+
}
128131

129-
std::cout << "Compiling Stateful OV Model..." << std::endl;
130-
obj = core.compile_model(mutable_model, hw_target, config);
131-
std::cout << "Stateful OV Model Compilation Complete" << std::endl;
132+
OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_network,
133+
std::string& hw_target,
134+
ov::AnyMap& device_config,
135+
bool enable_causallm,
136+
const std::string& name) {
137+
ov::CompiledModel obj;
138+
try {
139+
if (enable_causallm) {
140+
auto mutable_model = ie_cnn_network->clone();
141+
auto compiled_model = OVCore::Get()->StatefulCompileModel(mutable_model, hw_target, device_config);
142+
obj = compiled_model.Get();
132143
} else {
133144
obj = core.compile_model(ie_cnn_network, hw_target, device_config);
134145
}
@@ -166,10 +177,68 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
166177
OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
167178
std::string hw_target,
168179
const ov::AnyMap& device_config,
180+
bool enable_causallm,
169181
std::string name) {
170182
try {
171183
ov::CompiledModel obj;
172-
obj = core.import_model(model_stream, hw_target, device_config);
184+
185+
// Check if it's XML
186+
std::streampos originalPos = model_stream.tellg();
187+
// Allocate space for "<?xml"
188+
std::string header(5, '\0');
189+
model_stream.read(&header[0], 5);
190+
191+
// Clear any read errors
192+
model_stream.clear();
193+
// Restore the stream position (important for reusing the stream)
194+
model_stream.seekg(originalPos);
195+
196+
if (header != "<?xml") {
197+
obj = core.import_model(model_stream, hw_target, device_config);
198+
} else {
199+
// Get path to bin file
200+
std::string bin_file;
201+
if (name.size() >= 5 && name.substr(name.size() - 5) == ".onnx") {
202+
bin_file = name;
203+
bin_file.replace(name.size() - 5, 5, ".bin");
204+
} else {
205+
throw std::runtime_error("Invalid model name. Make sure *.onnx, *.xml, and *.bin carry the same name.");
206+
}
207+
208+
// Read the model XML into a string
209+
std::stringstream xml_stream;
210+
xml_stream << model_stream.rdbuf();
211+
std::string xml_content = xml_stream.str();
212+
213+
// Read model.bin into a vector
214+
std::ifstream bin_stream;
215+
bin_stream.open(bin_file, std::ios::binary);
216+
if (!bin_stream.is_open()) {
217+
throw std::runtime_error("Failed to open " + bin_file);
218+
}
219+
220+
bin_stream.seekg(0, std::ios::end);
221+
std::streamsize size = bin_stream.tellg();
222+
bin_stream.seekg(0, std::ios::beg);
223+
std::vector<uint8_t> bin_data(size);
224+
if (!bin_stream.read(reinterpret_cast<char*>(bin_data.data()), size)) {
225+
throw std::runtime_error("Failed to read binary data from " + bin_file);
226+
}
227+
228+
// Create an ov::Tensor for weights
229+
ov::Tensor weights_tensor(ov::element::u8, {bin_data.size()}, bin_data.data());
230+
231+
// Load the model explicitly with XML content and weights
232+
std::shared_ptr<ov::Model> model = core.read_model(xml_content, weights_tensor);
233+
234+
if (enable_causallm) {
235+
auto compiled_model = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config);
236+
obj = compiled_model.Get();
237+
} else {
238+
obj = core.compile_model(model, hw_target, device_config);
239+
}
240+
}
241+
173242
#ifndef NDEBUG
174243
printDebugInfo(obj);
175244
#endif

onnxruntime/core/providers/openvino/ov_interface.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ struct OVCore : WeakSingleton<OVCore> {
6969
// OV Interface For Reading Model
7070
std::shared_ptr<OVNetwork> ReadModel(std::string&& model_stream, const std::string& model_path);
7171

72+
OVExeNetwork StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
73+
std::string& hw_target,
74+
const ov::AnyMap& device_config);
7275
// OV Interface for Compiling OV Model Type
7376
OVExeNetwork CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_network,
7477
std::string& hw_target,
@@ -84,6 +87,7 @@ struct OVCore : WeakSingleton<OVCore> {
8487
OVExeNetwork ImportModel(std::istream& model_stream,
8588
std::string hw_target,
8689
const ov::AnyMap& device_config,
90+
bool enable_causallm,
8791
std::string name);
8892
#ifdef IO_BUFFER_ENABLED
8993
OVExeNetwork CompileModel(std::shared_ptr<const OVNetwork>& model,

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
namespace onnxruntime {
77
namespace openvino_ep {
88

9-
10-
void LogBasicModelInfo(const std::shared_ptr<const ov::Model>& model) {
9+
void LogBasicModelInfo(const std::shared_ptr<const ov::Model>& model) {
1110
std::cout << "Model Name: " << model->get_friendly_name() << std::endl;
1211

1312
// Dump information about model inputs/outputs
@@ -37,7 +36,7 @@ namespace openvino_ep {
3736
return;
3837
}
3938

40-
bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::string& name_to_match) {
39+
bool ModelHasInputOutputNames(std::shared_ptr<ov::Model> model, const std::string& name_to_match) {
4140
for (const ov::Output<ov::Node>& input : model->inputs()) {
4241
auto& names = input.get_names();
4342

@@ -60,10 +59,10 @@ namespace openvino_ep {
6059
return false;
6160
}
6261

63-
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
64-
std::vector<std::string>& not_kv_inputs,
65-
const std::vector<std::string>& key_value_input_names,
66-
int gather_dim) {
62+
void FuseCacheReorder(std::shared_ptr<ov::Model> ov_model,
63+
std::vector<std::string>& not_kv_inputs,
64+
const std::vector<std::string>& key_value_input_names,
65+
int gather_dim) {
6766
if (ModelHasInputOutputNames(ov_model, "beam_idx")) {
6867
throw std::runtime_error("Model already has fused cache");
6968
}
@@ -101,9 +100,9 @@ namespace openvino_ep {
101100
ov_model->validate_nodes_and_infer_types();
102101
}
103102

104-
void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
105-
const std::vector<std::string>& key_value_input_names,
106-
const std::vector<std::string>& key_value_output_names) {
103+
void MakeStateful(std::shared_ptr<ov::Model>& ov_model,
104+
const std::vector<std::string>& key_value_input_names,
105+
const std::vector<std::string>& key_value_output_names) {
107106
std::map<std::string, std::string> input_output_map;
108107

109108
// Create mapping for KV-cache inputs and outputs
@@ -119,7 +118,7 @@ namespace openvino_ep {
119118

120119
// Converted to C++ from here:
121120
// https://github.com/huggingface/optimum-intel/blob/main/optimum/exporters/openvino/stateful.py#L281
122-
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
121+
void PatchStatefulDecoder(std::shared_ptr<ov::Model> model) {
123122
std::vector<std::string> key_value_input_names;
124123
std::vector<std::string> not_kv_inputs;
125124
for (const ov::Output<ov::Node>& input : model->inputs()) {
@@ -166,7 +165,7 @@ namespace openvino_ep {
166165
}
167166

168167
// Some other utility functions copied from OpenVINO GenAI
169-
bool HasOpWithType(const std::shared_ptr<const ov::Model>& function, const std::string& type_name) {
168+
bool HasOpWithType(const std::shared_ptr<const ov::Model>& function, const std::string& type_name) {
170169
for (const auto& op : function->get_ops()) {
171170
if (op->get_type_name() == type_name) {
172171
return true;
@@ -175,7 +174,7 @@ namespace openvino_ep {
175174
return false;
176175
}
177176

178-
std::tuple<std::shared_ptr<ov::Node>, int64_t> FindLLMMatmul(const std::shared_ptr<ov::Model>& model) {
177+
std::tuple<std::shared_ptr<ov::Node>, int64_t> FindLLMMatmul(const std::shared_ptr<ov::Model>& model) {
179178
auto last_node = model->output(0).get_node()->input_value(0).get_node_shared_ptr();
180179
std::shared_ptr<ov::Node> matmul = ov::as_type_ptr<ov::op::v0::MatMul>(last_node);
181180

@@ -206,7 +205,7 @@ namespace openvino_ep {
206205
return std::make_tuple(matmul, slice_gather_dim);
207206
}
208207

209-
void ApplySliceBeforeMatmulTransformation(std::shared_ptr<ov::Model> model) {
208+
void ApplySliceBeforeMatmulTransformation(std::shared_ptr<ov::Model> model) {
210209
std::shared_ptr<ov::Node> matmul = nullptr;
211210
int64_t slice_gather_dim = -1;
212211
std::tie(matmul, slice_gather_dim) = FindLLMMatmul(model);
@@ -221,13 +220,13 @@ namespace openvino_ep {
221220
}
222221
}
223222

224-
void UpdateConfig(ov::AnyMap& config, const std::pair<std::string, ov::Any>& pair) {
223+
void UpdateConfig(ov::AnyMap& config, const std::pair<std::string, ov::Any>& pair) {
225224
if (config.count(pair.first) == 0) {
226225
config.insert(pair);
227226
}
228227
}
229228

230-
std::optional<ov::Any> PopOption(ov::AnyMap& config, const std::string& option_name) {
229+
std::optional<ov::Any> PopOption(ov::AnyMap& config, const std::string& option_name) {
231230
if (auto it = config.find(option_name); it != config.end()) {
232231
std::optional<ov::Any> found = std::make_optional(it->second);
233232
config.erase(it);
@@ -236,14 +235,13 @@ namespace openvino_ep {
236235
return std::nullopt;
237236
}
238237

239-
void RenameKey(ov::AnyMap& config, const std::string& old_key, const std::string& new_key) {
238+
void RenameKey(ov::AnyMap& config, const std::string& old_key, const std::string& new_key) {
240239
if (config.count(old_key) != 0) {
241240
auto opt_value = PopOption(config, old_key);
242241
config[new_key] = opt_value.value();
243242
}
244243
}
245244

246-
247245
KVAxesPosition GetKVAxesPos(std::shared_ptr<const ov::Model> model) {
248246
// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
249247
// therefore usually seq_length_axis = 2 and batch = 0
@@ -324,6 +322,17 @@ std::optional<uint32_t> PopIntAndCast(ov::AnyMap& config, const std::string& key
324322
return std::nullopt;
325323
}
326324

325+
bool IsStateful(const std::shared_ptr<ov::Model>& model) {
326+
for (auto&& ptr : model->get_ordered_ops()) {
327+
if (ov::is_type<ov::op::v3::ReadValue>(ptr) ||
328+
ov::is_type<ov::op::v6::ReadValue>(ptr) ||
329+
ov::is_type<ov::op::v3::Assign>(ptr) ||
330+
ov::is_type<ov::op::v6::Assign>(ptr)) {
331+
return true;
332+
}
333+
}
334+
return false;
335+
}
327336

328337
} // namespace openvino_ep
329338
} // namespace onnxruntime

onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,7 @@ void UpdateNPUConfig(ov::AnyMap& config, const KVAxesPosition& kv_pos, const KVD
6262
std::optional<ov::Any> PopOptionNew(ov::AnyMap& config, const std::string& option_name);
6363
std::optional<uint32_t> PopIntAndCast(ov::AnyMap& config, const std::string& key);
6464

65+
bool IsStateful(const std::shared_ptr<ov::Model>& model);
66+
6567
} // namespace openvino_ep
6668
} // namespace onnxruntime

0 commit comments

Comments
 (0)