@@ -800,7 +800,7 @@ struct server_context {
800800 int slot_prompt_len = slot_prompt.size ();
801801
802802 // length of the Longest Common Prefix between the current slot's prompt and the input prompt
803- int lcp_len = common_part (slot_prompt, prompt);
803+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
804804
805805 // fraction of the common substring length compared to the current slot's prompt length
806806 similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -2012,7 +2012,7 @@ struct server_context {
20122012 }
20132013 slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
20142014
2015- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2015+ // if input prompt is too big, truncate it
20162016 if (slot.n_prompt_tokens >= slot.n_ctx ) {
20172017 const int n_left = slot.n_ctx - slot.params .n_keep ;
20182018
@@ -2042,12 +2042,74 @@ struct server_context {
20422042
20432043 if (slot.params .cache_prompt ) {
20442044 // reuse any previously computed tokens that are common with the new prompt
2045- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2045+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
20462046
20472047 // push the prompt into the sampling context (do not apply grammar)
20482048 for (int i = 0 ; i < slot.n_past ; ++i) {
20492049 common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
20502050 }
2051+
2052+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
2053+ if (params.n_cache_reuse > 0 ) {
2054+ size_t head_c = slot.n_past ; // cache
2055+ size_t head_p = slot.n_past ; // current prompt
2056+
2057+ SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n " , params.n_cache_reuse , slot.n_past );
2058+
2059+ while (head_c < slot.cache_tokens .size () &&
2060+ head_p < prompt_tokens.size ()) {
2061+ if (llama_token_is_control (model, slot.cache_tokens [head_c])) {
2062+ break ;
2063+ }
2064+
2065+ if (llama_token_is_control (model, prompt_tokens[head_p])) {
2066+ break ;
2067+ }
2068+
2069+ size_t n_match = 0 ;
2070+
2071+ while (head_c + n_match < slot.cache_tokens .size () &&
2072+ head_p + n_match < prompt_tokens.size () &&
2073+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2074+ if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match])) {
2075+ break ;
2076+ }
2077+
2078+ if (llama_token_is_control (model, prompt_tokens[head_p + n_match])) {
2079+ break ;
2080+ }
2081+
2082+ n_match++;
2083+ }
2084+
2085+ if (n_match >= (size_t ) params.n_cache_reuse ) {
2086+ SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2087+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2088+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2089+ // }
2090+
2091+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2092+
2093+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2094+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2095+
2096+ for (size_t i = 0 ; i < n_match; i++) {
2097+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2098+
2099+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2100+
2101+ slot.n_past ++;
2102+ }
2103+
2104+ head_c += n_match;
2105+ head_p += n_match;
2106+ } else {
2107+ head_c += 1 ;
2108+ }
2109+ }
2110+
2111+ SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n " , slot.n_past );
2112+ }
20512113 }
20522114 }
20532115
@@ -3257,6 +3319,7 @@ int main(int argc, char ** argv) {
32573319
32583320 ctx_server.queue_tasks .on_new_task (std::bind (
32593321 &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3322+
32603323 ctx_server.queue_tasks .on_update_slots (std::bind (
32613324 &server_context::update_slots, &ctx_server));
32623325
0 commit comments