@@ -125,7 +125,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr<OVNetwork>& model,
125125 compiled_model = OVCore::Get ()->core .compile_model (model, hw_target, config);
126126 std::cout << " Stateful OV Model Compilation Complete" << std::endl;
127127
128- OVExeNetwork exe (compiled_model);
128+ OVExeNetwork exe (compiled_model, hw_target, true );
129129 return exe;
130130}
131131
@@ -134,19 +134,18 @@ OVExeNetwork OVCore::CompileModel(std::shared_ptr<const OVNetwork>& ie_cnn_netwo
134134 ov::AnyMap& device_config,
135135 bool enable_causallm,
136136 const std::string& name) {
137- ov::CompiledModel obj ;
137+ OVExeNetwork exe ;
138138 try {
139139 if (enable_causallm) {
140140 auto mutable_model = ie_cnn_network->clone ();
141- auto compiled_model = OVCore::Get ()->StatefulCompileModel (mutable_model, hw_target, device_config);
142- obj = compiled_model.Get ();
141+ exe = OVCore::Get ()->StatefulCompileModel (mutable_model, hw_target, device_config);
143142 } else {
144- obj = core.compile_model (ie_cnn_network, hw_target, device_config);
143+ auto obj = core.compile_model (ie_cnn_network, hw_target, device_config);
144+ exe = OVExeNetwork (obj, hw_target);
145145 }
146146#ifndef NDEBUG
147147 printDebugInfo (obj);
148148#endif
149- OVExeNetwork exe (obj);
150149 return exe;
151150 } catch (const Exception& e) {
152151 ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -165,7 +164,7 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model,
165164#ifndef NDEBUG
166165 printDebugInfo (obj);
167166#endif
168- OVExeNetwork exe (obj);
167+ OVExeNetwork exe (obj, hw_target );
169168 return exe;
170169 } catch (const Exception& e) {
171170 ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -180,7 +179,7 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
180179 bool enable_causallm,
181180 std::string name) {
182181 try {
183- ov::CompiledModel obj ;
182+ OVExeNetwork exe ;
184183
185184 // Check if it's XML
186185 std::streampos originalPos = model_stream.tellg ();
@@ -194,7 +193,8 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
194193 model_stream.seekg (originalPos);
195194
196195 if (header != " <?xml" ) {
197- obj = core.import_model (model_stream, hw_target, device_config);
196+ auto obj = core.import_model (model_stream, hw_target, device_config);
197+ exe = OVExeNetwork (obj, hw_target);
198198 } else {
199199 // Get path to bin file
200200 std::string bin_file;
@@ -232,17 +232,16 @@ OVExeNetwork OVCore::ImportModel(std::istream& model_stream,
232232 std::shared_ptr<ov::Model> model = core.read_model (xml_content, weights_tensor);
233233
234234 if (enable_causallm) {
235- auto compiled_model = OVCore::Get ()->StatefulCompileModel (model, hw_target, device_config);
236- obj = compiled_model.Get ();
235+ exe = OVCore::Get ()->StatefulCompileModel (model, hw_target, device_config);
237236 } else {
238- obj = core.compile_model (model, hw_target, device_config);
237+ auto obj = core.compile_model (model, hw_target, device_config);
238+ exe = OVExeNetwork (obj, hw_target);
239239 }
240240 }
241241
242242#ifndef NDEBUG
243243 printDebugInfo (obj);
244244#endif
245- OVExeNetwork exe (obj);
246245 return exe;
247246 } catch (const Exception& e) {
248247 ORT_THROW (log_tag + " Exception while Loading Network for graph: " + name + e.what ());
@@ -330,11 +329,16 @@ void OVCore::SetStreams(const std::string& device_type, int num_streams) {
330329 core.set_property (device_type, {ov::num_streams (num_streams)});
331330}
332331
333- OVInferRequest OVExeNetwork::CreateInferRequest () {
332+ std::shared_ptr< OVInferRequest> OVExeNetwork::CreateInferRequest () {
334333 try {
335334 auto infReq = obj.create_infer_request ();
336- OVInferRequest inf_obj (std::move (infReq));
337- return inf_obj;
335+ std::shared_ptr<OVInferRequest> ovInfReq;
336+ if (_stateful_llm) {
337+ ovInfReq = std::make_shared<StatefulOVInferRequest>(std::move (infReq), _device);
338+ } else {
339+ ovInfReq = std::make_shared<OVInferRequest>(std::move (infReq));
340+ }
341+ return ovInfReq;
338342 } catch (const Exception& e) {
339343 ORT_THROW (log_tag + " Exception while creating InferRequest object: " + e.what ());
340344 } catch (...) {
@@ -368,16 +372,6 @@ std::string OVInferRequest::GetInputTensorName(uint32_t index) {
368372void OVInferRequest::SetTensor (const std::string& name, OVTensorPtr& blob) {
369373 try {
370374 ovInfReq.set_tensor (name, *(blob.get ()));
371-
372- if (name == " input_ids" ) {
373- // Since we can't seem to set at ORT GenAI layer right now, we just set it here
374- // as a workaround.
375- // TODO: Fix this.
376- ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
377- std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
378- ovInfReq.set_tensor (" beam_idx" , beam_idx);
379- }
380-
381375 } catch (const Exception& e) {
382376 ORT_THROW (log_tag + " Cannot set Remote Blob for output: " + name + e.what ());
383377 } catch (...) {
@@ -423,5 +417,121 @@ void OVInferRequest::QueryStatus() {
423417 std::cout << " ovInfReq.query_state()"
424418 << " " ;
425419}
420+
421+ void StatefulOVInferRequest::_pre_infer () {
422+ // Since we can't seem to set at ORT GenAI layer right now, we just set it here
423+ // as a workaround.
424+ // TODO: Fix this.
425+ ov::Tensor beam_idx = ov::Tensor (ov::element::i32 , {1 });
426+ std::fill_n (beam_idx.data <int32_t >(), 1 , 0 );
427+ ovInfReq.set_tensor (" beam_idx" , beam_idx);
428+
429+ // For NPU, we need to cache input_ids and position_ids for
430+ // chat-mode support.
431+ if (device.find (" NPU" ) != std::string::npos) {
432+ auto input_ids_tensor = ovInfReq.get_tensor (" input_ids" );
433+
434+ // add input_ids to our cache
435+ {
436+ auto * pData = input_ids_tensor.data <int64_t >();
437+ for (size_t i = 0 ; i < input_ids_tensor.get_size (); i++) {
438+ cached_input_ids.push_back (pData[i]);
439+ }
440+ }
441+
442+ // add position_ids to our cache
443+ {
444+ auto position_ids = ovInfReq.get_tensor (" position_ids" );
445+ auto * pData = position_ids.data <int64_t >();
446+ for (size_t i = 0 ; i < position_ids.get_size (); i++) {
447+ cached_position_ids.push_back (pData[i]);
448+ }
449+ }
450+
451+ // if we're about to run prefill model
452+ if (input_ids_tensor.get_size () > 1 ) {
453+ // if the input_ids size doesn't equal cached size of the input_ids
454+ // then it means that we're running 2nd (or later) prompt.
455+ if (input_ids_tensor.get_shape ()[1 ] != cached_input_ids.size ()) {
456+ // set a new input_ids tensor with the content of our cached input_ids
457+ {
458+ auto new_shape = input_ids_tensor.get_shape ();
459+ new_shape[1 ] = cached_input_ids.size ();
460+ auto new_input_ids = ov::Tensor (input_ids_tensor.get_element_type (), new_shape);
461+ auto * pNewInputIds = new_input_ids.data <int64_t >();
462+ std::memcpy (pNewInputIds, cached_input_ids.data (), cached_input_ids.size () * sizeof (int64_t ));
463+ ovInfReq.set_tensor (" input_ids" , new_input_ids);
464+ }
465+
466+ // set a new position_ids tensor with the content of our cached position_ids
467+ {
468+ auto position_ids_tensor = ovInfReq.get_tensor (" position_ids" );
469+ auto new_shape = position_ids_tensor.get_shape ();
470+ new_shape[1 ] = cached_position_ids.size ();
471+ auto new_position_ids = ov::Tensor (position_ids_tensor.get_element_type (), new_shape);
472+ auto * pNewPositionIds = new_position_ids.data <int64_t >();
473+ std::memcpy (pNewPositionIds, cached_position_ids.data (), cached_position_ids.size () * sizeof (int64_t ));
474+ ovInfReq.set_tensor (" position_ids" , new_position_ids);
475+ }
476+ }
477+ }
478+ }
479+ }
480+
481+ void StatefulOVInferRequest::StartAsync () {
482+ _pre_infer ();
483+ OVInferRequest::StartAsync ();
484+ }
485+
486+ void StatefulOVInferRequest::Infer () {
487+ _pre_infer ();
488+ OVInferRequest::Infer ();
489+ }
490+
491+ void StatefulOVInferRequest::RewindKVCache (size_t index) {
492+ if (device == " NPU" ) {
493+ std::cout << " RewindKVCache on NPU: Trimming cached input_ids / position_ids to length "
494+ << index << std::endl;
495+ if (cached_input_ids.size () > index) {
496+ cached_input_ids.resize (index);
497+ }
498+
499+ if (cached_position_ids.size () > index) {
500+ cached_position_ids.resize (index);
501+ }
502+ } else {
503+ std::cout << " OVInferRequest::RewindKVCache: Trimming internal states to length = "
504+ << index << std::endl;
505+ if (index == 0 ) {
506+ // in this case, since we're trimming *all* of the KVCache, just reset the state.
507+ ovInfReq.reset_state ();
508+ } else {
509+ // retrieve kvcache states, and trim...
510+ // Most of this code was grabbed from here:
511+ // https://github.com/openvinotoolkit/openvino.genai/blob/releases/2025/1/src/cpp/src/utils.cpp#L329
512+ auto states = ovInfReq.query_state ();
513+ for (auto & state : states) {
514+ ov::Tensor old_tensor = state.get_state ();
515+ // [BATCH_SIZE, num_kv_heads, seq_len, head_size]
516+ auto shape = old_tensor.get_shape ();
517+
518+ if (shape[2 ] > index) {
519+ shape[2 ] = index;
520+
521+ ov::Coordinate new_shape_begin{0 , 0 , 0 , 0 };
522+ ov::Coordinate new_shape_end{shape};
523+
524+ auto trimmed_tensor = ov::Tensor (old_tensor, new_shape_begin, new_shape_end);
525+
526+ ov::Tensor new_tensor (old_tensor.get_element_type (), shape);
527+ trimmed_tensor.copy_to (new_tensor);
528+
529+ state.set_state (new_tensor);
530+ }
531+ }
532+ }
533+ }
534+ }
535+
426536} // namespace openvino_ep
427537} // namespace onnxruntime
0 commit comments