@@ -105,7 +105,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
105105
106106 if (hw_target.find (" NPU" ) != std::string::npos) {
107107 KVDesc kv_desc;
108- kv_desc.max_prompt_len = PopIntAndCast (config, " MAX_PROMPT_LEN" ).value_or (1024u );
108+ kv_desc.max_prompt_len = PopIntAndCast (config, " MAX_PROMPT_LEN" ).value_or (3072u );
109109 kv_desc.min_response_len = PopIntAndCast (config, " MIN_RESPONSE_LEN" ).value_or (128u );
110110
111111 if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled ()) {
@@ -125,7 +125,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
125125 compiled_model = OVCore::Get ()->core .compile_model (model, hw_target, config);
126126 std::cout << " Stateful OV Model Compilation Complete" << std::endl;
127127
128- OVExeNetwork exe (compiled_model);
128+ OVExeNetwork exe (compiled_model, hw_target, true );
129129 return exe;
130130}
131131
@@ -134,19 +134,18 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_netwo
134134 ov::AnyMap& device_config,
135135 bool enable_causallm,
136136 const std::string& name) {
137- ov::CompiledModel obj ;
137+ OVExeNetwork exe ;
138138 try {
139139 if (enable_causallm) {
140140 auto mutable_model = ie_cnn_network->clone ();
141- auto compiled_model = OVCore::Get ()->StatefulCompileModel (mutable_model, hw_target, device_config);
142- obj = compiled_model.Get ();
141+ exe = OVCore::Get ()->StatefulCompileModel (mutable_model, hw_target, device_config);
143142 } else {
144- obj = core.compile_model (ie_cnn_network, hw_target, device_config);
143+ auto obj = core.compile_model (ie_cnn_network, hw_target, device_config);
144+ exe = OVExeNetwork (obj, hw_target);
145145 }
146146#ifndef NDEBUG
147147 printDebugInfo (obj);
148148#endif
149- OVExeNetwork exe (obj);
150149 return exe;
151150 } catch (const Exception& e) {
152151 ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -165,7 +164,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
165164#ifndef NDEBUG
166165 printDebugInfo (obj);
167166#endif
168- OVExeNetwork exe (obj);
167+ OVExeNetwork exe (obj, hw_target );
169168 return exe;
170169 } catch (const Exception& e) {
171170 ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -180,7 +179,7 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
180179 bool enable_causallm,
181180 std::string name) {
182181 try {
183- ov::CompiledModel obj ;
182+ OVExeNetwork exe ;
184183
185184 // Check if it's XML
186185 std::streampos originalPos = model_stream.tellg ();
@@ -194,7 +193,8 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194193 model_stream.seekg (originalPos);
195194
196195 if (header != " <?xml" ) {
197- obj = core.import_model (model_stream, hw_target, device_config);
196+ auto obj = core.import_model (model_stream, hw_target, device_config);
197+ exe = OVExeNetwork (obj, hw_target);
198198 } else {
199199 // Get path to bin file
200200 std::string bin_file;
@@ -232,17 +232,16 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
232232 std::shared_ptr<ov::Model> model = core.read_model (xml_content, weights_tensor);
233233
234234 if (enable_causallm) {
235- auto compiled_model = OVCore::Get ()->StatefulCompileModel (model, hw_target, device_config);
236- obj = compiled_model.Get ();
235+ exe = OVCore::Get ()->StatefulCompileModel (model, hw_target, device_config);
237236 } else {
238- obj = core.compile_model (model, hw_target, device_config);
237+ auto obj = core.compile_model (model, hw_target, device_config);
238+ exe = OVExeNetwork (obj, hw_target);
239239 }
240240 }
241241
242242#ifndef NDEBUG
243243 printDebugInfo (obj);
244244#endif
245- OVExeNetwork exe (obj);
246245 return exe;
247246 } catch (const Exception& e) {
248247 ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -330,11 +329,16 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) {
330329 core.set_property (device_type, {ov::num_streams (num_streams)});
331330}
332331
333- OVInferRequest OVExeNetwork::CreateInferRequest () {
332+ std::shared_ptr< OVInferRequest> OVExeNetwork::CreateInferRequest () {
334333 try {
335334 auto infReq = obj.create_infer_request ();
336- OVInferRequest inf_obj (std::move (infReq));
337- return inf_obj;
335+ std::shared_ptr<OVInferRequest> ovInfReq;
336+ if (_stateful_llm) {
337+ ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), _device);
338+ } else {
339+ ovInfReq = std::make_shared<OVInferRequest>(std::move (infReq));
340+ }
341+ return ovInfReq;
338342 } catch (const Exception& e) {
339343 ORT_THROW (log_tag + " Exception while creating InferRequest object: " + e.what ());
340344 } catch (...) {
@@ -368,16 +372,6 @@ std::string OVInferRequest::GetInputTensorName(uint32_t index) {
368372void OVInferRequest::SetTensor (const std::string& name, OVTensorPtr& blob) {
369373 try {
370374 ovInfReq.set_tensor (name, *(blob.get ()));
371-
372- if (name == " input_ids" ) {
373- // Since we can't seem to set at ORT GenAI layer right now, we just set it here
374- // as a workaround.
375- // TODO: Fix this.
376- ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
377- std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
378- ovInfReq.set_tensor (" beam_idx" , beam_idx);
379- }
380-
381375 } catch (const Exception& e) {
382376 ORT_THROW (log_tag + " Cannot set Remote Blob for output: " + name + e.what ());
383377 } catch (...) {
@@ -423,5 +417,76 @@ void OVInferRequest::QueryStatus() {
423417 std::cout << " ovInfReq.query_state()"
424418 << " " ;
425419}
420+
421+ void StatefulOVInferRequest::_pre_infer () {
422+ // Since we can't seem to set at ORT GenAI layer right now, we just set it here
423+ // as a workaround.
424+ // TODO: Fix this.
425+ ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
426+ std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
427+ ovInfReq.set_tensor (" beam_idx" , beam_idx);
428+
429+ // For NPU, we need to cache input_ids and position_ids for
430+ // chat-mode support.
431+ if (device.find (" NPU" ) != std::string::npos) {
432+ auto input_ids_tensor = ovInfReq.get_tensor (" input_ids" );
433+
434+ // add input_ids to our cache
435+ {
436+ auto * pData = input_ids_tensor.data <int64_t >();
437+ for (size_t i = 0 ; i < input_ids_tensor.get_size (); i++) {
438+ cached_input_ids.push_back (pData[i]);
439+ }
440+ }
441+
442+ // add position_ids to our cache
443+ {
444+ auto position_ids = ovInfReq.get_tensor (" position_ids" );
445+ auto * pData = position_ids.data <int64_t >();
446+ for (size_t i = 0 ; i < position_ids.get_size (); i++) {
447+ cached_position_ids.push_back (pData[i]);
448+ }
449+ }
450+
451+ // if we're about to run prefill model
452+ if (input_ids_tensor.get_size () > 1 ) {
453+ // if the input_ids size doesn't equal cached size of the input_ids
454+ // then it means that we're running 2nd (or later) prompt.
455+ if (input_ids_tensor.get_shape ()[1 ] != cached_input_ids.size ()) {
456+ // set a new input_ids tensor with the content of our cached input_ids
457+ {
458+ auto new_shape = input_ids_tensor.get_shape ();
459+ new_shape[1 ] = cached_input_ids.size ();
460+ auto new_input_ids = ov::Tensor (input_ids_tensor.get_element_type (), new_shape);
461+ auto * pNewInputIds = new_input_ids.data <int64_t >();
462+ std::memcpy (pNewInputIds, cached_input_ids.data (), cached_input_ids.size () * sizeof (int64_t ));
463+ ovInfReq.set_tensor (" input_ids" , new_input_ids);
464+ }
465+
466+ // set a new position_ids tensor with the content of our cached position_ids
467+ {
468+ auto position_ids_tensor = ovInfReq.get_tensor (" position_ids" );
469+ auto new_shape = position_ids_tensor.get_shape ();
470+ new_shape[1 ] = cached_position_ids.size ();
471+ auto new_position_ids = ov::Tensor (position_ids_tensor.get_element_type (), new_shape);
472+ auto * pNewPositionIds = new_position_ids.data <int64_t >();
473+ std::memcpy (pNewPositionIds, cached_position_ids.data (), cached_position_ids.size () * sizeof (int64_t ));
474+ ovInfReq.set_tensor (" position_ids" , new_position_ids);
475+ }
476+ }
477+ }
478+ }
479+ }
480+
481+ void StatefulOVInferRequest::StartAsync () {
482+ _pre_infer ();
483+ OVInferRequest::StartAsync ();
484+ }
485+
486+ void StatefulOVInferRequest::Infer () {
487+ _pre_infer ();
488+ OVInferRequest::Infer ();
489+ }
490+
426491} // namespace openvino_ep
427492} // namespace onnxruntime
0 commit comments