@@ -139,6 +139,8 @@ struct slot_params {
139139
140140 json input_prefix;
141141 json input_suffix;
142+
143+ json extra_context;
142144};
143145
144146struct server_slot {
@@ -170,6 +172,7 @@ struct server_slot {
170172
171173 // when a task is submitted, we first tokenize the prompt and store it here
172174 std::vector<llama_token> prompt_tokens;
175+ std::vector<llama_token> extra_tokens;
173176
174177 std::string generated_text;
175178 std::vector<llama_token> cache_tokens;
@@ -906,8 +909,18 @@ struct server_context {
906909 }
907910
908911 // infill
909- slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
910- slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
912+ slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
913+ slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
914+ slot.params .extra_context = json_value (data, " extra_context" , default_params.extra_context );
915+
916+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params .extra_context .size ());
917+ for (const auto & chunk : slot.params .extra_context ) {
918+ if (chunk.is_string ()) {
919+ SLT_DBG (slot, " chunk: \n %s\n " , chunk.get <std::string>().c_str ());
920+ } else {
921+ SLT_DBG (slot, " %s" , " chunk is not a string - skipping\n " );
922+ }
923+ }
911924
912925 // get prompt
913926 if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1937,9 +1950,27 @@ struct server_context {
19371950 auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
19381951 auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
19391952
1940- // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1941- const int n_suffix_take = std::min<int >(suffix_tokens.size (), n_batch/4 );
1942- const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch - 3 ) - n_suffix_take);
1953+ slot.extra_tokens .clear ();
1954+ for (const auto & e : slot.params .extra_context ) {
1955+ if (e.is_string ()) {
1956+ // chunk separator in binary form to avoid confusing the AI
1957+ static const char k_chunk_prefix_str[] = {0x0a , 0x0a , 0x2d , 0x2d , 0x2d , 0x20 , 0x73 , 0x6e , 0x69 , 0x70 , 0x70 , 0x65 , 0x74 , 0x20 , 0x2d , 0x2d , 0x2d , 0x0a , 0x0a , 0x00 };
1958+ static const auto k_chunk_prefix_tokens = tokenize (k_chunk_prefix_str, false , false );
1959+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_chunk_prefix_tokens.begin (), k_chunk_prefix_tokens.end ());
1960+
1961+ const auto part = tokenize (e, false , false );
1962+ slot.extra_tokens .insert (slot.extra_tokens .end (), part.begin (), part.end ());
1963+ } else {
1964+ SLT_WRN (slot, " %s" , " extra context element is not a string\n " );
1965+ }
1966+ }
1967+
1968+ // for now pick FIM context to fit in half batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1969+ const int n_suffix_take = std::min<int >(suffix_tokens.size (), (n_batch/4 )/2 );
1970+ const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch/2 - 3 ) - n_suffix_take);
1971+
1972+ // fill the rest of the context with extra chunks
1973+ const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - n_batch - 2 *slot.n_predict ), slot.extra_tokens .size ());
19431974
19441975 prefix_tokens.erase (prefix_tokens.begin (), prefix_tokens.begin () + prefix_tokens.size () - n_prefix_take);
19451976 suffix_tokens.resize (n_suffix_take);
@@ -1954,6 +1985,10 @@ struct server_context {
19541985 embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
19551986 }
19561987
1988+ SLT_DBG (slot, " extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n " , slot.n_ctx , n_extra_take, (int ) slot.extra_tokens .size ());
1989+
1990+ embd_inp.insert (embd_inp.begin () + 1 , slot.extra_tokens .end () - n_extra_take, slot.extra_tokens .end ());
1991+
19571992 embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
19581993 embd_inp.push_back (llama_token_fim_mid (model));
19591994
0 commit comments