Skip to content

Commit c6afab9

Browse files
committed
server : add "extra_context" for infill endpoint
1 parent 947035a commit c6afab9

File tree

1 file changed

+40
-5
lines changed

1 file changed

+40
-5
lines changed

examples/server/server.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ struct slot_params {
139139

140140
json input_prefix;
141141
json input_suffix;
142+
143+
json extra_context;
142144
};
143145

144146
struct server_slot {
@@ -170,6 +172,7 @@ struct server_slot {
170172

171173
// when a task is submitted, we first tokenize the prompt and store it here
172174
std::vector<llama_token> prompt_tokens;
175+
std::vector<llama_token> extra_tokens;
173176

174177
std::string generated_text;
175178
std::vector<llama_token> cache_tokens;
@@ -906,8 +909,18 @@ struct server_context {
906909
}
907910

908911
// infill
909-
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
910-
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
912+
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
913+
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
914+
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
915+
916+
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
917+
for (const auto & chunk : slot.params.extra_context) {
918+
if (chunk.is_string()) {
919+
SLT_DBG(slot, "chunk: \n%s\n", chunk.get<std::string>().c_str());
920+
} else {
921+
SLT_DBG(slot, "%s", "chunk is not a string - skipping\n");
922+
}
923+
}
911924

912925
// get prompt
913926
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1937,9 +1950,27 @@ struct server_context {
19371950
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
19381951
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
19391952

1940-
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1941-
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
1942-
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
1953+
slot.extra_tokens.clear();
1954+
for (const auto & e : slot.params.extra_context) {
1955+
if (e.is_string()) {
1956+
// chunk separator in binary form to avoid confusing the AI
1957+
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
1958+
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
1959+
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
1960+
1961+
const auto part = tokenize(e, false, false);
1962+
slot.extra_tokens.insert(slot.extra_tokens.end(), part.begin(), part.end());
1963+
} else {
1964+
SLT_WRN(slot, "%s", "extra context element is not a string\n");
1965+
}
1966+
}
1967+
1968+
// for now pick FIM context to fit in half batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1969+
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch/4)/2);
1970+
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch/2 - 3) - n_suffix_take);
1971+
1972+
// fill the rest of the context with extra chunks
1973+
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - n_batch - 2*slot.n_predict), slot.extra_tokens.size());
19431974

19441975
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
19451976
suffix_tokens.resize(n_suffix_take);
@@ -1954,6 +1985,10 @@ struct server_context {
19541985
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
19551986
}
19561987

1988+
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
1989+
1990+
embd_inp.insert(embd_inp.begin() + 1, slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
1991+
19571992
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
19581993
embd_inp.push_back(llama_token_fim_mid(model));
19591994

0 commit comments

Comments
 (0)