Skip to content

Commit 147470b

Browse files
committed
llama : improve infill support
1 parent 79b8343 commit 147470b

File tree

10 files changed

+517
-377
lines changed

10 files changed

+517
-377
lines changed

common/arg.cpp

Lines changed: 110 additions & 136 deletions
Large diffs are not rendered by default.

common/common.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,21 @@ std::string gpt_params_get_system_info(const gpt_params & params) {
400400
// String utils
401401
//
402402

403+
std::string string_format(const char * fmt, ...) {
404+
va_list ap;
405+
va_list ap2;
406+
va_start(ap, fmt);
407+
va_copy(ap2, ap);
408+
int size = vsnprintf(NULL, 0, fmt, ap);
409+
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
410+
std::vector<char> buf(size + 1);
411+
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
412+
GGML_ASSERT(size2 == size);
413+
va_end(ap2);
414+
va_end(ap);
415+
return std::string(buf.data(), size);
416+
}
417+
403418
std::vector<std::string> string_split(std::string input, char separator) {
404419
std::vector<std::string> parts;
405420
size_t separator_pos = input.find(separator);

common/common.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,28 @@ void gpt_init();
349349

350350
std::string gpt_params_get_system_info(const gpt_params & params);
351351

352-
bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
353-
bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
354-
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
352+
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
353+
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
354+
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
355355
bool set_process_priority(enum ggml_sched_priority prio);
356356

357357
//
358358
// String utils
359359
//
360360

361+
#ifdef __GNUC__
362+
#ifdef __MINGW32__
363+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
364+
#else
365+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
366+
#endif
367+
#else
368+
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
369+
#endif
370+
371+
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
372+
std::string string_format(const char * fmt, ...);
373+
361374
std::vector<std::string> string_split(std::string input, char separator);
362375

363376
std::string string_strip(const std::string & str);

examples/infill/infill.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ int main(int argc, char ** argv) {
205205
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
206206
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
207207

208-
GGML_ASSERT(llama_token_prefix(model) >= 0);
209-
GGML_ASSERT(llama_token_suffix(model) >= 0);
208+
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
209+
GGML_ASSERT(llama_token_fim_suf(model) >= 0);
210210

211-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
212-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
211+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
212+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
213213

214214
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
215215
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
@@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
218218
}
219219
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
220220

221-
const llama_token middle_token = llama_token_middle(model);
221+
const llama_token middle_token = llama_token_fim_mid(model);
222222
if (middle_token >= 0) {
223223
embd_inp.push_back(middle_token);
224224
}
@@ -508,8 +508,8 @@ int main(int argc, char ** argv) {
508508
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
509509
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
510510

511-
inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
512-
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
511+
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
512+
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));
513513

514514
embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
515515
embd_end = params.spm_infill ? inp_pfx : inp_sfx;

examples/server/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,6 @@ Takes a prefix and a suffix and returns the predicted completion as stream.
525525

526526
`input_suffix`: Set the suffix of the code to infill.
527527

528-
It also accepts all the options of `/completion` except `stream` and `prompt`.
529-
530528
- **GET** `/props`: Return current server settings.
531529

532530
**Response format**

examples/server/server.cpp

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -753,12 +753,7 @@ struct server_context {
753753
metrics.init();
754754
}
755755

756-
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
757-
// TODO: currently, we tokenize using special tokens by default
758-
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759-
// but it's better compared to completely ignoring ChatML and other chat templates
760-
const bool TMP_FORCE_SPECIAL = true;
761-
756+
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special, bool parse_special) const {
762757
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
763758
// or the first element of the json_prompt array is a string.
764759
std::vector<llama_token> prompt_tokens;
@@ -771,10 +766,10 @@ struct server_context {
771766

772767
std::vector<llama_token> p;
773768
if (first) {
774-
p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
769+
p = ::llama_tokenize(ctx, s, add_special, parse_special);
775770
first = false;
776771
} else {
777-
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
772+
p = ::llama_tokenize(ctx, s, false, parse_special);
778773
}
779774

780775
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@@ -788,7 +783,7 @@ struct server_context {
788783
}
789784
} else {
790785
auto s = json_prompt.template get<std::string>();
791-
prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
786+
prompt_tokens = ::llama_tokenize(ctx, s, add_special, parse_special);
792787
}
793788

794789
return prompt_tokens;
@@ -1220,7 +1215,7 @@ struct server_context {
12201215
slot.params.n_predict, n_ctx_train);
12211216
}
12221217

1223-
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1218+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
12241219

12251220
return slot.has_next_token; // continue
12261221
}
@@ -1488,9 +1483,8 @@ struct server_context {
14881483
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
14891484
data["index"] = 0;
14901485
create_task(data, false, nullptr);
1491-
}
1492-
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
1493-
else if (prompt.is_array()) {
1486+
} else if (prompt.is_array()) {
1487+
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
14941488
std::vector<json> prompts = prompt;
14951489
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
14961490
// prompts[0] is the question
@@ -1515,9 +1509,8 @@ struct server_context {
15151509
}
15161510
}
15171511
}
1518-
}
1519-
// invalid case
1520-
else {
1512+
} else {
1513+
// invalid case
15211514
throw std::runtime_error(error_msg);
15221515
}
15231516

@@ -1988,31 +1981,23 @@ struct server_context {
19881981

19891982
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
19901983
const bool add_bos = llama_add_bos_token(model);
1991-
bool suff_rm_leading_spc = true;
1992-
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1993-
params.input_suffix.erase(0, 1);
1994-
suff_rm_leading_spc = false;
1995-
}
19961984

