@@ -800,7 +800,7 @@ struct server_context {
800800                int  slot_prompt_len = slot_prompt.size ();
801801
802802                //  length of the Longest Common Prefix between the current slot's prompt and the input prompt
803-                 int  lcp_len = common_part (slot_prompt, prompt);
803+                 int  lcp_len = longest_common_prefix (slot_prompt, prompt);
804804
805805                //  fraction of the common substring length compared to the current slot's prompt length
806806                similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -2012,7 +2012,7 @@ struct server_context {
20122012                            }
20132013                            slot.params .n_keep  = std::min (slot.n_ctx  - 4 , slot.params .n_keep );
20142014
2015-                             //  if input prompt is too big, truncate it (if group attention self-extend is disabled) 
2015+                             //  if input prompt is too big, truncate it
20162016                            if  (slot.n_prompt_tokens  >= slot.n_ctx ) {
20172017                                const  int  n_left = slot.n_ctx  - slot.params .n_keep ;
20182018
@@ -2042,12 +2042,74 @@ struct server_context {
20422042
20432043                            if  (slot.params .cache_prompt ) {
20442044                                //  reuse any previously computed tokens that are common with the new prompt
2045-                                 slot.n_past  = common_part (slot.cache_tokens , prompt_tokens);
2045+                                 slot.n_past  = longest_common_prefix (slot.cache_tokens , prompt_tokens);
20462046
20472047                                //  push the prompt into the sampling context (do not apply grammar)
20482048                                for  (int  i = 0 ; i < slot.n_past ; ++i) {
20492049                                    common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
20502050                                }
2051+ 
2052+                                 //  reuse chunks from the cached prompt by shifting their KV cache in the new position
2053+                                 if  (params.n_cache_reuse  > 0 ) {
2054+                                     size_t  head_c = slot.n_past ; //  cache
2055+                                     size_t  head_p = slot.n_past ; //  current prompt
2056+ 
2057+                                     SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n "  , params.n_cache_reuse , slot.n_past );
2058+ 
2059+                                     while  (head_c < slot.cache_tokens .size () &&
2060+                                            head_p < prompt_tokens.size ()) {
2061+                                         if  (llama_token_is_control (model, slot.cache_tokens [head_c])) {
2062+                                             break ;
2063+                                         }
2064+ 
2065+                                         if  (llama_token_is_control (model, prompt_tokens[head_p])) {
2066+                                             break ;
2067+                                         }
2068+ 
2069+                                         size_t  n_match = 0 ;
2070+ 
2071+                                         while  (head_c + n_match < slot.cache_tokens .size () &&
2072+                                                head_p + n_match < prompt_tokens.size ()     &&
2073+                                                slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2074+                                             if  (llama_token_is_control (model, slot.cache_tokens [head_c + n_match])) {
2075+                                                 break ;
2076+                                             }
2077+ 
2078+                                             if  (llama_token_is_control (model, prompt_tokens[head_p + n_match])) {
2079+                                                 break ;
2080+                                             }
2081+ 
2082+                                             n_match++;
2083+                                         }
2084+ 
2085+                                         if  (n_match >= (size_t ) params.n_cache_reuse ) {
2086+                                             SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n "  , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2087+                                             // for (size_t i = head_p; i < head_p + n_match; i++) {
2088+                                             //     SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2089+                                             // }
2090+ 
2091+                                             const  int64_t  kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2092+ 
2093+                                             llama_kv_cache_seq_rm  (ctx, slot.id  + 1 , head_p, head_c);
2094+                                             llama_kv_cache_seq_add (ctx, slot.id  + 1 , head_c, -1 ,     kv_shift);
2095+ 
2096+                                             for  (size_t  i = 0 ; i < n_match; i++) {
2097+                                                 slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2098+ 
2099+                                                 common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2100+ 
2101+                                                 slot.n_past ++;
2102+                                             }
2103+ 
2104+                                             head_c += n_match;
2105+                                             head_p += n_match;
2106+                                         } else  {
2107+                                             head_c += 1 ;
2108+                                         }
2109+                                     }
2110+ 
2111+                                     SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n "  , slot.n_past );
2112+                                 }
20512113                            }
20522114                        }
20532115
@@ -3257,6 +3319,7 @@ int main(int argc, char ** argv) {
32573319
32583320    ctx_server.queue_tasks .on_new_task (std::bind (
32593321                &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3322+ 
32603323    ctx_server.queue_tasks .on_update_slots (std::bind (
32613324                &server_context::update_slots, &ctx_server));
32623325
0 commit comments