@@ -443,20 +443,20 @@ struct completion_token_output {
443443    std::string text_to_send;
444444    struct  token_prob  {
445445        llama_token tok;
446-         std::string tok_str ;
446+         std::string txt ;
447447        float  prob;
448448    };
449449    std::vector<token_prob> probs;
450450
451451    json to_json (bool  post_sampling_probs) const  {
452452        json probs_for_token = json::array ();
453453        for  (const  auto  & p : probs) {
454-             std::string tok_str (p.tok_str );
455-             tok_str .resize (validate_utf8 (tok_str ));
454+             std::string txt (p.txt );
455+             txt .resize (validate_utf8 (txt ));
456456            probs_for_token.push_back (json {
457457                {" id"  ,      p.tok },
458-                 {" token"  ,   tok_str },
459-                 {" bytes"  ,   str_to_bytes (p.tok_str )},
458+                 {" token"  ,   txt },
459+                 {" bytes"  ,   str_to_bytes (p.txt )},
460460                {
461461                    post_sampling_probs ? " prob"   : " logprob"  ,
462462                    post_sampling_probs ? p.prob  : logarithm (p.prob )
@@ -468,20 +468,20 @@ struct completion_token_output {
468468
469469    static  json probs_vector_to_json (const  std::vector<completion_token_output> & probs, bool  post_sampling_probs) {
470470        json out = json::array ();
471-         for  (const  auto  & it  : probs) {
472-             std::string tok_str (it .text_to_send );
473-             tok_str .resize (validate_utf8 (tok_str ));
471+         for  (const  auto  & p  : probs) {
472+             std::string txt (p .text_to_send );
473+             txt .resize (validate_utf8 (txt ));
474474            out.push_back (json {
475-                 {" id"  ,           it .tok },
476-                 {" token"  ,        tok_str },
477-                 {" bytes"  ,        str_to_bytes (it .text_to_send )},
475+                 {" id"  ,           p .tok },
476+                 {" token"  ,        txt },
477+                 {" bytes"  ,        str_to_bytes (p .text_to_send )},
478478                {
479479                    post_sampling_probs ? " top_probs"   : " top_logprobs"  ,
480-                     it .to_json (post_sampling_probs)
480+                     p .to_json (post_sampling_probs)
481481                },
482482                {
483483                    post_sampling_probs ? " prob"   : " logprob"  ,
484-                     post_sampling_probs ? it .prob  : logarithm (it .prob )
484+                     post_sampling_probs ? p .prob  : logarithm (p .prob )
485485                },
486486            });
487487        }
@@ -1958,44 +1958,45 @@ struct server_context {
19581958        size_t  n_probs = slot.params .sampling .n_probs ;
19591959        int  n_vocab = llama_n_vocab (llama_get_model (ctx));
19601960        if  (post_sampling) {
1961-             std::vector<llama_token_data> cur = get_token_probabilities (ctx, idx);
1961+             //  TODO: optimize this with min-p optimization
1962+             const  auto  * cur_p = common_sampler_get_candidates (slot.smpl );
1963+             const  size_t  max_probs = cur_p->size ;
19621964
19631965            bool  found_sampled_tok = false ;
1964-             result.probs .reserve (n_probs );
1965-             for  (int  i = 0 ; i < n_vocab ; i++) {
1966+             result.probs .reserve (max_probs );
1967+             for  (size_t  i = 0 ; i < max_probs ; i++) {
19661968                //  set probability for sampled token
1967-                 if  (cur [i].id  == result.tok ) {
1969+                 if  (cur_p-> data [i].id  == result.tok ) {
19681970                    found_sampled_tok = true ;
1969-                     result.prob  = cur [i].p ;
1971+                     result.prob  = cur_p-> data [i].p ;
19701972                }
19711973                //  set probability for top n_probs tokens
19721974                result.probs .push_back ({
1973-                     cur [i].id ,
1974-                     common_detokenize (ctx, {cur [i].id }, special),
1975-                     cur [i].p 
1975+                     cur_p-> data [i].id ,
1976+                     common_detokenize (ctx, {cur_p-> data [i].id }, special),
1977+                     cur_p-> data [i].p 
19761978                });
19771979                //  break if we have all the necessary data
19781980                if  (result.probs .size () == n_probs && found_sampled_tok) {
19791981                    break ;
19801982                }
19811983            }
19821984        } else  {
1983-             const  auto  * cur_p = common_sampler_get_candidates (slot.smpl );
1984-             const  size_t  max_probs = cur_p->size ;
1985+             std::vector<llama_token_data> cur = get_token_probabilities (ctx, idx);
19851986
19861987            bool  found_sampled_tok = false ;
1987-             result.probs .reserve (max_probs );
1988-             for  (size_t  i = 0 ; i < max_probs ; i++) {
1988+             result.probs .reserve (n_probs );
1989+             for  (int  i = 0 ; i < n_vocab ; i++) {
19891990                //  set probability for sampled token
1990-                 if  (cur_p-> data [i].id  == result.tok ) {
1991+                 if  (cur [i].id  == result.tok ) {
19911992                    found_sampled_tok = true ;
1992-                     result.prob  = cur_p-> data [i].p ;
1993+                     result.prob  = cur [i].p ;
19931994                }
19941995                //  set probability for top n_probs tokens
19951996                result.probs .push_back ({
1996-                     cur_p-> data [i].id ,
1997-                     common_detokenize (ctx, {cur_p-> data [i].id }, special),
1998-                     cur_p-> data [i].p 
1997+                     cur [i].id ,
1998+                     common_detokenize (ctx, {cur [i].id }, special),
1999+                     cur [i].p 
19992000                });
20002001                //  break if we have all the necessary data
20012002                if  (result.probs .size () == n_probs && found_sampled_tok) {
0 commit comments