Skip to content

Commit 67de6ab

Browse files
committed
server : reuse context chunks
ggml-ci
1 parent e5f74fe commit 67de6ab

File tree

2 files changed

+100
-10
lines changed

2 files changed

+100
-10
lines changed

examples/server/server.cpp

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,12 @@ struct slot_params {
128128
bool stream = true;
129129
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
130130

131-
int32_t n_keep = 0; // number of tokens to keep from initial prompt
132-
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133-
int32_t n_predict = -1; // new tokens to predict
131+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
132+
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133+
int32_t n_predict = -1; // new tokens to predict
134+
135+
int64_t t_max_prompt_ms = -1; // TODO: not implemented
136+
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
134137

135138
std::vector<std::string> antiprompt;
136139

@@ -175,6 +178,7 @@ struct server_slot {
175178
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176179

177180
bool has_next_token = true;
181+
bool has_new_line = false;
178182
bool truncated = false;
179183
bool stopped_eos = false;
180184
bool stopped_word = false;
@@ -210,6 +214,7 @@ struct server_slot {
210214

211215
n_prompt_tokens = 0;
212216
generated_text = "";
217+
has_new_line = false;
213218
truncated = false;
214219
stopped_eos = false;
215220
stopped_word = false;
@@ -795,7 +800,7 @@ struct server_context {
795800
int slot_prompt_len = slot_prompt.size();
796801

797802
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
798-
int lcp_len = common_part(slot_prompt, prompt);
803+
int lcp_len = longest_common_prefix(slot_prompt, prompt);
799804

800805
// fraction of the common substring length compared to the current slot's prompt length
801806
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
@@ -931,6 +936,10 @@ struct server_context {
931936
}
932937
}
933938

939+
// time limits
940+
slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms);
941+
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
942+
934943
{
935944
slot.sparams.logit_bias.clear();
936945

@@ -1101,6 +1110,20 @@ struct server_context {
11011110
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
11021111
}
11031112

1113+
// if we have already seen a new line, we stop after a certain time limit
1114+
if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
1115+
(ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1116+
slot.stopped_limit = true;
1117+
slot.has_next_token = false;
1118+
1119+
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1120+
}
1121+
1122+
// check if there is a new line in the generated text
1123+
if (result.text_to_send.find('\n') != std::string::npos) {
1124+
slot.has_new_line = true;
1125+
}
1126+
11041127
// if context shift is disabled, we stop when it reaches the context limit
11051128
if (slot.n_decoded >= slot.n_ctx) {
11061129
slot.truncated = true;
@@ -1249,6 +1272,7 @@ struct server_context {
12491272
{"tokens_evaluated", slot.n_prompt_tokens},
12501273
{"generation_settings", get_formated_generation(slot)},
12511274
{"prompt", slot.prompt},
1275+
{"has_new_line", slot.has_new_line},
12521276
{"truncated", slot.truncated},
12531277
{"stopped_eos", slot.stopped_eos},
12541278
{"stopped_word", slot.stopped_word},
@@ -1575,6 +1599,7 @@ struct server_context {
15751599
slot_data["prompt"] = slot.prompt;
15761600
slot_data["next_token"] = {
15771601
{"has_next_token", slot.has_next_token},
1602+
{"has_new_line", slot.has_new_line},
15781603
{"n_remain", slot.n_remaining},
15791604
{"n_decoded", slot.n_decoded},
15801605
{"stopped_eos", slot.stopped_eos},
@@ -1913,6 +1938,13 @@ struct server_context {
19131938
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
19141939
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
19151940

1941+
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1942+
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
1943+
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
1944+
1945+
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
1946+
suffix_tokens.resize(n_suffix_take);
1947+
19161948
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
19171949
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
19181950

@@ -1935,9 +1967,17 @@ struct server_context {
19351967

19361968
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
19371969

1938-
// print prompt tokens:
1939-
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
1940-
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1970+
// print prompt tokens (for debugging)
1971+
if (1) {
1972+
// first 16 tokens (avoid flooding logs)
1973+
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
1974+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1975+
}
1976+
} else {
1977+
// all
1978+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
1979+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1980+
}
19411981
}
19421982

19431983
// empty prompt passed -> release the slot and send empty response
@@ -2001,12 +2041,61 @@ struct server_context {
20012041

20022042
if (slot.params.cache_prompt) {
20032043
// reuse any previously computed tokens that are common with the new prompt
2004-
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
2044+
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
20052045

20062046
// push the prompt into the sampling context (do not apply grammar)
20072047
for (int i = 0; i < slot.n_past; ++i) {
20082048
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
20092049
}
2050+
2051+
// EXPERIMENTAL: reuse chunks from the cached prompt by shifting them in the new position
2052+
if (1) {
2053+
size_t head_c = slot.n_past; // cache
2054+
size_t head_p = slot.n_past; // current prompt
2055+
2056+
while (head_c < slot.cache_tokens.size() &&
2057+
head_p < prompt_tokens.size() &&
2058+
!llama_token_is_control(model, slot.cache_tokens[head_c]) &&
2059+
!llama_token_is_control(model, prompt_tokens[head_p])) {
2060+
2061+
size_t n_match = 0;
2062+
while (head_c + n_match < slot.cache_tokens.size() &&
2063+
head_p + n_match < prompt_tokens.size() &&
2064+
!llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
2065+
!llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
2066+
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
2067+
n_match++;
2068+
}
2069+
2070+
if (n_match > 32) {
2071+
// shift the KV chunk [head_c, head_c + n_match) -> [head_p, head_p + n_match)
2072+
SLT_DBG(slot, "shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", head_c, head_c + n_match, head_p, head_p + n_match);
2073+
//for (size_t i = head_p; i < head_p + n_match; i++) {
2074+
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2075+
//}
2076+
2077+
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
2078+
2079+
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
2080+
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
2081+
2082+
for (size_t i = 0; i < n_match; i++) {
2083+
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
2084+
2085+
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
2086+
2087+
slot.n_past++;
2088+
}
2089+
2090+
head_c += n_match;
2091+
head_p += n_match;
2092+
} else {
2093+
head_c += 1;
2094+
}
2095+
}
2096+
2097+
SLT_DBG(slot, "new slot.n_past = %d, cache_tokens.size() = %zu\n", slot.n_past, slot.cache_tokens.size());
2098+
}
20102099
}
20112100
}
20122101

@@ -3216,6 +3305,7 @@ int main(int argc, char ** argv) {
32163305

32173306
ctx_server.queue_tasks.on_new_task(std::bind(
32183307
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
3308+
32193309
ctx_server.queue_tasks.on_update_slots(std::bind(
32203310
&server_context::update_slots, &ctx_server));
32213311

examples/server/utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ static std::string gen_chatcmplid() {
195195
// other common utils
196196
//
197197

198-
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
198+
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
199199
size_t i;
200200
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
201201

202202
return i;
203203
}
204204

205-
static size_t common_part(const std::string & a, const std::string & b) {
205+
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
206206
size_t i;
207207
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
208208

0 commit comments

Comments
 (0)