@@ -800,7 +800,7 @@ struct server_context {
800800 int slot_prompt_len = slot_prompt.size ();
801801
802802 // length of the Longest Common Prefix between the current slot's prompt and the input prompt
803- int lcp_len = common_part (slot_prompt, prompt);
803+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
804804
805805 // fraction of the common substring length compared to the current slot's prompt length
806806 similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -2042,12 +2042,61 @@ struct server_context {
20422042
20432043 if (slot.params .cache_prompt ) {
20442044 // reuse any previously computed tokens that are common with the new prompt
2045- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2045+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
20462046
20472047 // push the prompt into the sampling context (do not apply grammar)
20482048 for (int i = 0 ; i < slot.n_past ; ++i) {
20492049 common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
20502050 }
2051+
2052+ // EXPERIMENTAL: reuse chunks from the cached prompt by shifting them in the new position
2053+ if (1 ) {
2054+ size_t head_c = slot.n_past ; // cache
2055+ size_t head_p = slot.n_past ; // current prompt
2056+
2057+ while (head_c < slot.cache_tokens .size () &&
2058+ head_p < prompt_tokens.size () &&
2059+ !llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2060+ !llama_token_is_control (model, prompt_tokens[head_p])) {
2061+
2062+ size_t n_match = 0 ;
2063+ while (head_c + n_match < slot.cache_tokens .size () &&
2064+ head_p + n_match < prompt_tokens.size () &&
2065+ !llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2066+ !llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2067+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2068+ n_match++;
2069+ }
2070+
2071+ if (n_match > 32 ) {
2072+ // shift the KV chunk [head_c, head_c + n_match) -> [head_p, head_p + n_match)
2073+ SLT_DBG (slot, " shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , head_c, head_c + n_match, head_p, head_p + n_match);
2074+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2075+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2076+ // }
2077+
2078+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2079+
2080+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2081+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2082+
2083+ for (size_t i = 0 ; i < n_match; i++) {
2084+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2085+
2086+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2087+
2088+ slot.n_past ++;
2089+ }
2090+
2091+ head_c += n_match;
2092+ head_p += n_match;
2093+ } else {
2094+ head_c += 1 ;
2095+ }
2096+ }
2097+
2098+ SLT_DBG (slot, " new slot.n_past = %d, cache_tokens.size() = %zu\n " , slot.n_past , slot.cache_tokens .size ());
2099+ }
20512100 }
20522101 }
20532102
@@ -3257,6 +3306,7 @@ int main(int argc, char ** argv) {
32573306
32583307 ctx_server.queue_tasks .on_new_task (std::bind (
32593308 &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3309+
32603310 ctx_server.queue_tasks .on_update_slots (std::bind (
32613311 &server_context::update_slots, &ctx_server));
32623312
0 commit comments