@@ -136,10 +136,6 @@ struct slot_params {
136136 int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
137137
138138 std::vector<std::string> antiprompt;
139-
140- json input_prefix;
141- json input_suffix;
142- json extra_context;
143139};
144140
145141struct server_slot {
@@ -169,6 +165,10 @@ struct server_slot {
169165
170166 json prompt; // can be either a string, array of strings or array of token ids
171167
168+ json input_prefix;
169+ json input_suffix;
170+ json input_extra;
171+
172172 // when a task is submitted, we first tokenize the prompt and store it here
173173 std::vector<llama_token> prompt_tokens;
174174 std::vector<llama_token> extra_tokens;
@@ -910,12 +910,12 @@ struct server_context {
910910 }
911911
912912 // infill
913- slot.params . input_prefix = json_value (data, " input_prefix" , default_params. input_prefix );
914- slot.params . input_suffix = json_value (data, " input_suffix" , default_params. input_suffix );
915- slot.params . extra_context = json_value (data, " extra_context " , default_params. extra_context );
913+ slot.input_prefix = json_value (data, " input_prefix" , json () );
914+ slot.input_suffix = json_value (data, " input_suffix" , json () );
915+ slot.input_extra = json_value (data, " input_extra " , json () );
916916
917- SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params . extra_context .size ());
918- for (const auto & chunk : slot.params . extra_context ) {
917+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.input_extra .size ());
918+ for (const auto & chunk : slot.input_extra ) {
919919 // { "text": string, "filename": string }
920920 if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
921921 send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
@@ -932,7 +932,7 @@ struct server_context {
932932 }
933933
934934 // get prompt
935- if (task. cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
935+ {
936936 const auto & prompt = data.find (" prompt" );
937937 if (prompt == data.end ()) {
938938 send_error (task, " \" prompt\" must be provided" , ERROR_TYPE_INVALID_REQUEST);
@@ -1958,6 +1958,8 @@ struct server_context {
19581958 } break ;
19591959 case SERVER_TASK_CMPL_TYPE_INFILL:
19601960 {
1961+ // TODO: optimize this block by reducing memory allocations and movement
1962+
19611963 // use FIM repo-level pattern:
19621964 // ref: https://arxiv.org/pdf/2409.12186
19631965 //
@@ -1968,10 +1970,11 @@ struct server_context {
19681970 // extra chunk 1
19691971 // ...
19701972 // [FIM_SEP]filename
1971- // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1973+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
19721974 //
1973- auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
1974- auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
1975+ auto tokens_prefix = tokenize (slot.input_prefix , false , false );
1976+ auto tokens_suffix = tokenize (slot.input_suffix , false , false );
1977+ auto tokens_prompt = tokenize (slot.prompt , false , false );
19751978
19761979 slot.extra_tokens .clear ();
19771980 if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
@@ -1981,7 +1984,7 @@ struct server_context {
19811984 slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
19821985 }
19831986
1984- for (const auto & chunk : slot.params . extra_context ) {
1987+ for (const auto & chunk : slot.input_extra ) {
19851988 // { "text": string, "filename": string }
19861989 const std::string text = chunk.value (" text" , " " );
19871990 const std::string filename = chunk.value (" filename" , " tmp" );
@@ -2012,20 +2015,21 @@ struct server_context {
20122015 }
20132016
20142017 // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2015- const int n_suffix_take = std::min<int >(suffix_tokens .size (), (n_batch)/ 4 );
2016- const int n_prefix_take = std::min<int >(prefix_tokens .size (), (n_batch - 3 ) - n_suffix_take );
2018+ const int n_suffix_take = std::min<int >(tokens_suffix .size (), (n_batch/ 4 ) );
2019+ const int n_prefix_take = std::min<int >(tokens_prefix .size (), 3 * (n_batch/ 4 ) - 3 );
20172020
20182021 // fill the rest of the context with extra chunks
20192022 const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
20202023
2021- prefix_tokens .erase (prefix_tokens .begin (), prefix_tokens .begin () + prefix_tokens .size () - n_prefix_take);
2022- suffix_tokens .resize (n_suffix_take);
2024+ tokens_prefix .erase (tokens_prefix .begin (), tokens_prefix .begin () + tokens_prefix .size () - n_prefix_take);
2025+ tokens_suffix .resize (n_suffix_take);
20232026
2024- prefix_tokens.insert (prefix_tokens.begin (), llama_token_fim_pre (model));
2025- suffix_tokens.insert (suffix_tokens.begin (), llama_token_fim_suf (model));
2027+ tokens_prefix.insert (tokens_prefix.begin (), llama_token_fim_pre (model));
2028+ tokens_prefix.insert (tokens_prefix.end (), tokens_prompt.begin (), tokens_prompt.end ());
2029+ tokens_suffix.insert (tokens_suffix.begin (), llama_token_fim_suf (model));
20262030
2027- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens ;
2028- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens ;
2031+ auto embd_inp = params.spm_infill ? tokens_suffix : tokens_prefix ;
2032+ auto embd_end = params.spm_infill ? tokens_prefix : tokens_suffix ;
20292033
20302034 if (llama_add_bos_token (model)) {
20312035 embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
@@ -2140,40 +2144,17 @@ struct server_context {
21402144
21412145 while (head_c < slot.cache_tokens .size () &&
21422146 head_p < prompt_tokens.size ()) {
2143- if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2144- slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2145- slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2146- break ;
2147- }
2148-
2149- if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2150- prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2151- prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2152- break ;
2153- }
21542147
21552148 size_t n_match = 0 ;
2156-
21572149 while (head_c + n_match < slot.cache_tokens .size () &&
21582150 head_p + n_match < prompt_tokens.size () &&
21592151 slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2160- if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2161- slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2162- slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2163- break ;
2164- }
2165-
2166- if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2167- prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2168- prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2169- break ;
2170- }
21712152
21722153 n_match++;
21732154 }
21742155
21752156 if (n_match >= (size_t ) params.n_cache_reuse ) {
2176- SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2157+ SLT_INF (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
21772158 // for (size_t i = head_p; i < head_p + n_match; i++) {
21782159 // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
21792160 // }
0 commit comments