@@ -29,7 +29,6 @@ std::pair<ov::Tensor, std::optional<int64_t>> InputsEmbedder::IInputsEmbedder::g
2929
3030void InputsEmbedder::IInputsEmbedder::start_chat (const std::string& system_message) {
3131 m_is_chat_conversation = true ;
32- m_kv_history_trim_manager.reset ();
3332 if (!m_kv_cache_state.get_state ().empty ()) {
3433 m_history.clear ();
3534 m_kv_cache_state.reset_state ();
@@ -40,17 +39,26 @@ void InputsEmbedder::IInputsEmbedder::start_chat(const std::string& system_messa
4039 m_history = {{{" role" , " system" }, {" content" , system_message}}};
4140}
4241
43- void InputsEmbedder::IInputsEmbedder::update_chat_history (const std::string& decoded_results) {
44- // Tail of chat template is missing in KV cache.
45- // Find the tail to concatenate it with the next input prompt.
46- m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
47- m_kv_history_trim_manager.reset ();
42+ void InputsEmbedder::IInputsEmbedder::update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status) {
43+ m_kv_cache_state.num_tokens_to_trim = 0 ;
44+ if (generation_finish_status == ov::genai::GenerationStatus::CANCEL) {
45+ // If chat generation process was cancelled by user, let's rollback to previous state of history
46+ m_history.pop_back ();
47+
48+ std::vector<int64_t >& state = m_kv_cache_state.get_state ();
49+
50+ m_kv_cache_state.num_tokens_to_trim = state.size () - m_prev_hist_length;
51+ state.resize (m_prev_hist_length);
52+ m_kv_cache_state.reset_mem_state = state.empty ();
53+ } else {
54+ // Tail of chat template is missing in KV cache.
55+ // Find the tail to concatenate it with the next input prompt.
56+ m_history.push_back ({{" role" , " assistant" }, {" content" , decoded_results}});
57+ }
4858}
4959
5060void InputsEmbedder::IInputsEmbedder::finish_chat () {
5161 m_is_chat_conversation = false ;
52- m_kv_history_trim_manager.reset ();
53-
5462 m_history.clear ();
5563 m_kv_cache_state.reset_state ();
5664}
@@ -123,7 +131,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::apply_chat_template_tokenize(const s
123131ov::Tensor InputsEmbedder::IInputsEmbedder::update_history (const ov::Tensor& new_chat_tokens) {
124132 ov::Tensor encoded_inputs;
125133 if (m_is_chat_conversation) {
126- ov::genai::align_kv_cache_and_history (m_kv_history_trim_manager, new_chat_tokens, m_kv_cache_state);
134+ ov::genai::align_kv_cache_and_history (new_chat_tokens, m_kv_cache_state);
127135 encoded_inputs = get_chat_encoded_input (new_chat_tokens, m_kv_cache_state).input_ids ;
128136 } else {
129137 encoded_inputs = new_chat_tokens;
@@ -135,6 +143,7 @@ ov::Tensor InputsEmbedder::IInputsEmbedder::update_history(const ov::Tensor& new
135143ov::Tensor InputsEmbedder::IInputsEmbedder::get_encoded_input_ids (const std::string& prompt, ov::genai::VLMPerfMetrics& metrics) {
136144 const auto new_chat_tokens = apply_chat_template_tokenize (prompt, metrics);
137145 auto new_input_ids = update_history (new_chat_tokens);
146+ m_prev_hist_length = m_kv_cache_state.get_state ().size ();
138147 m_kv_cache_state.add_inputs (new_input_ids);
139148
140149 return new_input_ids;
@@ -225,14 +234,10 @@ EmbeddingsModel InputsEmbedder::get_embedding_model() const {
225234 return m_impl->get_embedding_model ();
226235}
227236
228- KVCacheState& InputsEmbedder::get_kv_cache_state () {
237+ ov::genai::utils:: KVCacheState& InputsEmbedder::get_kv_cache_state () {
229238 return m_impl->get_kv_cache_state ();
230239}
231240
232- size_t InputsEmbedder::get_num_tokens_to_remove_from_hist () const {
233- return m_impl->get_num_tokens_to_remove_from_hist ();
234- }
235-
236241Tokenizer InputsEmbedder::get_tokenizer () const {
237242 return m_impl->get_tokenizer ();
238243}
@@ -241,8 +246,8 @@ void InputsEmbedder::start_chat(const std::string& system_message) {
241246 return m_impl->start_chat (system_message);
242247}
243248
244- void InputsEmbedder::update_chat_history (const std::string& decoded_results) {
245- return m_impl->update_chat_history (decoded_results);
249+ void InputsEmbedder::update_chat_history (const std::string& decoded_results, const ov::genai::GenerationStatus generation_finish_status ) {
250+ return m_impl->update_chat_history (decoded_results, generation_finish_status );
246251}
247252
248253void InputsEmbedder::set_apply_chat_template_status (bool apply_chat_template) {
0 commit comments