@@ -112,6 +112,7 @@ struct slot_params {
112112 bool stream = true ;
113113 bool cache_prompt = true ; // remember the prompt to avoid reprocessing all prompt
114114 bool return_tokens = false ;
115+ bool echo = false ;
115116
116117 int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
117118 int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -160,6 +161,7 @@ struct slot_params {
160161 }
161162
162163 return json {
164+ {" echo" , echo},
163165 {" n_predict" , n_predict}, // Server configured n_predict
164166 {" seed" , sampling.seed },
165167 {" temperature" , sampling.temp },
@@ -265,6 +267,7 @@ struct server_task {
265267 params.stream = json_value (data, " stream" , false );
266268 params.cache_prompt = json_value (data, " cache_prompt" , true );
267269 params.return_tokens = json_value (data, " return_tokens" , false );
270+ params.echo = json_value (data, " echo" , false );
268271 params.n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , defaults.n_predict ));
269272 params.n_indent = json_value (data, " n_indent" , defaults.n_indent );
270273 params.n_keep = json_value (data, " n_keep" , defaults.n_keep );
@@ -674,6 +677,91 @@ struct completion_token_output {
674677 return out;
675678 }
676679
680+ static json oaicompat_probs_vector_to_json (
681+ const std::vector<completion_token_output> & probs_out,
682+ bool post_sampling_probs,
683+ bool echo,
684+ const std::vector<completion_token_output> & prompt_probs = {}
685+ ) {
686+ json out = json::object ();
687+
688+ std::vector<std::string> tokens;
689+ std::vector<completion_token_output> all_probs;
690+
691+ if (echo && !prompt_probs.empty ()) {
692+ all_probs.insert (all_probs.end (), prompt_probs.begin (), prompt_probs.end ());
693+ }
694+
695+ all_probs.insert (all_probs.end (), probs_out.begin (), probs_out.end ());
696+
697+ tokens.reserve (all_probs.size ());
698+ for (const auto & p : all_probs) {
699+ std::string piece = p.text_to_send ;
700+ piece.resize (validate_utf8 (piece));
701+ tokens.push_back (piece);
702+ }
703+
704+ int text_offset = 0 ;
705+ std::vector<int > text_offsets;
706+ text_offsets.reserve (tokens.size ());
707+
708+ int current_off = text_offset;
709+ for (const auto & tok : tokens) {
710+ text_offsets.push_back (current_off);
711+ current_off += static_cast <int >(tok.size ());
712+ }
713+
714+ std::vector<std::optional<float >> token_logprobs;
715+ token_logprobs.reserve (all_probs.size ());
716+
717+ std::vector<std::optional<std::unordered_map<std::string, float >>> top_logprobs;
718+ top_logprobs.reserve (all_probs.size ());
719+
720+ for (size_t i = 0 ; i < all_probs.size (); ++i) {
721+ const auto & p = all_probs[i];
722+
723+ if (std::isinf (p.prob ) && p.prob < 0 ) {
724+ token_logprobs.push_back (std::nullopt );
725+ top_logprobs.push_back (std::nullopt );
726+ } else {
727+ float logprob_value = p.prob ;
728+ if (!post_sampling_probs) {
729+ logprob_value = p.prob ;
730+ } else {
731+ logprob_value = p.prob > 0 .0f ? std::log (p.prob ) : -std::numeric_limits<float >::infinity ();
732+ }
733+
734+ token_logprobs.push_back (std::optional<float >(logprob_value));
735+
736+ std::unordered_map<std::string, float > top_map;
737+ for (const auto & cand : p.probs ) {
738+ std::string cand_txt = cand.txt ;
739+ cand_txt.resize (validate_utf8 (cand_txt));
740+
741+ float cand_logprob;
742+ if (!post_sampling_probs) {
743+ cand_logprob = cand.prob ;
744+ } else {
745+ cand_logprob = cand.prob > 0 .0f ? std::log (cand.prob ) : -std::numeric_limits<float >::infinity ();
746+ }
747+
748+ top_map[cand_txt] = cand_logprob;
749+ }
750+
751+ top_logprobs.push_back (std::move (top_map));
752+ }
753+ }
754+
755+ out = json{
756+ {" text_offset" , text_offsets},
757+ {" token_logprobs" , token_logprobs},
758+ {" tokens" , tokens},
759+ {" top_logprobs" , top_logprobs}
760+ };
761+
762+ return out;
763+ }
764+
677765 static float logarithm (float x) {
678766 // nlohmann::json converts -inf to null, so we need to prevent that
679767 return x == 0 .0f ? std::numeric_limits<float >::lowest () : std::log (x);
@@ -697,6 +785,7 @@ struct server_task_result_cmpl_final : server_task_result {
697785 bool stream;
698786 result_timings timings;
699787 std::string prompt;
788+ bool echo = false ;
700789
701790 bool truncated;
702791 int32_t n_decoded;
@@ -708,6 +797,7 @@ struct server_task_result_cmpl_final : server_task_result {
708797
709798 bool post_sampling_probs;
710799 std::vector<completion_token_output> probs_output;
800+ std::vector<completion_token_output> prompt_probs_output;
711801 std::vector<std::string> response_fields;
712802
713803 slot_params generation_params;
@@ -769,19 +859,26 @@ struct server_task_result_cmpl_final : server_task_result {
769859 json to_json_oaicompat () {
770860 std::time_t t = std::time (0 );
771861 json logprobs = json (nullptr ); // OAI default to null
772- if (!stream && probs_output.size () > 0 ) {
773- logprobs = json{
774- {" content" , completion_token_output::probs_vector_to_json (probs_output, post_sampling_probs)},
775- };
862+ if (!stream && (probs_output.size () > 0 || (echo && prompt_probs_output.size () > 0 ))) {
863+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
864+ probs_output,
865+ post_sampling_probs,
866+ echo,
867+ prompt_probs_output
868+ );
776869 }
777870 json finish_reason = " length" ;
778871 if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
779872 finish_reason = " stop" ;
780873 }
874+ std::string response_text = content;
875+ if (echo && !stream) {
876+ response_text = prompt + content;
877+ }
781878 json res = json {
782879 {" choices" , json::array ({
783880 json{
784- {" text" , stream ? " " : content }, // in stream mode, content is already in last partial chunk
881+ {" text" , stream ? " " : response_text }, // in stream mode, content is already in last partial chunk
785882 {" index" , index},
786883 {" logprobs" , logprobs},
787884 {" finish_reason" , finish_reason},
@@ -940,6 +1037,10 @@ struct server_task_result_cmpl_partial : server_task_result {
9401037 std::string oaicompat_cmpl_id;
9411038 std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
9421039
1040+ bool echo = false ;
1041+ std::string prompt_text;
1042+ bool is_first_chunk = false ;
1043+
9431044 virtual int get_index () override {
9441045 return index;
9451046 }
@@ -986,14 +1087,21 @@ struct server_task_result_cmpl_partial : server_task_result {
9861087 std::time_t t = std::time (0 );
9871088 json logprobs = json (nullptr ); // OAI default to null
9881089 if (prob_output.probs .size () > 0 ) {
989- logprobs = json{
990- {" content" , completion_token_output::probs_vector_to_json ({prob_output}, post_sampling_probs)},
991- };
1090+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
1091+ std::vector<completion_token_output>{prob_output},
1092+ post_sampling_probs,
1093+ echo
1094+ );
1095+ }
1096+
1097+ std::string response_text = content;
1098+ if (echo && is_first_chunk) {
1099+ response_text = prompt_text + content;
9921100 }
9931101 json res = json {
9941102 {" choices" , json::array ({
9951103 json{
996- {" text" , content },
1104+ {" text" , response_text },
9971105 {" index" , index},
9981106 {" logprobs" , logprobs},
9991107 {" finish_reason" , nullptr },
@@ -1321,6 +1429,8 @@ struct server_slot {
13211429
13221430 // input prompt tokens
13231431 server_tokens prompt_tokens;
1432+ std::string prompt_text;
1433+ std::vector<completion_token_output> prompt_token_probs;
13241434
13251435 size_t last_nl_pos = 0 ;
13261436
@@ -1368,6 +1478,7 @@ struct server_slot {
13681478 SLT_DBG (*this , " %s" , " \n " );
13691479
13701480 n_prompt_tokens = 0 ;
1481+ prompt_text = " " ;
13711482 last_nl_pos = 0 ;
13721483 generated_text = " " ;
13731484 has_new_line = false ;
@@ -1381,6 +1492,7 @@ struct server_slot {
13811492
13821493 generated_tokens.clear ();
13831494 generated_token_probs.clear ();
1495+ prompt_token_probs.clear ();
13841496 chat_msg = {};
13851497 json_schema = json ();
13861498 generated_tool_call_ids.clear ();
@@ -2240,6 +2352,113 @@ struct server_context {
22402352 slot.params = std::move (task.params );
22412353 slot.prompt_tokens = std::move (task.prompt_tokens );
22422354
2355+ if (slot.params .echo ) {
2356+ slot.prompt_text = slot.prompt_tokens .detokenize (ctx, true );
2357+
2358+ if (slot.params .sampling .n_probs > 0 && slot.prompt_tokens .size () > 1 && slot.prompt_token_probs .empty ()) {
2359+ slot.prompt_token_probs .reserve (slot.prompt_tokens .size ());
2360+
2361+ llama_memory_clear (llama_get_memory (ctx), true );
2362+
2363+ const int n_batch = llama_n_batch (ctx);
2364+ const int num_batches = (slot.prompt_tokens .size () + n_batch - 1 ) / n_batch;
2365+ const int n_vocab = llama_vocab_n_tokens (vocab);
2366+
2367+ std::vector<float > all_logits;
2368+ if (num_batches > 1 ) {
2369+ all_logits.reserve (slot.prompt_tokens .size () * n_vocab);
2370+ }
2371+
2372+ for (int batch_idx = 0 ; batch_idx < num_batches; ++batch_idx) {
2373+ const int batch_start = batch_idx * n_batch;
2374+ const int batch_size = std::min ((int )slot.prompt_tokens .size () - batch_start, n_batch);
2375+
2376+ llama_batch batch = llama_batch_init (batch_size, 0 , 1 );
2377+ for (int i = 0 ; i < batch_size; ++i) {
2378+ common_batch_add (batch, slot.prompt_tokens [batch_start + i], batch_start + i, {0 }, true );
2379+ }
2380+
2381+ if (llama_decode (ctx, batch) == 0 ) {
2382+ const float * batch_logits = llama_get_logits (ctx);
2383+ if (num_batches > 1 ) {
2384+ all_logits.insert (all_logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
2385+ }
2386+ } else {
2387+ llama_batch_free (batch);
2388+ break ;
2389+ }
2390+ llama_batch_free (batch);
2391+ }
2392+
2393+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2394+ completion_token_output prompt_token;
2395+ prompt_token.tok = slot.prompt_tokens [i];
2396+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2397+
2398+ if (i == 0 ) {
2399+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2400+ } else {
2401+ const float * logits = num_batches > 1 ?
2402+ all_logits.data () + (i - 1 ) * n_vocab :
2403+ llama_get_logits_ith (ctx, i - 1 );
2404+
2405+ if (logits != nullptr ) {
2406+ float max_logit = logits[0 ];
2407+ for (int j = 1 ; j < n_vocab; ++j) {
2408+ max_logit = std::max (max_logit, logits[j]);
2409+ }
2410+
2411+ double sum_exp = 0.0 ;
2412+ for (int j = 0 ; j < n_vocab; ++j) {
2413+ sum_exp += expf (logits[j] - max_logit);
2414+ }
2415+
2416+ const float log_sum_exp = max_logit + logf (sum_exp);
2417+ prompt_token.prob = logits[slot.prompt_tokens [i]] - log_sum_exp;
2418+
2419+ if (slot.params .sampling .n_probs > 0 ) {
2420+ std::vector<std::pair<float , llama_token>> logits_id;
2421+ logits_id.reserve (n_vocab);
2422+
2423+ for (int j = 0 ; j < n_vocab; j++) {
2424+ const float logprob = logits[j] - log_sum_exp;
2425+ logits_id.emplace_back (logprob, j);
2426+ }
2427+
2428+ std::partial_sort (logits_id.begin (),
2429+ logits_id.begin () + std::min ((size_t )slot.params .sampling .n_probs , logits_id.size ()),
2430+ logits_id.end (),
2431+ [](const auto & a, const auto & b) { return a.first > b.first ; });
2432+
2433+ prompt_token.probs .clear ();
2434+ size_t top_k = std::min (logits_id.size (), static_cast <size_t >(slot.params .sampling .n_probs ));
2435+ for (size_t k = 0 ; k < top_k; ++k) {
2436+ completion_token_output::prob_info prob_info;
2437+ prob_info.tok = logits_id[k].second ;
2438+ prob_info.prob = logits_id[k].first ;
2439+ prob_info.txt = common_token_to_piece (ctx, logits_id[k].second , true );
2440+ prompt_token.probs .push_back (prob_info);
2441+ }
2442+ }
2443+ } else {
2444+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2445+ }
2446+ }
2447+
2448+ slot.prompt_token_probs .push_back (prompt_token);
2449+ }
2450+ } else {
2451+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2452+ completion_token_output prompt_token;
2453+ prompt_token.tok = slot.prompt_tokens [i];
2454+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2455+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2456+ slot.prompt_token_probs .push_back (prompt_token);
2457+ }
2458+ }
2459+ }
2460+
2461+
22432462 if (!are_lora_equal (slot.params .lora , slot.lora )) {
22442463 // if lora is changed, we cannot reuse cached tokens
22452464 slot.cache_tokens .clear ();
@@ -2529,6 +2748,10 @@ struct server_context {
25292748 res->content = tkn.text_to_send ;
25302749 res->tokens = { tkn.tok };
25312750
2751+ res->echo = slot.params .echo ;
2752+ res->prompt_text = slot.prompt_text ;
2753+ res->is_first_chunk = (slot.n_decoded == 1 );
2754+
25322755 res->n_decoded = slot.n_decoded ;
25332756 res->n_prompt_tokens = slot.n_prompt_tokens ;
25342757 res->post_sampling_probs = slot.params .post_sampling_probs ;
@@ -2562,7 +2785,9 @@ struct server_context {
25622785 res->content = slot.generated_text ;
25632786 res->tokens = std::move (slot.generated_tokens );
25642787 res->timings = slot.get_timings ();
2565- res->prompt = slot.prompt_tokens .detokenize (ctx, true );
2788+
2789+ res->echo = slot.params .echo ;
2790+ res->prompt = slot.params .echo ? slot.prompt_text : slot.prompt_tokens .detokenize (ctx, true );
25662791 res->response_fields = std::move (slot.params .response_fields );
25672792
25682793 res->truncated = slot.truncated ;
@@ -2595,6 +2820,10 @@ struct server_context {
25952820 slot.generated_token_probs .begin (),
25962821 slot.generated_token_probs .end ());
25972822 }
2823+
2824+ if (slot.params .echo && !slot.prompt_token_probs .empty ()) {
2825+ res->prompt_probs_output = slot.prompt_token_probs ;
2826+ }
25982827 }
25992828
26002829 res->generation_params = slot.params ; // copy the parameters
0 commit comments