@@ -753,12 +753,7 @@ struct server_context {
753753 metrics.init ();
754754 }
755755
756- std::vector<llama_token> tokenize (const json & json_prompt, bool add_special) const {
757- // TODO: currently, we tokenize using special tokens by default
758- // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759- // but it's better compared to completely ignoring ChatML and other chat templates
760- const bool TMP_FORCE_SPECIAL = true ;
761-
756+ std::vector<llama_token> tokenize (const json & json_prompt, bool add_special, bool parse_special) const {
762757 // If `add_bos` is true, we only add BOS, when json_prompt is a string,
763758 // or the first element of the json_prompt array is a string.
764759 std::vector<llama_token> prompt_tokens;
@@ -771,10 +766,10 @@ struct server_context {
771766
772767 std::vector<llama_token> p;
773768 if (first) {
774- p = common_tokenize (ctx, s, add_special, TMP_FORCE_SPECIAL );
769+ p = common_tokenize (ctx, s, add_special, parse_special );
775770 first = false ;
776771 } else {
777- p = common_tokenize (ctx, s, false , TMP_FORCE_SPECIAL );
772+ p = common_tokenize (ctx, s, false , parse_special );
778773 }
779774
780775 prompt_tokens.insert (prompt_tokens.end (), p.begin (), p.end ());
@@ -788,7 +783,7 @@ struct server_context {
788783 }
789784 } else {
790785 auto s = json_prompt.template get <std::string>();
791- prompt_tokens = common_tokenize (ctx, s, add_special, TMP_FORCE_SPECIAL );
786+ prompt_tokens = common_tokenize (ctx, s, add_special, parse_special );
792787 }
793788
794789 return prompt_tokens;
@@ -1215,7 +1210,7 @@ struct server_context {
12151210 slot.params .n_predict , n_ctx_train);
12161211 }
12171212
1218- SLT_DBG (slot, " n_decoded = %d, n_remaining = %d, next token: '%s'\n " , slot.n_decoded , slot.n_remaining , token_str.c_str ());
1213+ SLT_DBG (slot, " n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n " , slot.n_decoded , slot.n_remaining , result. tok , token_str.c_str ());
12191214
12201215 return slot.has_next_token ; // continue
12211216 }
@@ -1483,9 +1478,8 @@ struct server_context {
14831478 if (prompt.is_string () || json_is_array_of_numbers (prompt)) {
14841479 data[" index" ] = 0 ;
14851480 create_task (data, false , nullptr );
1486- }
1487- // otherwise, it's a multiple-prompt task, we break it into smaller tasks
1488- else if (prompt.is_array ()) {
1481+ } else if (prompt.is_array ()) {
1482+ // otherwise, it's a multiple-prompt task, we break it into smaller tasks
14891483 std::vector<json> prompts = prompt;
14901484 if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
14911485 // prompts[0] is the question
@@ -1510,9 +1504,8 @@ struct server_context {
15101504 }
15111505 }
15121506 }
1513- }
1514- // invalid case
1515- else {
1507+ } else {
1508+ // invalid case
15161509 throw std::runtime_error (error_msg);
15171510 }
15181511
@@ -1785,6 +1778,9 @@ struct server_context {
17851778 }
17861779 slot->cache_tokens .resize (token_count);
17871780
1781+ // TODO: maybe detokenize the slot->cache_tokens instead?
1782+ slot->prompt = string_format (" [restored %d tokens from file]" , (int ) token_count);
1783+
17881784 const int64_t t_end = ggml_time_us ();
17891785 const double t_restore_ms = (t_end - t_start) / 1000.0 ;
17901786
@@ -1971,70 +1967,69 @@ struct server_context {
19711967 slot.t_start_process_prompt = ggml_time_us ();
19721968 slot.t_start_generation = 0 ;
19731969
1974- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1975- const bool add_bos = llama_add_bos_token (model);
1976- bool suff_rm_leading_spc = true ;
1977- if (params.input_suffix .find_first_of (' ' ) == 0 && params.input_suffix .size () > 1 ) {
1978- params.input_suffix .erase (0 , 1 );
1979- suff_rm_leading_spc = false ;
1980- }
1981-
1982- auto prefix_tokens = tokenize (slot.params .input_prefix , false );
1983- auto suffix_tokens = tokenize (slot.params .input_suffix , false );
1984-
1985- const int space_token = 29871 ; // TODO: this should not be hardcoded
1986- if (suff_rm_leading_spc && !suffix_tokens.empty () && suffix_tokens[0 ] == space_token) {
1987- suffix_tokens.erase (suffix_tokens.begin ());
1988- }
1989-
1990- prefix_tokens.insert (prefix_tokens.begin (), llama_token_prefix (model));
1991- suffix_tokens.insert (suffix_tokens.begin (), llama_token_suffix (model));
1992-
1993- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
1994- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
1995- if (add_bos) {
1996- embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
1997- }
1998- embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
1999-
2000- const llama_token middle_token = llama_token_middle (model);
2001- if (middle_token >= 0 ) {
2002- embd_inp.push_back (middle_token);
2003- }
2004-
2005- prompt_tokens = embd_inp;
2006- } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2007- // require slot.prompt to be array of 2 strings
2008- if (!slot.prompt .is_array () || slot.prompt .size () != 2 ) {
2009- SLT_ERR (slot, " %s" , " invalid prompt for rerank task\n " );
2010- slot.release ();
2011- send_error (slot, " invalid prompt for rerank task" , ERROR_TYPE_INVALID_REQUEST);
2012- continue ;
2013- }
2014-
2015- // prompt: [BOS]query[EOS][SEP]doc[EOS]
2016- prompt_tokens.clear ();
2017- prompt_tokens.push_back (llama_token_bos (model));
2018- {
2019- const auto part = tokenize (slot.prompt [0 ], false );
2020- prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2021- }
2022- prompt_tokens.push_back (llama_token_eos (model));
2023- prompt_tokens.push_back (llama_token_sep (model));
2024- {
2025- const auto part = tokenize (slot.prompt [1 ], false );
2026- prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
2027- }
2028- prompt_tokens.push_back (llama_token_eos (model));
2029- } else {
2030- prompt_tokens = tokenize (slot.prompt , system_prompt.empty ()); // add BOS if there isn't system prompt
1970+ switch (slot.cmpl_type ) {
1971+ case SERVER_TASK_CMPL_TYPE_NORMAL:
1972+ case SERVER_TASK_CMPL_TYPE_EMBEDDING:
1973+ {
1974+ prompt_tokens = tokenize (slot.prompt , system_prompt.empty (), true ); // add BOS if there isn't system prompt
1975+ } break ;
1976+ case SERVER_TASK_CMPL_TYPE_RERANK:
1977+ {
1978+ // require slot.prompt to be array of 2 strings
1979+ if (!slot.prompt .is_array () || slot.prompt .size () != 2 ) {
1980+ SLT_ERR (slot, " %s" , " invalid prompt for rerank task\n " );
1981+ slot.release ();
1982+ send_error (slot, " invalid prompt for rerank task" , ERROR_TYPE_INVALID_REQUEST);
1983+ continue ;
1984+ }
1985+
1986+ // prompt: [BOS]query[EOS][SEP]doc[EOS]
1987+ prompt_tokens.clear ();
1988+ prompt_tokens.push_back (llama_token_bos (model));
1989+ {
1990+ const auto part = tokenize (slot.prompt [0 ], false , false );
1991+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
1992+ }
1993+ prompt_tokens.push_back (llama_token_eos (model));
1994+ prompt_tokens.push_back (llama_token_sep (model));
1995+ {
1996+ const auto part = tokenize (slot.prompt [1 ], false , false );
1997+ prompt_tokens.insert (prompt_tokens.end (), part.begin (), part.end ());
1998+ }
1999+ prompt_tokens.push_back (llama_token_eos (model));
2000+ } break ;
2001+ case SERVER_TASK_CMPL_TYPE_INFILL:
2002+ {
2003+ auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
2004+ auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
2005+
2006+ prefix_tokens.insert (prefix_tokens.begin (), llama_token_fim_pre (model));
2007+ suffix_tokens.insert (suffix_tokens.begin (), llama_token_fim_suf (model));
2008+
2009+ auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2010+ auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2011+
2012+ if (llama_add_bos_token (model)) {
2013+ embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
2014+ }
2015+
2016+ embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
2017+ embd_inp.push_back (llama_token_fim_mid (model));
2018+
2019+ prompt_tokens = std::move (embd_inp);
2020+ } break ;
20312021 }
20322022
20332023 slot.n_past = 0 ;
20342024 slot.n_prompt_tokens = prompt_tokens.size ();
20352025
20362026 SLT_INF (slot, " prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n " , slot.n_ctx , slot.params .n_keep , slot.n_prompt_tokens );
20372027
2028+ // print prompt tokens:
2029+ for (int i = 0 ; i < (int ) prompt_tokens.size (); i++) {
2030+ SLT_DBG (slot, " prompt token %3d: %6d '%s'\n " , i, prompt_tokens[i], common_token_to_piece (ctx, prompt_tokens[i]).c_str ());
2031+ }
2032+
20382033 // empty prompt passed -> release the slot and send empty response
20392034 if (prompt_tokens.empty ()) {
20402035 SLT_WRN (slot, " %s" , " empty prompt - releasing slot\n " );
@@ -2924,7 +2919,23 @@ int main(int argc, char ** argv) {
29242919 return handle_completions_generic (SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
29252920 };
29262921
2927- const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2922+ const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2923+ std::string err;
2924+ if (llama_token_fim_pre (ctx_server.model ) == LLAMA_TOKEN_NULL) {
2925+ err += " prefix token is missing. " ;
2926+ }
2927+ if (llama_token_fim_suf (ctx_server.model ) == LLAMA_TOKEN_NULL) {
2928+ err += " suffix token is missing. " ;
2929+ }
2930+ if (llama_token_fim_mid (ctx_server.model ) == LLAMA_TOKEN_NULL) {
2931+ err += " middle token is missing. " ;
2932+ }
2933+
2934+ if (!err.empty ()) {
2935+ res_error (res, format_error_response (string_format (" Infill is not supported by this model: %s" , err.c_str ()), ERROR_TYPE_NOT_SUPPORTED));
2936+ return ;
2937+ }
2938+
29282939 json data = json::parse (req.body );
29292940 return handle_completions_generic (SERVER_TASK_CMPL_TYPE_INFILL, data, res);
29302941 };
@@ -3010,7 +3021,8 @@ int main(int argc, char ** argv) {
30103021 if (body.count (" content" ) != 0 ) {
30113022 const bool add_special = json_value (body, " add_special" , false );
30123023 const bool with_pieces = json_value (body, " with_pieces" , false );
3013- std::vector<llama_token> tokens = ctx_server.tokenize (body.at (" content" ), add_special);
3024+
3025+ std::vector<llama_token> tokens = ctx_server.tokenize (body.at (" content" ), add_special, true );
30143026
30153027 if (with_pieces) {
30163028 for (const auto & token : tokens) {
0 commit comments