@@ -342,6 +342,11 @@ struct server_task {
342342 }
343343 }
344344
345+ if (params.sampling .n_probs > 0 && params.cache_prompt ) {
346+ SRV_WRN (" cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n " , params.sampling .n_probs );
347+ params.cache_prompt = false ;
348+ }
349+
345350 std::string model_name = params_base.model_alias .empty () ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias ;
346351 params.oaicompat_model = json_value (data, " model" , model_name);
347352
@@ -416,6 +421,7 @@ inline std::string stop_type_to_str(stop_type type) {
416421
417422struct completion_token_output {
418423 llama_token tok;
424+ float prob;
419425 std::string text_to_send;
420426 struct token_prob {
421427 llama_token tok;
@@ -427,25 +433,41 @@ struct completion_token_output {
427433 json to_json () const {
428434 json probs_for_token = json::array ();
429435 for (const auto & p : probs) {
436+ std::string tok_str (p.tok_str );
437+ tok_str.resize (validate_utf8 (tok_str));
430438 probs_for_token.push_back (json {
431- {" tok_str" , p.tok_str },
432- {" prob" , p.prob },
439+ {" id" , p.tok },
440+ {" token" , tok_str},
441+ {" bytes" , str_to_bytes (p.tok_str )},
442+ {" logprob" , p.prob },
433443 });
434444 }
435445 return probs_for_token;
436446 }
437447
438448 static json probs_vector_to_json (const std::vector<completion_token_output> & probs) {
439449 json out = json::array ();
440- for (const auto & prob : probs) {
441- const std::string tok_str = prob.text_to_send ;
450+ for (const auto & it : probs) {
451+ std::string tok_str (it.text_to_send );
452+ tok_str.resize (validate_utf8 (tok_str));
442453 out.push_back (json {
443- {" content" , tok_str},
444- {" probs" , prob.to_json ()},
454+ {" id" , it.tok },
455+ {" token" , tok_str},
456+ {" logprob" , it.prob },
457+ {" bytes" , str_to_bytes (it.text_to_send )},
458+ {" top_logprobs" , it.to_json ()},
445459 });
446460 }
447461 return out;
448462 }
463+
464+ static std::vector<unsigned char > str_to_bytes (const std::string & str) {
465+ std::vector<unsigned char > bytes;
466+ for (unsigned char c : str) {
467+ bytes.push_back (c);
468+ }
469+ return bytes;
470+ }
449471};
450472
451473struct server_task_result_cmpl_final : server_task_result {
@@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result {
506528 {" tokens_cached" , n_tokens_cached},
507529 {" timings" , timings.to_json ()},
508530 };
509- if (!probs_output.empty ()) {
531+ if (!stream && ! probs_output.empty ()) {
510532 res[" completion_probabilities" ] = completion_token_output::probs_vector_to_json (probs_output);
511533 }
512534 return res;
@@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result {
518540 finish_reason = " stop" ;
519541 }
520542
521- json choices = json::array ({ json{
543+ json choice = json{
522544 {" finish_reason" , finish_reason},
523545 {" index" , 0 },
524546 {" message" , json{
525547 {" content" , content},
526548 {" role" , " assistant" }
527549 }
528- }}});
550+ }};
551+
552+ if (!stream && probs_output.size () > 0 ) {
553+ choice[" logprobs" ] = json{
554+ {" content" , completion_token_output::probs_vector_to_json (probs_output)},
555+ };
556+ }
529557
530558 std::time_t t = std::time (0 );
531559
532560 json res = json {
533- {" choices" , choices },
561+ {" choices" , json::array ({choice}) },
534562 {" created" , t},
535563 {" model" , oaicompat_model},
536564 {" object" , " chat.completion" },
@@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result {
560588 finish_reason = " stop" ;
561589 }
562590
563- json choices = json::array ({json{{" finish_reason" , finish_reason},
564- {" index" , 0 },
565- {" delta" , json::object ()}}});
591+ json choice = json{
592+ {" finish_reason" , finish_reason},
593+ {" index" , 0 },
594+ {" delta" , json::object ()}
595+ };
566596
567597 json ret = json {
568- {" choices" , choices },
598+ {" choices" , json::array ({choice}) },
569599 {" created" , t},
570600 {" id" , oaicompat_cmpl_id},
571601 {" model" , oaicompat_model},
@@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result {
592622 int32_t n_decoded;
593623 int32_t n_prompt_tokens;
594624
595- std::vector< completion_token_output> probs_output ;
625+ completion_token_output prob_output ;
596626 result_timings timings;
597627
598628 // OAI-compat fields
@@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result {
628658 if (timings.prompt_n > 0 ) {
629659 res.push_back ({" timings" , timings.to_json ()});
630660 }
631- if (!probs_output .empty ()) {
632- res[" completion_probabilities" ] = completion_token_output::probs_vector_to_json (probs_output );
661+ if (!prob_output. probs .empty ()) {
662+ res[" completion_probabilities" ] = completion_token_output::probs_vector_to_json ({prob_output} );
633663 }
634664 return res;
635665 }
@@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result {
681711 }});
682712 }
683713
714+ GGML_ASSERT (choices.size () >= 1 );
715+
716+ if (prob_output.probs .size () > 0 ) {
717+ choices[0 ][" logprobs" ] = json{
718+ {" content" , completion_token_output::probs_vector_to_json ({prob_output})},
719+ };
720+ }
721+
684722 json ret = json {
685723 {" choices" , choices},
686724 {" created" , t},
@@ -951,7 +989,6 @@ struct server_slot {
951989
952990 // stats
953991 size_t n_sent_text = 0 ; // number of sent text character
954- size_t n_sent_token_probs = 0 ;
955992
956993 int64_t t_start_process_prompt;
957994 int64_t t_start_generation;
@@ -973,7 +1010,6 @@ struct server_slot {
9731010 stopping_word = " " ;
9741011 n_past = 0 ;
9751012 n_sent_text = 0 ;
976- n_sent_token_probs = 0 ;
9771013 task_type = SERVER_TASK_TYPE_COMPLETION;
9781014
9791015 generated_token_probs.clear ();
@@ -1713,34 +1749,15 @@ struct server_context {
17131749
17141750 bool process_token (completion_token_output & result, server_slot & slot) {
17151751 // remember which tokens were sampled - used for repetition penalties during sampling
1716- const std::string token_str = common_token_to_piece (ctx, result.tok , params_base. special ) ;
1752+ const std::string token_str = result.text_to_send ;
17171753 slot.sampled = result.tok ;
17181754
17191755 // search stop word and delete it
17201756 slot.generated_text += token_str;
17211757 slot.has_next_token = true ;
17221758
17231759 // check if there is incomplete UTF-8 character at the end
1724- bool incomplete = false ;
1725- for (unsigned i = 1 ; i < 5 && i <= slot.generated_text .size (); ++i) {
1726- unsigned char c = slot.generated_text [slot.generated_text .size () - i];
1727- if ((c & 0xC0 ) == 0x80 ) {
1728- // continuation byte: 10xxxxxx
1729- continue ;
1730- }
1731- if ((c & 0xE0 ) == 0xC0 ) {
1732- // 2-byte character: 110xxxxx ...
1733- incomplete = i < 2 ;
1734- } else if ((c & 0xF0 ) == 0xE0 ) {
1735- // 3-byte character: 1110xxxx ...
1736- incomplete = i < 3 ;
1737- } else if ((c & 0xF8 ) == 0xF0 ) {
1738- // 4-byte character: 11110xxx ...
1739- incomplete = i < 4 ;
1740- }
1741- // else 1-byte character or invalid byte
1742- break ;
1743- }
1760+ bool incomplete = validate_utf8 (slot.generated_text ) < slot.generated_text .size ();
17441761
17451762 if (!incomplete) {
17461763 size_t pos = std::min (slot.n_sent_text , slot.generated_text .size ());
@@ -1869,6 +1886,29 @@ struct server_context {
18691886 return slot.has_next_token ; // continue
18701887 }
18711888
1889+ void populate_token_probs (const server_slot & slot, completion_token_output & result) {
1890+ const auto * cur_p = common_sampler_get_candidates (slot.smpl );
1891+ const size_t max_probs = cur_p->size ;
1892+
1893+ // set prob for the sampled token
1894+ for (size_t i = 0 ; i < max_probs; ++i) {
1895+ if (result.tok == cur_p->data [i].id ) {
1896+ result.prob = cur_p->data [i].p ;
1897+ break ;
1898+ }
1899+ }
1900+
1901+ // set probs for the top n tokens
1902+ for (size_t i = 0 ; i < std::min (max_probs, (size_t ) slot.params .sampling .n_probs ); ++i) {
1903+ auto tok_id = cur_p->data [i].id ;
1904+ result.probs .push_back ({
1905+ tok_id,
1906+ tokens_to_output_formatted_string (ctx, tok_id),
1907+ cur_p->data [i].p ,
1908+ });
1909+ }
1910+ }
1911+
18721912 void send_error (const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
18731913 send_error (task.id , error, type);
18741914 }
@@ -1906,17 +1946,7 @@ struct server_context {
19061946
19071947 // populate res.probs_output
19081948 if (slot.params .sampling .n_probs > 0 ) {
1909- const llama_tokens to_send_toks = common_tokenize (ctx, tkn.text_to_send , false );
1910-
1911- const size_t probs_pos = std::min (slot.n_sent_token_probs , slot.generated_token_probs .size ());
1912- const size_t probs_stop_pos = std::min (slot.n_sent_token_probs + to_send_toks.size (), slot.generated_token_probs .size ());
1913-
1914- std::vector<completion_token_output> probs_output;
1915- if (probs_pos < probs_stop_pos) {
1916- res->probs_output = std::vector<completion_token_output>(
1917- slot.generated_token_probs .begin () + probs_pos,
1918- slot.generated_token_probs .begin () + probs_stop_pos);
1919- }
1949+ res->prob_output = tkn; // copy the token probs
19201950 }
19211951
19221952 // populate timings if this is final response or timings_per_token is enabled
@@ -2747,17 +2777,12 @@ struct server_context {
27472777 slot.t_token_generation = (t_current - slot.t_start_generation ) / 1e3 ;
27482778
27492779 completion_token_output result;
2750- result.tok = id;
2780+ result.tok = id;
2781+ result.text_to_send = common_token_to_piece (ctx, result.tok , params_base.special );
2782+ result.prob = 1 .0f ; // set later
27512783
2752- const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2753-
2754- for (size_t i = 0 ; i < (size_t ) slot.params .sampling .n_probs ; ++i) {
2755- auto tok_id = cur_p->data [i].id ;
2756- result.probs .push_back ({
2757- tok_id,
2758- tokens_to_output_formatted_string (ctx, tok_id),
2759- i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2760- });
2784+ if (slot.params .sampling .n_probs > 0 ) {
2785+ populate_token_probs (slot, result);
27612786 }
27622787
27632788 if (!process_token (result, slot)) {
@@ -2841,7 +2866,9 @@ struct server_context {
28412866 for (size_t i = 0 ; i < ids.size (); ++i) {
28422867 completion_token_output result;
28432868
2844- result.tok = ids[i];
2869+ result.tok = ids[i];
2870+ result.text_to_send = common_token_to_piece (ctx, result.tok , params_base.special );
2871+ result.prob = 1 .0f ; // set later
28452872
28462873 if (!process_token (result, slot)) {
28472874 // release slot because of stop condition
0 commit comments