@@ -139,6 +139,7 @@ struct slot_params {
139139
140140 json input_prefix;
141141 json input_suffix;
142+ json extra_context;
142143};
143144
144145struct server_slot {
@@ -170,6 +171,7 @@ struct server_slot {
170171
171172 // when a task is submitted, we first tokenize the prompt and store it here
172173 std::vector<llama_token> prompt_tokens;
174+ std::vector<llama_token> extra_tokens;
173175
174176 std::string generated_text;
175177 std::vector<llama_token> cache_tokens;
@@ -800,7 +802,7 @@ struct server_context {
800802 int slot_prompt_len = slot_prompt.size ();
801803
802804 // length of the Longest Common Prefix between the current slot's prompt and the input prompt
803- int lcp_len = common_part (slot_prompt, prompt);
805+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
804806
805807 // fraction of the common substring length compared to the current slot's prompt length
806808 similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -908,8 +910,26 @@ struct server_context {
908910 }
909911
910912 // infill
911- slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
912- slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
913+ slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
914+ slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
915+ slot.params .extra_context = json_value (data, " extra_context" , default_params.extra_context );
916+
917+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params .extra_context .size ());
918+ for (const auto & chunk : slot.params .extra_context ) {
919+ // { "text": string, "filename": string }
920+ if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
921+ send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
922+ return false ;
923+ }
924+
925+ // filename is optional
926+ if (chunk.contains (" filename" ) && !chunk[" filename" ].is_string ()) {
927+ send_error (task, " extra_context chunk's \" filename\" field must be a string" , ERROR_TYPE_INVALID_REQUEST);
928+ return false ;
929+ }
930+
931+ SLT_DBG (slot, " extra_context chunk in file '%s':\n %s\n " , chunk.value (" filename" , " " ).c_str (), chunk.value (" text" , " " ).c_str ());
932+ }
913933
914934 // get prompt
915935 if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1938,13 +1958,66 @@ struct server_context {
19381958 } break ;
19391959 case SERVER_TASK_CMPL_TYPE_INFILL:
19401960 {
1961+ // use FIM repo-level pattern:
1962+ // ref: https://arxiv.org/pdf/2409.12186
1963+ //
1964+ // [FIM_REP]myproject
1965+ // [FIM_SEP]filename0
1966+ // extra chunk 0
1967+ // [FIM_SEP]filename1
1968+ // extra chunk 1
1969+ // ...
1970+ // [FIM_SEP]filename
1971+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1972+ //
19411973 auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
19421974 auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
19431975
1944- // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1945- const int n_suffix_take = std::min<int >(suffix_tokens.size (), n_batch/4 );
1976+ slot.extra_tokens .clear ();
1977+ if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
1978+ static const auto k_fim_repo = tokenize (" myproject\n " , false , false );
1979+
1980+ slot.extra_tokens .push_back (llama_token_fim_rep (model));
1981+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
1982+ }
1983+
1984+ for (const auto & chunk : slot.params .extra_context ) {
1985+ // { "text": string, "filename": string }
1986+ const std::string text = chunk.value (" text" , " " );
1987+ const std::string filename = chunk.value (" filename" , " tmp" );
1988+
1989+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
1990+ const auto k_fim_file = tokenize (filename + " \n " , false , false );
1991+
1992+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
1993+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
1994+ } else {
1995+ // chunk separator in binary form to avoid confusing the AI
1996+ static const char k_chunk_prefix_str[] = {0x0a , 0x0a , 0x2d , 0x2d , 0x2d , 0x20 , 0x73 , 0x6e , 0x69 , 0x70 , 0x70 , 0x65 , 0x74 , 0x20 , 0x2d , 0x2d , 0x2d , 0x0a , 0x0a , 0x00 };
1997+ static const auto k_chunk_prefix_tokens = tokenize (k_chunk_prefix_str, false , false );
1998+
1999+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_chunk_prefix_tokens.begin (), k_chunk_prefix_tokens.end ());
2000+ }
2001+
2002+ const auto chunk_tokens = tokenize (text, false , false );
2003+ slot.extra_tokens .insert (slot.extra_tokens .end (), chunk_tokens.begin (), chunk_tokens.end ());
2004+ }
2005+
2006+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
2007+ // TODO: current filename
2008+ static const auto k_fim_file = tokenize (" filename\n " , false , false );
2009+
2010+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
2011+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
2012+ }
2013+
2014+ // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2015+ const int n_suffix_take = std::min<int >(suffix_tokens.size (), (n_batch)/4 );
19462016 const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch - 3 ) - n_suffix_take);
19472017
2018+ // fill the rest of the context with extra chunks
2019+ const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
2020+
19482021 prefix_tokens.erase (prefix_tokens.begin (), prefix_tokens.begin () + prefix_tokens.size () - n_prefix_take);
19492022 suffix_tokens.resize (n_suffix_take);
19502023
@@ -1958,6 +2031,11 @@ struct server_context {
19582031 embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
19592032 }
19602033
2034+ SLT_DBG (slot, " extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n " , slot.n_ctx , n_extra_take, (int ) slot.extra_tokens .size ());
2035+
2036+ // put the extra context before the FIM prefix
2037+ embd_inp.insert (embd_inp.begin (), slot.extra_tokens .end () - n_extra_take, slot.extra_tokens .end ());
2038+
19612039 embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
19622040 embd_inp.push_back (llama_token_fim_mid (model));
19632041
@@ -2016,7 +2094,7 @@ struct server_context {
20162094 }
20172095 slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
20182096
2019- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2097+ // if input prompt is too big, truncate it
20202098 if (slot.n_prompt_tokens >= slot.n_ctx ) {
20212099 const int n_left = slot.n_ctx - slot.params .n_keep ;
20222100
@@ -2046,12 +2124,82 @@ struct server_context {
20462124
20472125 if (slot.params .cache_prompt ) {
20482126 // reuse any previously computed tokens that are common with the new prompt
2049- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2127+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
20502128
20512129 // push the prompt into the sampling context (do not apply grammar)
20522130 for (int i = 0 ; i < slot.n_past ; ++i) {
20532131 common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
20542132 }
2133+
2134+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
2135+ if (params.n_cache_reuse > 0 ) {
2136+ size_t head_c = slot.n_past ; // cache
2137+ size_t head_p = slot.n_past ; // current prompt
2138+
2139+ SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n " , params.n_cache_reuse , slot.n_past );
2140+
2141+ while (head_c < slot.cache_tokens .size () &&
2142+ head_p < prompt_tokens.size ()) {
2143+ if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2144+ slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2145+ slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2146+ break ;
2147+ }
2148+
2149+ if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2150+ prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2151+ prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2152+ break ;
2153+ }
2154+
2155+ size_t n_match = 0 ;
2156+
2157+ while (head_c + n_match < slot.cache_tokens .size () &&
2158+ head_p + n_match < prompt_tokens.size () &&
2159+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2160+ if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2161+ slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2162+ slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2163+ break ;
2164+ }
2165+
2166+ if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2167+ prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2168+ prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2169+ break ;
2170+ }
2171+
2172+ n_match++;
2173+ }
2174+
2175+ if (n_match >= (size_t ) params.n_cache_reuse ) {
2176+ SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2177+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2178+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2179+ // }
2180+
2181+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2182+
2183+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2184+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2185+
2186+ for (size_t i = 0 ; i < n_match; i++) {
2187+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2188+
2189+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2190+
2191+ slot.n_past ++;
2192+ }
2193+
2194+ head_c += n_match;
2195+ head_p += n_match;
2196+ } else {
2197+ head_c += 1 ;
2198+ }
2199+ }
2200+
2201+ SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n " , slot.n_past );
2202+ }
20552203 }
20562204 }
20572205
@@ -3261,6 +3409,7 @@ int main(int argc, char ** argv) {
32613409
32623410 ctx_server.queue_tasks .on_new_task (std::bind (
32633411 &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3412+
32643413 ctx_server.queue_tasks .on_update_slots (std::bind (
32653414 &server_context::update_slots, &ctx_server));
32663415
0 commit comments