Skip to content

Commit cefb32b

Browse files
committed
server : smarter cache reuse + "prompt" support in /infill
1 parent 26af4cd commit cefb32b

File tree

1 file changed

+7
-27
lines changed

1 file changed

+7
-27
lines changed

examples/server/server.cpp

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ struct slot_params {
137137

138138
std::vector<std::string> antiprompt;
139139

140+
// TODO: move to server_slot
140141
json input_prefix;
141142
json input_suffix;
142143
json extra_context;
@@ -930,7 +931,7 @@ struct server_context {
930931
}
931932

932933
// get prompt
933-
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
934+
{
934935
const auto & prompt = data.find("prompt");
935936
if (prompt == data.end()) {
936937
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
@@ -1964,10 +1965,11 @@ struct server_context {
19641965
// extra chunk 1
19651966
// ...
19661967
// [FIM_SEP]filename
1967-
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1968+
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
19681969
//
19691970
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
19701971
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
1972+
auto prompt_start = tokenize(slot.prompt, false, false);
19711973

19721974
slot.extra_tokens.clear();
19731975
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
@@ -2008,8 +2010,8 @@ struct server_context {
20082010
}
20092011

20102012
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2011-
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
2012-
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
2013+
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch/4));
2014+
const int n_prefix_take = std::min<int>(prefix_tokens.size(), 3*(n_batch/4) - 3);
20132015

20142016
// fill the rest of the context with extra chunks
20152017
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
@@ -2018,6 +2020,7 @@ struct server_context {
20182020
suffix_tokens.resize(n_suffix_take);
20192021

20202022
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
2023+
prefix_tokens.insert(prefix_tokens.end(), prompt_start.begin(), prompt_start.end());
20212024
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
20222025

20232026
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
@@ -2136,34 +2139,11 @@ struct server_context {
21362139

21372140
while (head_c < slot.cache_tokens.size() &&
21382141
head_p < prompt_tokens.size()) {
2139-
if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
2140-
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
2141-
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
2142-
break;
2143-
}
2144-
2145-
if (llama_token_is_control(model, prompt_tokens[head_p]) &&
2146-
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
2147-
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
2148-
break;
2149-
}
21502142

21512143
size_t n_match = 0;
2152-
21532144
while (head_c + n_match < slot.cache_tokens.size() &&
21542145
head_p + n_match < prompt_tokens.size() &&
21552146
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
2156-
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
2157-
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
2158-
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
2159-
break;
2160-
}
2161-
2162-
if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
2163-
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
2164-
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
2165-
break;
2166-
}
21672147

21682148
n_match++;
21692149
}

0 commit comments

Comments
 (0)