@@ -128,9 +128,12 @@ struct slot_params {
128128 bool stream = true ;
129129 bool cache_prompt = false ; // remember the prompt to avoid reprocessing all prompt
130130
131- int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
132- int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133- int32_t n_predict = -1 ; // new tokens to predict
131+ int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
132+ int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133+ int32_t n_predict = -1 ; // new tokens to predict
134+
135+ int64_t t_max_prompt_ms = -1 ; // TODO: not implemented
136+ int64_t t_max_predict_ms = -1 ; // if positive, limit the generation phase to this time limit
134137
135138 std::vector<std::string> antiprompt;
136139
@@ -175,6 +178,7 @@ struct server_slot {
175178 server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176179
177180 bool has_next_token = true ;
181+ bool has_new_line = false ;
178182 bool truncated = false ;
179183 bool stopped_eos = false ;
180184 bool stopped_word = false ;
@@ -210,6 +214,7 @@ struct server_slot {
210214
211215 n_prompt_tokens = 0 ;
212216 generated_text = " " ;
217+ has_new_line = false ;
213218 truncated = false ;
214219 stopped_eos = false ;
215220 stopped_word = false ;
@@ -795,7 +800,7 @@ struct server_context {
795800 int slot_prompt_len = slot_prompt.size ();
796801
797802 // length of the Longest Common Prefix between the current slot's prompt and the input prompt
798- int lcp_len = common_part (slot_prompt, prompt);
803+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
799804
800805 // fraction of the common substring length compared to the current slot's prompt length
801806 similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -931,6 +936,10 @@ struct server_context {
931936 }
932937 }
933938
939+ // time limits
940+ slot.params .t_max_prompt_ms = json_value (data, " t_max_prompt_ms" , default_params.t_max_prompt_ms );
941+ slot.params .t_max_predict_ms = json_value (data, " t_max_predict_ms" , default_params.t_max_predict_ms );
942+
934943 {
935944 slot.sparams .logit_bias .clear ();
936945
@@ -1101,6 +1110,20 @@ struct server_context {
11011110 SLT_DBG (slot, " stopped by limit, n_decoded = %d, n_predict = %d\n " , slot.n_decoded , slot.params .n_predict );
11021111 }
11031112
1113+ // if we have already seen a new line, we stop after a certain time limit
1114+ if (slot.has_new_line && slot.params .t_max_predict_ms > 0 &&
1115+ (ggml_time_us () - slot.t_start_generation > 1000 .0f *slot.params .t_max_predict_ms )) {
1116+ slot.stopped_limit = true ;
1117+ slot.has_next_token = false ;
1118+
1119+ SLT_DBG (slot, " stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n " , slot.n_decoded , (int ) slot.params .t_max_predict_ms );
1120+ }
1121+
1122+ // check if there is a new line in the generated text
1123+ if (result.text_to_send .find (' \n ' ) != std::string::npos) {
1124+ slot.has_new_line = true ;
1125+ }
1126+
11041127 // if context shift is disabled, we stop when it reaches the context limit
11051128 if (slot.n_decoded >= slot.n_ctx ) {
11061129 slot.truncated = true ;
@@ -1249,6 +1272,7 @@ struct server_context {
12491272 {" tokens_evaluated" , slot.n_prompt_tokens },
12501273 {" generation_settings" , get_formated_generation (slot)},
12511274 {" prompt" , slot.prompt },
1275+ {" has_new_line" , slot.has_new_line },
12521276 {" truncated" , slot.truncated },
12531277 {" stopped_eos" , slot.stopped_eos },
12541278 {" stopped_word" , slot.stopped_word },
@@ -1575,6 +1599,7 @@ struct server_context {
15751599 slot_data[" prompt" ] = slot.prompt ;
15761600 slot_data[" next_token" ] = {
15771601 {" has_next_token" , slot.has_next_token },
1602+ {" has_new_line" , slot.has_new_line },
15781603 {" n_remain" , slot.n_remaining },
15791604 {" n_decoded" , slot.n_decoded },
15801605 {" stopped_eos" , slot.stopped_eos },
@@ -1913,6 +1938,13 @@ struct server_context {
19131938 auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
19141939 auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
19151940
1941+ // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1942+ const int n_suffix_take = std::min<int >(suffix_tokens.size (), n_batch/4 );
1943+ const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch - 3 ) - n_suffix_take);
1944+
1945+ prefix_tokens.erase (prefix_tokens.begin (), prefix_tokens.begin () + prefix_tokens.size () - n_prefix_take);
1946+ suffix_tokens.resize (n_suffix_take);
1947+
19161948 prefix_tokens.insert (prefix_tokens.begin (), llama_token_fim_pre (model));
19171949 suffix_tokens.insert (suffix_tokens.begin (), llama_token_fim_suf (model));
19181950
@@ -1935,9 +1967,17 @@ struct server_context {
19351967
19361968 SLT_INF (slot, " prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n " , slot.n_ctx , slot.params .n_keep , slot.n_prompt_tokens );
19371969
1938- // print prompt tokens:
1939- for (int i = 0 ; i < (int ) prompt_tokens.size (); i++) {
1940- SLT_DBG (slot, " prompt token %3d: %6d '%s'\n " , i, prompt_tokens[i], common_token_to_piece (ctx, prompt_tokens[i]).c_str ());
1970+ // print prompt tokens (for debugging)
1971+ if (1 ) {
1972+ // first 16 tokens (avoid flooding logs)
1973+ for (int i = 0 ; i < std::min<int >(16 , prompt_tokens.size ()); i++) {
1974+ SLT_DBG (slot, " prompt token %3d: %6d '%s'\n " , i, prompt_tokens[i], common_token_to_piece (ctx, prompt_tokens[i]).c_str ());
1975+ }
1976+ } else {
1977+ // all
1978+ for (int i = 0 ; i < (int ) prompt_tokens.size (); i++) {
1979+ SLT_DBG (slot, " prompt token %3d: %6d '%s'\n " , i, prompt_tokens[i], common_token_to_piece (ctx, prompt_tokens[i]).c_str ());
1980+ }
19411981 }
19421982
19431983 // empty prompt passed -> release the slot and send empty response
@@ -2001,12 +2041,61 @@ struct server_context {
20012041
20022042 if (slot.params .cache_prompt ) {
20032043 // reuse any previously computed tokens that are common with the new prompt
2004- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2044+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
20052045
20062046 // push the prompt into the sampling context (do not apply grammar)
20072047 for (int i = 0 ; i < slot.n_past ; ++i) {
20082048 common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
20092049 }
2050+
2051+ // EXPERIMENTAL: reuse chunks from the cached prompt by shifting them in the new position
2052+ if (1 ) {
2053+ size_t head_c = slot.n_past ; // cache
2054+ size_t head_p = slot.n_past ; // current prompt
2055+
2056+ while (head_c < slot.cache_tokens .size () &&
2057+ head_p < prompt_tokens.size () &&
2058+ !llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2059+ !llama_token_is_control (model, prompt_tokens[head_p])) {
2060+
2061+ size_t n_match = 0 ;
2062+ while (head_c + n_match < slot.cache_tokens .size () &&
2063+ head_p + n_match < prompt_tokens.size () &&
2064+ !llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2065+ !llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2066+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2067+ n_match++;
2068+ }
2069+
2070+ if (n_match > 32 ) {
2071+ // shift the KV chunk [head_c, head_c + n_match) -> [head_p, head_p + n_match)
2072+ SLT_DBG (slot, " shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , head_c, head_c + n_match, head_p, head_p + n_match);
2073+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2074+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2075+ // }
2076+
2077+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2078+
2079+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2080+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2081+
2082+ for (size_t i = 0 ; i < n_match; i++) {
2083+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2084+
2085+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2086+
2087+ slot.n_past ++;
2088+ }
2089+
2090+ head_c += n_match;
2091+ head_p += n_match;
2092+ } else {
2093+ head_c += 1 ;
2094+ }
2095+ }
2096+
2097+ SLT_DBG (slot, " new slot.n_past = %d, cache_tokens.size() = %zu\n " , slot.n_past , slot.cache_tokens .size ());
2098+ }
20102099 }
20112100 }
20122101
@@ -3216,6 +3305,7 @@ int main(int argc, char ** argv) {
32163305
32173306 ctx_server.queue_tasks .on_new_task (std::bind (
32183307 &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3308+
32193309 ctx_server.queue_tasks .on_update_slots (std::bind (
32203310 &server_context::update_slots, &ctx_server));
32213311
0 commit comments