Skip to content

Commit dfef2c4

Browse files
authored
Merge branch 'ggerganov:master' into master
2 parents a3e6522 + 13dca2a commit dfef2c4

File tree

7 files changed

+252
-37
lines changed

7 files changed

+252
-37
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
18021802
params.n_threads_http = value;
18031803
}
18041804
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_THREADS_HTTP"));
1805+
add_opt(common_arg(
1806+
{"--cache-reuse"}, "N",
1807+
string_format("min chunk size to attempt reusing from the cache via KV shifting (default: %d)", params.n_cache_reuse),
1808+
[](common_params & params, int value) {
1809+
params.n_cache_reuse = value;
1810+
}
1811+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_CACHE_REUSE"));
18051812
add_opt(common_arg(
18061813
{"--metrics"},
18071814
string_format("enable prometheus compatible metrics endpoint (default: %s)", params.endpoint_metrics ? "enabled" : "disabled"),

common/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ struct common_params {
283283
int32_t port = 8080; // server listens on this network port
284284
int32_t timeout_read = 600; // http read timeout in seconds
285285
int32_t timeout_write = timeout_read; // http write timeout in seconds
286-
int n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
286+
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
287+
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
287288

288289
std::string hostname = "127.0.0.1";
289290
std::string public_path = ""; // NOLINT

examples/server/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ The project is under active development, and we are [looking for feedback and co
147147
| `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate<br/>(env: LLAMA_ARG_SSL_CERT_FILE) |
148148
| `-to, --timeout N` | server read/write timeout in seconds (default: 600)<br/>(env: LLAMA_ARG_TIMEOUT) |
149149
| `--threads-http N` | number of threads used to process HTTP requests (default: -1)<br/>(env: LLAMA_ARG_THREADS_HTTP) |
150+
| `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting (default: 0)<br/>(env: LLAMA_ARG_CACHE_REUSE) |
150151
| `--metrics` | enable prometheus compatible metrics endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_METRICS) |
151152
| `--slots` | enable slots monitoring endpoint (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
152153
| `--props` | enable changing global properties via POST /props (default: disabled)<br/>(env: LLAMA_ARG_ENDPOINT_PROPS) |
@@ -523,9 +524,30 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
523524

524525
- `input_prefix`: Set the prefix of the code to infill.
525526
- `input_suffix`: Set the suffix of the code to infill.
527+
- `prompt`: Added after the `FIM_MID` token
528+
- `extra_context`: Additional context inserted before the FIM prefix. See https://github.com/ggerganov/llama.cpp/pull/9874
526529

527530
It also accepts all the options of `/completion`.
528531

532+
If the model has `FIM_REPO` and `FIM_FILE_SEP` tokens, the [repo-level pattern](https://arxiv.org/pdf/2409.12186) is used:
533+
534+
```txt
535+
<FIM_REP>myproject
536+
<FIM_SEP>{chunk 0 filename}
537+
{chunk 0 text}
538+
<FIM_SEP>{chunk 1 filename}
539+
{chunk 1 text}
540+
...
541+
<FIM_SEP>filename
542+
<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
543+
```
544+
545+
If the tokens are missing, then the extra context is simply prefixed at the start:
546+
547+
```txt
548+
[extra_context]<FIM_PRE>[input_prefix]<FIM_SUF>[input_suffix]<FIM_MID>[prompt]
549+
```
550+
529551
### **GET** `/props`: Get server global properties.
530552

531553
This endpoint is public (no API key check). By default, it is read-only. To make POST request to change global properties, you need to start server with `--props`

examples/server/server.cpp

Lines changed: 156 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ struct slot_params {
139139

140140
json input_prefix;
141141
json input_suffix;
142+
json extra_context;
142143
};
143144

144145
struct server_slot {
@@ -170,6 +171,7 @@ struct server_slot {
170171

171172
// when a task is submitted, we first tokenize the prompt and store it here
172173
std::vector<llama_token> prompt_tokens;
174+
std::vector<llama_token> extra_tokens;
173175

174176
std::string generated_text;
175177
std::vector<llama_token> cache_tokens;
@@ -800,7 +802,7 @@ struct server_context {
800802
int slot_prompt_len = slot_prompt.size();
801803

802804
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
803-
int lcp_len = common_part(slot_prompt, prompt);
805+
int lcp_len = longest_common_prefix(slot_prompt, prompt);
804806

805807
// fraction of the common substring length compared to the current slot's prompt length
806808
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
@@ -908,8 +910,26 @@ struct server_context {
908910
}
909911

910912
// infill
911-
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
912-
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
913+
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
914+
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
915+
slot.params.extra_context = json_value(data, "extra_context", default_params.extra_context);
916+
917+
SLT_DBG(slot, "extra_context chunks: %d\n", (int) slot.params.extra_context.size());
918+
for (const auto & chunk : slot.params.extra_context) {
919+
// { "text": string, "filename": string }
920+
if (!chunk.contains("text") || !chunk["text"].is_string()) {
921+
send_error(task, "extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST);
922+
return false;
923+
}
924+
925+
// filename is optional
926+
if (chunk.contains("filename") && !chunk["filename"].is_string()) {
927+
send_error(task, "extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST);
928+
return false;
929+
}
930+
931+
SLT_DBG(slot, "extra_context chunk in file '%s':\n%s\n", chunk.value("filename", "").c_str(), chunk.value("text", "").c_str());
932+
}
913933

914934
// get prompt
915935
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1938,13 +1958,66 @@ struct server_context {
19381958
} break;
19391959
case SERVER_TASK_CMPL_TYPE_INFILL:
19401960
{
1961+
// use FIM repo-level pattern:
1962+
// ref: https://arxiv.org/pdf/2409.12186
1963+
//
1964+
// [FIM_REP]myproject
1965+
// [FIM_SEP]filename0
1966+
// extra chunk 0
1967+
// [FIM_SEP]filename1
1968+
// extra chunk 1
1969+
// ...
1970+
// [FIM_SEP]filename
1971+
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1972+
//
19411973
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
19421974
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
19431975

1944-
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1945-
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
1976+
slot.extra_tokens.clear();
1977+
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) {
1978+
static const auto k_fim_repo = tokenize("myproject\n", false, false);
1979+
1980+
slot.extra_tokens.push_back(llama_token_fim_rep(model));
1981+
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
1982+
}
1983+
1984+
for (const auto & chunk : slot.params.extra_context) {
1985+
// { "text": string, "filename": string }
1986+
const std::string text = chunk.value("text", "");
1987+
const std::string filename = chunk.value("filename", "tmp");
1988+
1989+
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
1990+
const auto k_fim_file = tokenize(filename + "\n", false, false);
1991+
1992+
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
1993+
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
1994+
} else {
1995+
// chunk separator in binary form to avoid confusing the AI
1996+
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
1997+
static const auto k_chunk_prefix_tokens = tokenize(k_chunk_prefix_str, false, false);
1998+
1999+
slot.extra_tokens.insert(slot.extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
2000+
}
2001+
2002+
const auto chunk_tokens = tokenize(text, false, false);
2003+
slot.extra_tokens.insert(slot.extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
2004+
}
2005+
2006+
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) {
2007+
// TODO: current filename
2008+
static const auto k_fim_file = tokenize("filename\n", false, false);
2009+
2010+
slot.extra_tokens.insert(slot.extra_tokens.end(), llama_token_fim_sep(model));
2011+
slot.extra_tokens.insert(slot.extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
2012+
}
2013+
2014+
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2015+
const int n_suffix_take = std::min<int>(suffix_tokens.size(), (n_batch)/4);
19462016
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);
19472017

2018+
// fill the rest of the context with extra chunks
2019+
const int n_extra_take = std::min<int>(std::max<int>(0, slot.n_ctx - (n_batch) - 2*slot.n_predict), slot.extra_tokens.size());
2020+
19482021
prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
19492022
suffix_tokens.resize(n_suffix_take);
19502023

@@ -1958,6 +2031,11 @@ struct server_context {
19582031
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
19592032
}
19602033

2034+
SLT_DBG(slot, "extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", slot.n_ctx, n_extra_take, (int) slot.extra_tokens.size());
2035+
2036+
// put the extra context before the FIM prefix
2037+
embd_inp.insert(embd_inp.begin(), slot.extra_tokens.end() - n_extra_take, slot.extra_tokens.end());
2038+
19612039
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
19622040
embd_inp.push_back(llama_token_fim_mid(model));
19632041

@@ -2016,7 +2094,7 @@ struct server_context {
20162094
}
20172095
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
20182096

2019-
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
2097+
// if input prompt is too big, truncate it
20202098
if (slot.n_prompt_tokens >= slot.n_ctx) {
20212099
const int n_left = slot.n_ctx - slot.params.n_keep;
20222100

@@ -2046,12 +2124,82 @@ struct server_context {
20462124

20472125
if (slot.params.cache_prompt) {
20482126
// reuse any previously computed tokens that are common with the new prompt
2049-
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
2127+
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
20502128

20512129
// push the prompt into the sampling context (do not apply grammar)
20522130
for (int i = 0; i < slot.n_past; ++i) {
20532131
common_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
20542132
}
2133+
2134+
// reuse chunks from the cached prompt by shifting their KV cache in the new position
2135+
if (params.n_cache_reuse > 0) {
2136+
size_t head_c = slot.n_past; // cache
2137+
size_t head_p = slot.n_past; // current prompt
2138+
2139+
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
2140+
2141+
while (head_c < slot.cache_tokens.size() &&
2142+
head_p < prompt_tokens.size()) {
2143+
if (llama_token_is_control(model, slot.cache_tokens[head_c]) &&
2144+
slot.cache_tokens[head_c] != llama_token_fim_rep(model) &&
2145+
slot.cache_tokens[head_c] != llama_token_fim_sep(model)) {
2146+
break;
2147+
}
2148+
2149+
if (llama_token_is_control(model, prompt_tokens[head_p]) &&
2150+
prompt_tokens[head_p] != llama_token_fim_rep(model) &&
2151+
prompt_tokens[head_p] != llama_token_fim_sep(model)) {
2152+
break;
2153+
}
2154+
2155+
size_t n_match = 0;
2156+
2157+
while (head_c + n_match < slot.cache_tokens.size() &&
2158+
head_p + n_match < prompt_tokens.size() &&
2159+
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
2160+
if (llama_token_is_control(model, slot.cache_tokens[head_c + n_match]) &&
2161+
slot.cache_tokens[head_c + n_match] != llama_token_fim_rep(model) &&
2162+
slot.cache_tokens[head_c + n_match] != llama_token_fim_sep(model)) {
2163+
break;
2164+
}
2165+
2166+
if (llama_token_is_control(model, prompt_tokens[head_p + n_match]) &&
2167+
prompt_tokens[head_p + n_match] != llama_token_fim_rep(model) &&
2168+
prompt_tokens[head_p + n_match] != llama_token_fim_sep(model)) {
2169+
break;
2170+
}
2171+
2172+
n_match++;
2173+
}
2174+
2175+
if (n_match >= (size_t) params.n_cache_reuse) {
2176+
SLT_DBG(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2177+
//for (size_t i = head_p; i < head_p + n_match; i++) {
2178+
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2179+
//}
2180+
2181+
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
2182+
2183+
llama_kv_cache_seq_rm (ctx, slot.id + 1, head_p, head_c);
2184+
llama_kv_cache_seq_add(ctx, slot.id + 1, head_c, -1, kv_shift);
2185+
2186+
for (size_t i = 0; i < n_match; i++) {
2187+
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
2188+
2189+
common_sampler_accept(slot.smpl, slot.cache_tokens[head_p + i], false);
2190+
2191+
slot.n_past++;
2192+
}
2193+
2194+
head_c += n_match;
2195+
head_p += n_match;
2196+
} else {
2197+
head_c += 1;
2198+
}
2199+
}
2200+
2201+
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
2202+
}
20552203
}
20562204
}
20572205

@@ -3261,6 +3409,7 @@ int main(int argc, char ** argv) {
32613409

32623410
ctx_server.queue_tasks.on_new_task(std::bind(
32633411
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
3412+
32643413
ctx_server.queue_tasks.on_update_slots(std::bind(
32653414
&server_context::update_slots, &ctx_server));
32663415

examples/server/utils.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,14 @@ static std::string gen_chatcmplid() {
195195
// other common utils
196196
//
197197

198-
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
198+
static size_t longest_common_prefix(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
199199
size_t i;
200200
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
201201

202202
return i;
203203
}
204204

205-
static size_t common_part(const std::string & a, const std::string & b) {
205+
static size_t longest_common_prefix(const std::string & a, const std::string & b) {
206206
size_t i;
207207
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
208208

0 commit comments

Comments
 (0)