@@ -136,10 +136,6 @@ struct slot_params {
136136    int64_t  t_max_predict_ms = -1 ; //  if positive, limit the generation phase to this time limit
137137
138138    std::vector<std::string> antiprompt;
139- 
140-     json input_prefix;
141-     json input_suffix;
142-     json extra_context;
143139};
144140
145141struct  server_slot  {
@@ -169,6 +165,10 @@ struct server_slot {
169165
170166    json prompt; //  can be either a string, array of strings or array of token ids
171167
168+     json input_prefix;
169+     json input_suffix;
170+     json input_extra;
171+ 
172172    //  when a task is submitted, we first tokenize the prompt and store it here
173173    std::vector<llama_token> prompt_tokens;
174174    std::vector<llama_token> extra_tokens;
@@ -910,12 +910,12 @@ struct server_context {
910910        }
911911
912912        //  infill
913-         slot.params . input_prefix    = json_value (data, " input_prefix"  ,  default_params. input_prefix );
914-         slot.params . input_suffix    = json_value (data, " input_suffix"  ,  default_params. input_suffix );
915-         slot.params . extra_context   = json_value (data, " extra_context "  , default_params. extra_context );
913+         slot.input_prefix  = json_value (data, " input_prefix"  , json () );
914+         slot.input_suffix  = json_value (data, " input_suffix"  , json () );
915+         slot.input_extra    = json_value (data, " input_extra "  ,   json () );
916916
917-         SLT_DBG (slot, " extra_context chunks: %d\n "  , (int ) slot.params . extra_context .size ());
918-         for  (const  auto  & chunk : slot.params . extra_context ) {
917+         SLT_DBG (slot, " extra_context chunks: %d\n "  , (int ) slot.input_extra .size ());
918+         for  (const  auto  & chunk : slot.input_extra ) {
919919            //  { "text": string, "filename": string }
920920            if  (!chunk.contains (" text"  ) || !chunk[" text"  ].is_string ()) {
921921                send_error (task, " extra_context chunk must contain a \" text\"  field with a string value"  , ERROR_TYPE_INVALID_REQUEST);
@@ -932,7 +932,7 @@ struct server_context {
932932        }
933933
934934        //  get prompt
935-         if  (task. cmpl_type  != SERVER_TASK_CMPL_TYPE_INFILL)  {
935+         {
936936            const  auto  & prompt = data.find (" prompt"  );
937937            if  (prompt == data.end ()) {
938938                send_error (task, " \" prompt\"  must be provided"  , ERROR_TYPE_INVALID_REQUEST);
@@ -1958,6 +1958,8 @@ struct server_context {
19581958                                } break ;
19591959                            case  SERVER_TASK_CMPL_TYPE_INFILL:
19601960                                {
1961+                                     //  TODO: optimize this block by reducing memory allocations and movement
1962+ 
19611963                                    //  use FIM repo-level pattern:
19621964                                    //  ref: https://arxiv.org/pdf/2409.12186
19631965                                    // 
@@ -1968,10 +1970,11 @@ struct server_context {
19681970                                    //  extra chunk 1
19691971                                    //  ...
19701972                                    //  [FIM_SEP]filename
1971-                                     //  [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1973+                                     //  [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt 
19721974                                    // 
1973-                                     auto  prefix_tokens = tokenize (slot.params .input_prefix , false , false );
1974-                                     auto  suffix_tokens = tokenize (slot.params .input_suffix , false , false );
1975+                                     auto  tokens_prefix = tokenize (slot.input_prefix , false , false );
1976+                                     auto  tokens_suffix = tokenize (slot.input_suffix , false , false );
1977+                                     auto  tokens_prompt = tokenize (slot.prompt ,       false , false );
19751978
19761979                                    slot.extra_tokens .clear ();
19771980                                    if  (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
@@ -1981,7 +1984,7 @@ struct server_context {
19811984                                        slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
19821985                                    }
19831986
1984-                                     for  (const  auto  & chunk : slot.params . extra_context ) {
1987+                                     for  (const  auto  & chunk : slot.input_extra ) {
19851988                                        //  { "text": string, "filename": string }
19861989                                        const  std::string text     = chunk.value (" text"  , " "  );
19871990                                        const  std::string filename = chunk.value (" filename"  , " tmp"  );
@@ -2012,20 +2015,21 @@ struct server_context {
20122015                                    }
20132016
20142017                                    //  for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2015-                                     const  int  n_suffix_take = std::min<int >(suffix_tokens .size (), (n_batch)/ 4 );
2016-                                     const  int  n_prefix_take = std::min<int >(prefix_tokens .size (), (n_batch -  3 ) - n_suffix_take );
2018+                                     const  int  n_suffix_take = std::min<int >(tokens_suffix .size (),    (n_batch/ 4 ) );
2019+                                     const  int  n_prefix_take = std::min<int >(tokens_prefix .size (), 3 * (n_batch/ 4 ) - 3 );
20172020
20182021                                    //  fill the rest of the context with extra chunks
20192022                                    const  int  n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx  - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
20202023
2021-                                     prefix_tokens .erase (prefix_tokens .begin (), prefix_tokens .begin () + prefix_tokens .size () - n_prefix_take);
2022-                                     suffix_tokens .resize (n_suffix_take);
2024+                                     tokens_prefix .erase (tokens_prefix .begin (), tokens_prefix .begin () + tokens_prefix .size () - n_prefix_take);
2025+                                     tokens_suffix .resize (n_suffix_take);
20232026
2024-                                     prefix_tokens.insert (prefix_tokens.begin (), llama_token_fim_pre (model));
2025-                                     suffix_tokens.insert (suffix_tokens.begin (), llama_token_fim_suf (model));
2027+                                     tokens_prefix.insert (tokens_prefix.begin (), llama_token_fim_pre (model));
2028+                                     tokens_prefix.insert (tokens_prefix.end (),   tokens_prompt.begin (), tokens_prompt.end ());
2029+                                     tokens_suffix.insert (tokens_suffix.begin (), llama_token_fim_suf (model));
20262030
2027-                                     auto  embd_inp = params.spm_infill  ? suffix_tokens  : prefix_tokens ;
2028-                                     auto  embd_end = params.spm_infill  ? prefix_tokens  : suffix_tokens ;
2031+                                     auto  embd_inp = params.spm_infill  ? tokens_suffix  : tokens_prefix ;
2032+                                     auto  embd_end = params.spm_infill  ? tokens_prefix  : tokens_suffix ;
20292033
20302034                                    if  (llama_add_bos_token (model)) {
20312035                                        embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
@@ -2140,40 +2144,17 @@ struct server_context {
21402144
21412145                                    while  (head_c < slot.cache_tokens .size () &&
21422146                                           head_p < prompt_tokens.size ()) {
2143-                                         if  (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2144-                                             slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2145-                                             slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2146-                                             break ;
2147-                                         }
2148- 
2149-                                         if  (llama_token_is_control (model, prompt_tokens[head_p]) &&
2150-                                             prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2151-                                             prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2152-                                             break ;
2153-                                         }
21542147
21552148                                        size_t  n_match = 0 ;
2156- 
21572149                                        while  (head_c + n_match < slot.cache_tokens .size () &&
21582150                                               head_p + n_match < prompt_tokens.size ()     &&
21592151                                               slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2160-                                             if  (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2161-                                                 slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2162-                                                 slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2163-                                                 break ;
2164-                                             }
2165- 
2166-                                             if  (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2167-                                                 prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2168-                                                 prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2169-                                                 break ;
2170-                                             }
21712152
21722153                                            n_match++;
21732154                                        }
21742155
21752156                                        if  (n_match >= (size_t ) params.n_cache_reuse ) {
2176-                                             SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n "  , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2157+                                             SLT_INF (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n "  , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
21772158                                            // for (size_t i = head_p; i < head_p + n_match; i++) {
21782159                                            //     SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
21792160                                            // }
0 commit comments