@@ -558,6 +558,7 @@ struct slot_params {
558558 std::vector<std::string> antiprompt;
559559
560560 bool timings_per_token = false ;
561+ bool post_sampling_probs = false ;
561562 json input_prefix;
562563 json input_suffix;
563564
@@ -1549,6 +1550,8 @@ struct server_context {
15491550 slot.sparams .n_probs = json_value (data, " n_probs" , default_sparams.n_probs );
15501551 slot.sparams .min_keep = json_value (data, " min_keep" , default_sparams.min_keep );
15511552
1553+ slot.params .post_sampling_probs = json_value (data, " post_sampling_probs" , default_params.post_sampling_probs );
1554+
15521555 // speculative decoding parameters
15531556 slot.params .speculative .n_max = json_value (data, " speculative.n_max" , params.n_draft );
15541557 slot.params .speculative .n_min = json_value (data, " speculative.n_min" , params.n_draft_min );
@@ -1951,26 +1954,7 @@ struct server_context {
19511954 }
19521955
19531956 // check if there is incomplete UTF-8 character at the end
1954- bool incomplete = false ;
1955- for (unsigned i = 1 ; i < 5 && i <= slot.generated_text .size (); ++i) {
1956- unsigned char c = slot.generated_text [slot.generated_text .size () - i];
1957- if ((c & 0xC0 ) == 0x80 ) {
1958- // continuation byte: 10xxxxxx
1959- continue ;
1960- }
1961- if ((c & 0xE0 ) == 0xC0 ) {
1962- // 2-byte character: 110xxxxx ...
1963- incomplete = i < 2 ;
1964- } else if ((c & 0xF0 ) == 0xE0 ) {
1965- // 3-byte character: 1110xxxx ...
1966- incomplete = i < 3 ;
1967- } else if ((c & 0xF8 ) == 0xF0 ) {
1968- // 4-byte character: 11110xxx ...
1969- incomplete = i < 4 ;
1970- }
1971- // else 1-byte character or invalid byte
1972- break ;
1973- }
1957+ bool incomplete = validate_utf8 (slot.generated_text ) < slot.generated_text .size ();
19741958
19751959 if (!incomplete) {
19761960 size_t pos = std::min (slot.n_sent_text , slot.generated_text .size ());
@@ -2066,6 +2050,49 @@ struct server_context {
20662050 return slot.has_next_token ; // continue
20672051 }
20682052
2053+ void populate_token_probs (const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
2054+ size_t n_probs = slot.sparams .n_probs ;
2055+ size_t n_vocab = llama_n_vocab (llama_get_model (ctx));
2056+
2057+ if (post_sampling) {
2058+ const auto * cur_p = llama_sampling_get_candidates (slot.ctx_sampling );
2059+ const size_t max_probs = cur_p->size ;
2060+
2061+ // set probability for sampled token
2062+ for (size_t i = 0 ; i < max_probs; i++) {
2063+ if (cur_p->data [i].id == result.tok ) {
2064+ result.prob = cur_p->data [i].p ;
2065+ break ;
2066+ }
2067+ }
2068+
2069+ // set probability for top n_probs tokens
2070+ result.probs .reserve (max_probs);
2071+ for (size_t i = 0 ; i < std::min (max_probs, n_probs); i++) {
2072+ result.probs .push_back ({
2073+ cur_p->data [i].id ,
2074+ llama_detokenize (ctx, {cur_p->data [i].id }, special),
2075+ cur_p->data [i].p
2076+ });
2077+ }
2078+ } else {
2079+ auto &&[sampled_token_p, cur] = get_token_probabilities (ctx, idx, result.tok , n_probs);
2080+
2081+ // set probability for sampled token
2082+ result.prob = sampled_token_p;
2083+
2084+ // set probability for top n_probs tokens
2085+ result.probs .reserve (n_probs);
2086+ for (size_t i = 0 ; i < std::min (n_vocab, n_probs); i++) {
2087+ result.probs .push_back ({
2088+ cur[i].id ,
2089+ llama_detokenize (ctx, {cur[i].id }, special),
2090+ cur[i].p
2091+ });
2092+ }
2093+ }
2094+ }
2095+
20692096 json get_formated_generation (const server_slot & slot) const {
20702097 const auto eos_bias = slot.sparams .logit_bias .find (llama_token_eos (model));
20712098 const bool ignore_eos = eos_bias != slot.sparams .logit_bias .end () && eos_bias->second < 0 .0f && std::isinf (eos_bias->second );
@@ -2163,6 +2190,7 @@ struct server_context {
21632190 res.stop = false ;
21642191 res.stream = slot.params .stream ;
21652192 res.content = tkn.text_to_send ;
2193+ res.post_sampling_probs = slot.params .post_sampling_probs ;
21662194 res.oaicompat = slot.params .oaicompat ;
21672195 res.oaicompat_model = slot.params .oaicompat_model ;
21682196 res.oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
@@ -2175,26 +2203,18 @@ struct server_context {
21752203 {" multimodal" , false }
21762204 };
21772205 slot.update_chat_msg (res.oaicompat_msg_diffs );
2178- if (slot.sparams .n_probs > 0 ) {
2179- const std::vector<llama_token> to_send_toks = llama_tokenize (ctx, tkn.text_to_send , false );
2180- const size_t probs_pos = std::min (slot.n_sent_token_probs , slot.generated_token_probs .size ());
2181- const size_t probs_stop_pos = std::min (slot.n_sent_token_probs + to_send_toks.size (), slot.generated_token_probs .size ());
2182-
2183- std::vector<completion_token_output> probs_output;
2184- if (probs_pos < probs_stop_pos) {
2185- probs_output = std::vector<completion_token_output>(
2186- slot.generated_token_probs .begin () + probs_pos,
2187- slot.generated_token_probs .begin () + probs_stop_pos);
2188- }
2189- slot.n_sent_token_probs = probs_stop_pos;
21902206
2191- res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, probs_output);
2207+ // populate res.probs_output
2208+ if (slot.sparams .n_probs > 0 ) {
2209+ res.probs_output = {tkn}; // copy the token probs
2210+ res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, res.probs_output );
21922211 }
21932212
21942213 if (slot.oaicompat ) {
21952214 res.data [" oaicompat_token_ctr" ] = slot.n_decoded ;
21962215 res.data [" model" ] = slot.oaicompat_model ;
21972216 }
2217+
21982218 // populate timings if this is final response or timings_per_token is enabled
21992219 if (slot.params .timings_per_token ) {
22002220 res.timings = slot.get_timings ();
@@ -2212,6 +2232,8 @@ struct server_context {
22122232 res.stream = slot.params .stream ;
22132233 res.include_usage = slot.params .include_usage ;
22142234 res.content = slot.generated_text ;
2235+ res.timings = slot.get_timings ();
2236+ res.post_sampling_probs = slot.params .post_sampling_probs ;
22152237 res.oaicompat = slot.params .oaicompat ;
22162238 res.oaicompat_model = slot.params .oaicompat_model ;
22172239 res.oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
@@ -2239,26 +2261,23 @@ struct server_context {
22392261 // {"oaicompat_chat_format", slot.params.oaicompat_chat_format},
22402262 };
22412263
2264+ // populate res.probs_output
22422265 if (slot.sparams .n_probs > 0 ) {
2243- std::vector<completion_token_output> probs;
22442266 if (!slot.params .stream && slot.stopped_word ) {
22452267 const std::vector<llama_token> stop_word_toks = llama_tokenize (ctx, slot.stopping_word , false );
22462268
22472269 size_t safe_offset = std::min (slot.generated_token_probs .size (), stop_word_toks.size ());
2248- probs = std::vector<completion_token_output>(
2270+ res. probs_output = std::vector<completion_token_output>(
22492271 slot.generated_token_probs .begin (),
22502272 slot.generated_token_probs .end () - safe_offset);
22512273 } else {
2252- probs = std::vector<completion_token_output>(
2274+ res. probs_output = std::vector<completion_token_output>(
22532275 slot.generated_token_probs .begin (),
22542276 slot.generated_token_probs .end ());
22552277 }
2256- // res.generation_params = slot.params;
2257- res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, probs);
2278+ res.data [" completion_probabilities" ] = probs_vector_to_json (ctx, res.probs_output );
22582279 }
22592280
2260- res.timings = slot.get_timings ();
2261-
22622281 if (slot.oaicompat ) {
22632282 res.data [" oaicompat_token_ctr" ] = slot.n_decoded ;
22642283 res.data [" model" ] = slot.oaicompat_model ;
@@ -3199,7 +3218,8 @@ struct server_context {
31993218 }
32003219
32013220 completion_token_output result;
3202- const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , slot.i_batch - i);
3221+ const int tok_idx = slot.i_batch - i;
3222+ const llama_token id = llama_sampling_sample (slot.ctx_sampling , ctx, NULL , tok_idx);
32033223
32043224 llama_sampling_accept (slot.ctx_sampling , ctx, id, true );
32053225
@@ -3215,35 +3235,12 @@ struct server_context {
32153235
32163236 slot.t_token_generation = (t_current - slot.t_start_generation ) / 1e3 ;
32173237
3218- llama_token_data_array cur_p = { slot.ctx_sampling ->cur .data (), slot.ctx_sampling ->cur .size (), false };
32193238 result.tok = id;
3239+ result.prob = 1 .0f ; // TODO: set it here instead of doing inside populate_token_probs
32203240 result.text_to_send = llama_token_to_piece (ctx, result.tok , accept_special_token (slot, result.tok ));
32213241
3222- const size_t n_probs = std::min (cur_p.size , (size_t ) slot.sparams .n_probs );
3223- if (n_probs > 0 ) {
3224- const size_t n_valid = slot.ctx_sampling ->n_valid ;
3225-
3226- // Make sure at least n_probs top tokens are at the front of the vector:
3227- if (slot.sparams .temp == 0 .0f && n_probs > n_valid) {
3228- llama_sample_top_k (ctx, &cur_p, n_probs, 0 );
3229- }
3230-
3231- if (slot.sparams .temp == 0 .0f ) {
3232- // With greedy sampling the probabilities have possibly not been calculated.
3233- for (size_t i = 0 ; i < n_probs; ++i) {
3234- result.probs .push_back ({
3235- cur_p.data [i].id ,llama_detokenize (ctx, {cur_p.data [i].id }, params.special ),
3236- i == 0 ? 1 .0f : 0 .0f
3237- });
3238- }
3239- } else {
3240- for (size_t i = 0 ; i < n_probs; ++i) {
3241- result.probs .push_back ({
3242- cur_p.data [i].id , llama_detokenize (ctx, {cur_p.data [i].id }, params.special ),
3243- i >= n_valid ? 0 .0f : cur_p.data [i].p // Tokens filtered out due to e.g. top_k have 0 probability.
3244- });
3245- }
3246- }
3242+ if (slot.sparams .n_probs > 0 ) {
3243+ populate_token_probs (slot, result, slot.params .post_sampling_probs , params.special , tok_idx);
32473244 }
32483245
32493246 if (!process_token (result, slot)) {
@@ -3348,7 +3345,11 @@ struct server_context {
33483345
33493346 result.tok = ids[i];
33503347 result.text_to_send = llama_token_to_piece (ctx, result.tok , params.special );
3351- // result.prob = 1.0f; // set later
3348+ result.prob = 1 .0f ; // set later
3349+
3350+ if (slot.sparams .n_probs > 0 ) {
3351+ populate_token_probs (slot, result, slot.params .post_sampling_probs , params.special , i);
3352+ }
33523353
33533354 if (!process_token (result, slot)) {
33543355 // release slot because of stop condition
0 commit comments