@@ -139,6 +139,7 @@ struct slot_params {
139139
140140 json input_prefix;
141141 json input_suffix;
142+ json extra_context;
142143};
143144
144145struct server_slot {
@@ -170,6 +171,7 @@ struct server_slot {
170171
171172 // when a task is submitted, we first tokenize the prompt and store it here
172173 std::vector<llama_token> prompt_tokens;
174+ std::vector<llama_token> extra_tokens;
173175
174176 std::string generated_text;
175177 std::vector<llama_token> cache_tokens;
@@ -800,7 +802,7 @@ struct server_context {
800802 int slot_prompt_len = slot_prompt.size ();
801803
802804 // length of the Longest Common Prefix between the current slot's prompt and the input prompt
803- int lcp_len = common_part (slot_prompt, prompt);
805+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
804806
805807 // fraction of the common substring length compared to the current slot's prompt length
806808 similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -906,8 +908,26 @@ struct server_context {
906908 }
907909
908910 // infill
909- slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
910- slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
911+ slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
912+ slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
913+ slot.params .extra_context = json_value (data, " extra_context" , default_params.extra_context );
914+
915+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params .extra_context .size ());
916+ for (const auto & chunk : slot.params .extra_context ) {
917+ // { "text": string, "filename": string }
918+ if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
919+ send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
920+ return false ;
921+ }
922+
923+ // filename is optional
924+ if (chunk.contains (" filename" ) && !chunk[" filename" ].is_string ()) {
925+ send_error (task, " extra_context chunk's \" filename\" field must be a string" , ERROR_TYPE_INVALID_REQUEST);
926+ return false ;
927+ }
928+
929+ SLT_DBG (slot, " extra_context chunk in file '%s':\n %s\n " , chunk.value (" filename" , " " ).c_str (), chunk.value (" text" , " " ).c_str ());
930+ }
911931
912932 // get prompt
913933 if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1934,13 +1954,66 @@ struct server_context {
19341954 } break ;
19351955 case SERVER_TASK_CMPL_TYPE_INFILL:
19361956 {
1957+ // use FIM repo-level pattern:
1958+ // ref: https://arxiv.org/pdf/2409.12186
1959+ //
1960+ // [FIM_REP]myproject
1961+ // [FIM_SEP]filename0
1962+ // extra chunk 0
1963+ // [FIM_SEP]filename1
1964+ // extra chunk 1
1965+ // ...
1966+ // [FIM_SEP]filename
1967+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1968+ //
19371969 auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
19381970 auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
19391971
1940- // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1941- const int n_suffix_take = std::min<int >(suffix_tokens.size (), n_batch/4 );
1972+ slot.extra_tokens .clear ();
1973+ if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
1974+ static const auto k_fim_repo = tokenize (" myproject\n " , false , false );
1975+
1976+ slot.extra_tokens .push_back (llama_token_fim_rep (model));
1977+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
1978+ }
1979+
1980+ for (const auto & chunk : slot.params .extra_context ) {
1981+ // { "text": string, "filename": string }
1982+ const std::string text = chunk.value (" text" , " " );
1983+ const std::string filename = chunk.value (" filename" , " tmp" );
1984+
1985+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
1986+ const auto k_fim_file = tokenize (filename + " \n " , false , false );
1987+
1988+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
1989+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
1990+ } else {
1991+ // chunk separator in binary form to avoid confusing the AI
1992+ static const char k_chunk_prefix_str[] = {0x0a , 0x0a , 0x2d , 0x2d , 0x2d , 0x20 , 0x73 , 0x6e , 0x69 , 0x70 , 0x70 , 0x65 , 0x74 , 0x20 , 0x2d , 0x2d , 0x2d , 0x0a , 0x0a , 0x00 };
1993+ static const auto k_chunk_prefix_tokens = tokenize (k_chunk_prefix_str, false , false );
1994+
1995+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_chunk_prefix_tokens.begin (), k_chunk_prefix_tokens.end ());
1996+ }
1997+
1998+ const auto chunk_tokens = tokenize (text, false , false );
1999+ slot.extra_tokens .insert (slot.extra_tokens .end (), chunk_tokens.begin (), chunk_tokens.end ());
2000+ }
2001+
2002+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
2003+ // TODO: current filename
2004+ static const auto k_fim_file = tokenize (" filename\n " , false , false );
2005+
2006+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
2007+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
2008+ }
2009+
2010+ // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2011+ const int n_suffix_take = std::min<int >(suffix_tokens.size (), (n_batch)/4 );
19422012 const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch - 3 ) - n_suffix_take);
19432013
2014+ // fill the rest of the context with extra chunks
2015+ const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
2016+
19442017 prefix_tokens.erase (prefix_tokens.begin (), prefix_tokens.begin () + prefix_tokens.size () - n_prefix_take);
19452018 suffix_tokens.resize (n_suffix_take);
19462019
@@ -1954,6 +2027,11 @@ struct server_context {
19542027 embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
19552028 }
19562029
2030+ SLT_DBG (slot, " extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n " , slot.n_ctx , n_extra_take, (int ) slot.extra_tokens .size ());
2031+
2032+ // put the extra context before the FIM prefix
2033+ embd_inp.insert (embd_inp.begin (), slot.extra_tokens .end () - n_extra_take, slot.extra_tokens .end ());
2034+
19572035 embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
19582036 embd_inp.push_back (llama_token_fim_mid (model));
19592037
@@ -2012,7 +2090,7 @@ struct server_context {
20122090 }
20132091 slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
20142092
2015- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2093+ // if input prompt is too big, truncate it
20162094 if (slot.n_prompt_tokens >= slot.n_ctx ) {
20172095 const int n_left = slot.n_ctx - slot.params .n_keep ;
20182096
@@ -2042,12 +2120,82 @@ struct server_context {
20422120
20432121 if (slot.params .cache_prompt ) {
20442122 // reuse any previously computed tokens that are common with the new prompt
2045- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2123+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
20462124
20472125 // push the prompt into the sampling context (do not apply grammar)
20482126 for (int i = 0 ; i < slot.n_past ; ++i) {
20492127 common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
20502128 }
2129+
2130+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
2131+ if (params.n_cache_reuse > 0 ) {
2132+ size_t head_c = slot.n_past ; // cache
2133+ size_t head_p = slot.n_past ; // current prompt
2134+
2135+ SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n " , params.n_cache_reuse , slot.n_past );
2136+
2137+ while (head_c < slot.cache_tokens .size () &&
2138+ head_p < prompt_tokens.size ()) {
2139+ if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2140+ slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2141+ slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2142+ break ;
2143+ }
2144+
2145+ if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2146+ prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2147+ prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2148+ break ;
2149+ }
2150+
2151+ size_t n_match = 0 ;
2152+
2153+ while (head_c + n_match < slot.cache_tokens .size () &&
2154+ head_p + n_match < prompt_tokens.size () &&
2155+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2156+ if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2157+ slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2158+ slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2159+ break ;
2160+ }
2161+
2162+ if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2163+ prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2164+ prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2165+ break ;
2166+ }
2167+
2168+ n_match++;
2169+ }
2170+
2171+ if (n_match >= (size_t ) params.n_cache_reuse ) {
2172+ SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2173+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2174+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2175+ // }
2176+
2177+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2178+
2179+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2180+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2181+
2182+ for (size_t i = 0 ; i < n_match; i++) {
2183+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2184+
2185+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2186+
2187+ slot.n_past ++;
2188+ }
2189+
2190+ head_c += n_match;
2191+ head_p += n_match;
2192+ } else {
2193+ head_c += 1 ;
2194+ }
2195+ }
2196+
2197+ SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n " , slot.n_past );
2198+ }
20512199 }
20522200 }
20532201
@@ -3257,6 +3405,7 @@ int main(int argc, char ** argv) {
32573405
32583406 ctx_server.queue_tasks .on_new_task (std::bind (
32593407 &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3408+
32603409 ctx_server.queue_tasks .on_update_slots (std::bind (
32613410 &server_context::update_slots, &ctx_server));
32623411
0 commit comments