1997-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1998-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1985+
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
1986+
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);
19991987

2000-
const int space_token = 29871; // TODO: this should not be hardcoded
2001-
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
2002-
suffix_tokens.erase(suffix_tokens.begin());
2003-
}
2004-
2005-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
2006-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
1988+
prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
1989+
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));
20071990

20081991
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
20091992
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
1993+
20101994
if (add_bos) {
20111995
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
20121996
}
1997+
20131998
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
20141999

2015-
const llama_token middle_token = llama_token_middle(model);
2000+
const llama_token middle_token = llama_token_fim_mid(model);
20162001
if (middle_token >= 0) {
20172002
embd_inp.push_back(middle_token);
20182003
}
@@ -2031,28 +2016,28 @@ struct server_context {
20312016
prompt_tokens.clear();
20322017
prompt_tokens.push_back(llama_token_bos(model));
20332018
{
2034-
const auto part = tokenize(slot.prompt[0], false);
2019+
const auto part = tokenize(slot.prompt[0], false, false);
20352020
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
20362021
}
20372022
prompt_tokens.push_back(llama_token_eos(model));
20382023
prompt_tokens.push_back(llama_token_sep(model));
20392024
{
2040-
const auto part = tokenize(slot.prompt[1], false);
2025+
const auto part = tokenize(slot.prompt[1], false, false);
20412026
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
20422027
}
20432028
prompt_tokens.push_back(llama_token_eos(model));
20442029
} else {
2045-
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
2030+
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
20462031
}
20472032

20482033
slot.n_past = 0;
20492034
slot.n_prompt_tokens = prompt_tokens.size();
20502035

20512036
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
20522037

2053-
// print tokens:
2038+
// print prompt tokens:
20542039
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
2055-
SLT_INF(slot, "prompt token %3d: %6d (%s)\n", i, prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
2040+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], llama_token_to_piece(ctx, prompt_tokens[i]).c_str());
20562041
}
20572042

20582043
// empty prompt passed -> release the slot and send empty response
@@ -2947,7 +2932,23 @@ int main(int argc, char ** argv) {
29472932
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
29482933
};
29492934

2950-
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2935+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2936+
std::string err;
2937+
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
2938+
err += "prefix token is missing. ";
2939+
}
2940+
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
2941+
err += "suffix token is missing. ";
2942+
}
2943+
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
2944+
err += "middle token is missing. ";
2945+
}
2946+
2947+
if (!err.empty()) {
2948+
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
2949+
return;
2950+
}
2951+
29512952
json data = json::parse(req.body);
29522953
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
29532954
};
@@ -3033,7 +3034,8 @@ int main(int argc, char ** argv) {
30333034
if (body.count("content") != 0) {
30343035
const bool add_special = json_value(body, "add_special", false);
30353036
const bool with_pieces = json_value(body, "with_pieces", false);
3036-
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
3037+
3038+
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special, true);
30373039

30383040
if (with_pieces) {
30393041
for (const auto& token : tokens) {

include/llama.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,7 @@ extern "C" {
896896
// Special tokens
897897
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
898898
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
899+
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
899900
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
900901
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
901902
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@@ -904,11 +905,17 @@ extern "C" {
904905
LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
905906
LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
906907

907-
// Codellama infill tokens
908-
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
909-
LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
910-
LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
911-
LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
908+
// infill tokens
909+
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
910+
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
911+
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
912+
913+
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
914+
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
915+
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
916+
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
917+
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
918+
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
912919

913920
//
914921
// Tokenization

src/llama-vocab.cpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) {
16631663
return vocab.special_eos_id;
16641664
}
16651665

1666+
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1667+
return vocab.special_eot_id;
1668+
}
1669+
1670+
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1671+
return vocab.special_eom_id;
1672+
}
1673+
16661674
llama_token llama_token_cls_impl(const struct llama_vocab & vocab) {
16671675
return vocab.special_cls_id;
16681676
}
@@ -1688,23 +1696,39 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) {
16881696
}
16891697

16901698
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) {
1691-
return vocab.special_prefix_id;
1699+
return vocab.special_fim_pre_id;
16921700
}
16931701

16941702
llama_token llama_token_middle_impl(const struct llama_vocab & vocab) {
1695-
return vocab.special_middle_id;
1703+
return vocab.special_fim_mid_id;
16961704
}
16971705

16981706
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) {
1699-
return vocab.special_suffix_id;
1707+
return vocab.special_fim_suf_id;
17001708
}
17011709

1702-
llama_token llama_token_eot_impl(const struct llama_vocab & vocab) {
1703-
return vocab.special_eot_id;
1710+
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) {
1711+
return vocab.special_fim_pre_id;
17041712
}
17051713

1706-
llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
1707-
return vocab.special_eom_id;
1714+
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) {
1715+
return vocab.special_fim_suf_id;
1716+
}
1717+
1718+
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) {
1719+
return vocab.special_fim_mid_id;
1720+
}
1721+
1722+
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) {
1723+
return vocab.special_fim_pad_id;
1724+
}
1725+
1726+
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) {
1727+
return vocab.special_fim_rep_id;
1728+
}
1729+
1730+
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) {
1731+
return vocab.special_fim_sep_id;
17081732
}
17091733

17101734
int32_t llama_tokenize_impl(

0 commit comments

Comments
 (0)