@@ -112,6 +112,7 @@ struct slot_params {
112112 bool stream = true ;
113113 bool cache_prompt = true ; // remember the prompt to avoid reprocessing all prompt
114114 bool return_tokens = false ;
115+ bool echo = false ;
115116
116117 int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
117118 int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -160,6 +161,7 @@ struct slot_params {
160161 }
161162
162163 return json {
164+ {" echo" , echo},
163165 {" n_predict" , n_predict}, // Server configured n_predict
164166 {" seed" , sampling.seed },
165167 {" temperature" , sampling.temp },
@@ -265,6 +267,7 @@ struct server_task {
265267 params.stream = json_value (data, " stream" , false );
266268 params.cache_prompt = json_value (data, " cache_prompt" , true );
267269 params.return_tokens = json_value (data, " return_tokens" , false );
270+ params.echo = json_value (data, " echo" , false );
268271 params.n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , defaults.n_predict ));
269272 params.n_indent = json_value (data, " n_indent" , defaults.n_indent );
270273 params.n_keep = json_value (data, " n_keep" , defaults.n_keep );
@@ -674,6 +677,98 @@ struct completion_token_output {
674677 return out;
675678 }
676679
680+ static json oaicompat_probs_vector_to_json (
681+ const std::vector<completion_token_output> & probs_out,
682+ bool post_sampling_probs,
683+ bool echo,
684+ const std::vector<completion_token_output> & prompt_probs = {}
685+ ) {
686+ json out = json::object ();
687+
688+ std::vector<std::string> tokens;
689+ std::vector<completion_token_output> all_probs;
690+
691+ if (echo && !prompt_probs.empty ()) {
692+ all_probs.insert (all_probs.end (), prompt_probs.begin (), prompt_probs.end ());
693+ }
694+
695+ all_probs.insert (all_probs.end (), probs_out.begin (), probs_out.end ());
696+
697+ tokens.reserve (all_probs.size ());
698+ for (const auto & p : all_probs) {
699+ std::string piece = p.text_to_send ;
700+ piece.resize (validate_utf8 (piece));
701+ tokens.push_back (piece);
702+ }
703+
704+ int text_offset = 0 ;
705+ std::vector<int > text_offsets;
706+ text_offsets.reserve (tokens.size ());
707+
708+ int current_off = text_offset;
709+ for (const auto & tok : tokens) {
710+ text_offsets.push_back (current_off);
711+ current_off += static_cast <int >(tok.size ());
712+ }
713+
714+ std::vector<std::optional<float >> token_logprobs;
715+ token_logprobs.reserve (all_probs.size ());
716+
717+ std::vector<std::optional<std::unordered_map<std::string, float >>> top_logprobs;
718+ top_logprobs.reserve (all_probs.size ());
719+
720+ for (size_t i = 0 ; i < all_probs.size (); ++i) {
721+ const auto & p = all_probs[i];
722+
723+ bool is_first_prompt_token = echo && (i == 0 );
724+
725+ if (is_first_prompt_token) {
726+ token_logprobs.push_back (std::nullopt );
727+ top_logprobs.push_back (std::nullopt );
728+ } else {
729+ if (std::isinf (p.prob ) && p.prob < 0 ) {
730+ token_logprobs.push_back (std::nullopt );
731+ top_logprobs.push_back (std::nullopt );
732+ } else {
733+ float logprob_value = p.prob ;
734+ if (!post_sampling_probs) {
735+ logprob_value = p.prob ;
736+ } else {
737+ logprob_value = p.prob > 0 .0f ? std::log (p.prob ) : -std::numeric_limits<float >::infinity ();
738+ }
739+
740+ token_logprobs.push_back (std::optional<float >(logprob_value));
741+
742+ std::unordered_map<std::string, float > top_map;
743+ for (const auto & cand : p.probs ) {
744+ std::string cand_txt = cand.txt ;
745+ cand_txt.resize (validate_utf8 (cand_txt));
746+
747+ float cand_logprob;
748+ if (!post_sampling_probs) {
749+ cand_logprob = cand.prob ;
750+ } else {
751+ cand_logprob = cand.prob > 0 .0f ? std::log (cand.prob ) : -std::numeric_limits<float >::infinity ();
752+ }
753+
754+ top_map[cand_txt] = cand_logprob;
755+ }
756+
757+ top_logprobs.push_back (std::move (top_map));
758+ }
759+ }
760+ }
761+
762+ out = json{
763+ {" text_offset" , text_offsets},
764+ {" token_logprobs" , token_logprobs},
765+ {" tokens" , tokens},
766+ {" top_logprobs" , top_logprobs}
767+ };
768+
769+ return out;
770+ }
771+
677772 static float logarithm (float x) {
678773 // nlohmann::json converts -inf to null, so we need to prevent that
679774 return x == 0 .0f ? std::numeric_limits<float >::lowest () : std::log (x);
@@ -697,6 +792,7 @@ struct server_task_result_cmpl_final : server_task_result {
697792 bool stream;
698793 result_timings timings;
699794 std::string prompt;
795+ bool echo = false ;
700796
701797 bool truncated;
702798 int32_t n_decoded;
@@ -708,6 +804,7 @@ struct server_task_result_cmpl_final : server_task_result {
708804
709805 bool post_sampling_probs;
710806 std::vector<completion_token_output> probs_output;
807+ std::vector<completion_token_output> prompt_probs_output;
711808 std::vector<std::string> response_fields;
712809
713810 slot_params generation_params;
@@ -769,19 +866,26 @@ struct server_task_result_cmpl_final : server_task_result {
769866 json to_json_oaicompat () {
770867 std::time_t t = std::time (0 );
771868 json logprobs = json (nullptr ); // OAI default to null
772- if (!stream && probs_output.size () > 0 ) {
773- logprobs = json{
774- {" content" , completion_token_output::probs_vector_to_json (probs_output, post_sampling_probs)},
775- };
869+ if (!stream && (probs_output.size () > 0 || (echo && prompt_probs_output.size () > 0 ))) {
870+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
871+ probs_output,
872+ post_sampling_probs,
873+ echo,
874+ prompt_probs_output
875+ );
776876 }
777877 json finish_reason = " length" ;
778878 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
779879 finish_reason = " stop" ;
780880 }
881+ std::string response_text = content;
882+ if (echo && !stream) {
883+ response_text = prompt + content;
884+ }
781885 json res = json {
782886 {" choices" , json::array ({
783887 json{
784- {" text" , stream ? " " : content }, // in stream mode, content is already in last partial chunk
888+ {" text" , stream ? " " : response_text }, // in stream mode, content is already in last partial chunk
785889 {" index" , index},
786890 {" logprobs" , logprobs},
787891 {" finish_reason" , finish_reason},
@@ -940,6 +1044,10 @@ struct server_task_result_cmpl_partial : server_task_result {
9401044 std::string oaicompat_cmpl_id;
9411045 std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
9421046
1047+ bool echo = false ;
1048+ std::string prompt_text;
1049+ bool is_first_chunk = false ;
1050+
9431051 virtual int get_index () override {
9441052 return index;
9451053 }
@@ -986,14 +1094,21 @@ struct server_task_result_cmpl_partial : server_task_result {
9861094 std::time_t t = std::time (0 );
9871095 json logprobs = json (nullptr ); // OAI default to null
9881096 if (prob_output.probs .size () > 0 ) {
989- logprobs = json{
990- {" content" , completion_token_output::probs_vector_to_json ({prob_output}, post_sampling_probs)},
991- };
1097+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
1098+ std::vector<completion_token_output>{prob_output},
1099+ post_sampling_probs,
1100+ echo
1101+ );
1102+ }
1103+
1104+ std::string response_text = content;
1105+ if (echo && is_first_chunk) {
1106+ response_text = prompt_text + content;
9921107 }
9931108 json res = json {
9941109 {" choices" , json::array ({
9951110 json{
996- {" text" , content },
1111+ {" text" , response_text },
9971112 {" index" , index},
9981113 {" logprobs" , logprobs},
9991114 {" finish_reason" , nullptr },
@@ -1321,6 +1436,8 @@ struct server_slot {
13211436
13221437 // input prompt tokens
13231438 server_tokens prompt_tokens;
1439+ std::string prompt_text;
1440+ std::vector<completion_token_output> prompt_token_probs;
13241441
13251442 size_t last_nl_pos = 0 ;
13261443
@@ -1368,6 +1485,7 @@ struct server_slot {
13681485 SLT_DBG (*this , " %s" , " \n " );
13691486
13701487 n_prompt_tokens = 0 ;
1488+ prompt_text = " " ;
13711489 last_nl_pos = 0 ;
13721490 generated_text = " " ;
13731491 has_new_line = false ;
@@ -1381,6 +1499,7 @@ struct server_slot {
13811499
13821500 generated_tokens.clear ();
13831501 generated_token_probs.clear ();
1502+ prompt_token_probs.clear ();
13841503 chat_msg = {};
13851504 json_schema = json ();
13861505 generated_tool_call_ids.clear ();
@@ -2240,6 +2359,77 @@ struct server_context {
22402359 slot.params = std::move (task.params );
22412360 slot.prompt_tokens = std::move (task.prompt_tokens );
22422361
2362+ if (slot.params .echo ) {
2363+ slot.prompt_text = slot.prompt_tokens .detokenize (ctx, true );
2364+
2365+ if (slot.params .sampling .n_probs > 0 && slot.prompt_tokens .size () > 0 && slot.prompt_token_probs .empty ()) {
2366+ slot.prompt_token_probs .reserve (slot.prompt_tokens .size ());
2367+
2368+ llama_batch batch = llama_batch_init (slot.prompt_tokens .size (), 0 , 1 );
2369+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2370+ common_batch_add (batch, slot.prompt_tokens [i], i, {0 }, 1 );
2371+ }
2372+
2373+ if (llama_decode (ctx, batch) == 0 ) {
2374+ const int n_vocab = llama_vocab_n_tokens (vocab);
2375+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2376+ completion_token_output prompt_token;
2377+ prompt_token.tok = slot.prompt_tokens [i];
2378+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2379+
2380+ if (i > 0 && i < slot.prompt_tokens .size ()) {
2381+ const float * logits = llama_get_logits_ith (ctx, i - 1 );
2382+ if (logits != nullptr ) {
2383+ std::vector<std::pair<float , llama_token>> logits_id;
2384+ logits_id.reserve (n_vocab);
2385+
2386+ for (int j = 0 ; j < n_vocab; j++) {
2387+ logits_id.emplace_back (logits[j], j);
2388+ }
2389+
2390+ prompt_token.probs .clear ();
2391+ size_t top_k = std::min (logits_id.size (), static_cast <size_t >(slot.params .sampling .n_probs ));
2392+ for (size_t k = 0 ; k < top_k; ++k) {
2393+ completion_token_output::prob_info prob_info;
2394+ prob_info.tok = logits_id[k].second ;
2395+ prob_info.prob = logits_id[k].first ;
2396+ prob_info.txt = common_token_to_piece (ctx, logits_id[k].second , true );
2397+ prompt_token.probs .push_back (prob_info);
2398+ }
2399+
2400+ auto actual_token_it = std::find_if (logits_id.begin (), logits_id.end (),
2401+ [&](const std::pair<float , llama_token> & pair) {
2402+ return pair.second == slot.prompt_tokens [i];
2403+ });
2404+
2405+ if (actual_token_it != logits_id.end ()) {
2406+ prompt_token.prob = actual_token_it->first ;
2407+ } else {
2408+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2409+ }
2410+ } else {
2411+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2412+ }
2413+ } else {
2414+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2415+ }
2416+
2417+ slot.prompt_token_probs .push_back (prompt_token);
2418+ }
2419+ } else {
2420+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2421+ completion_token_output prompt_token;
2422+ prompt_token.tok = slot.prompt_tokens [i];
2423+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2424+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2425+ slot.prompt_token_probs .push_back (prompt_token);
2426+ }
2427+ }
2428+
2429+ llama_batch_free (batch);
2430+ }
2431+ }
2432+
22432433 if (!are_lora_equal (slot.params .lora , slot.lora )) {
22442434 // if lora is changed, we cannot reuse cached tokens
22452435 slot.cache_tokens .clear ();
@@ -2529,6 +2719,10 @@ struct server_context {
25292719 res->content = tkn.text_to_send ;
25302720 res->tokens = { tkn.tok };
25312721
2722+ res->echo = slot.params .echo ;
2723+ res->prompt_text = slot.prompt_text ;
2724+ res->is_first_chunk = (slot.n_decoded == 1 );
2725+
25322726 res->n_decoded = slot.n_decoded ;
25332727 res->n_prompt_tokens = slot.n_prompt_tokens ;
25342728 res->post_sampling_probs = slot.params .post_sampling_probs ;
@@ -2562,7 +2756,9 @@ struct server_context {
25622756 res->content = slot.generated_text ;
25632757 res->tokens = std::move (slot.generated_tokens );
25642758 res->timings = slot.get_timings ();
2565- res->prompt = slot.prompt_tokens .detokenize (ctx, true );
2759+
2760+ res->echo = slot.params .echo ;
2761+ res->prompt = slot.params .echo ? slot.prompt_text : slot.prompt_tokens .detokenize (ctx, true );
25662762 res->response_fields = std::move (slot.params .response_fields );
25672763
25682764 res->truncated = slot.truncated ;
@@ -2595,6 +2791,10 @@ struct server_context {
25952791 slot.generated_token_probs .begin (),
25962792 slot.generated_token_probs .end ());
25972793 }
2794+
2795+ if (slot.params .echo && !slot.prompt_token_probs .empty ()) {
2796+ res->prompt_probs_output = slot.prompt_token_probs ;
2797+ }
25982798 }
25992799
26002800 res->generation_params = slot.params ; // copy the parameters
0 commit comments