@@ -112,6 +112,7 @@ struct slot_params {
112
112
bool stream = true ;
113
113
bool cache_prompt = true ; // remember the prompt to avoid reprocessing all prompt
114
114
bool return_tokens = false ;
115
+ bool echo = false ;
115
116
116
117
int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
117
118
int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -160,6 +161,7 @@ struct slot_params {
160
161
}
161
162
162
163
return json {
164
+ {" echo" , echo},
163
165
{" n_predict" , n_predict}, // Server configured n_predict
164
166
{" seed" , sampling.seed },
165
167
{" temperature" , sampling.temp },
@@ -265,6 +267,7 @@ struct server_task {
265
267
params.stream = json_value (data, " stream" , false );
266
268
params.cache_prompt = json_value (data, " cache_prompt" , true );
267
269
params.return_tokens = json_value (data, " return_tokens" , false );
270
+ params.echo = json_value (data, " echo" , false );
268
271
params.n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , defaults.n_predict ));
269
272
params.n_indent = json_value (data, " n_indent" , defaults.n_indent );
270
273
params.n_keep = json_value (data, " n_keep" , defaults.n_keep );
@@ -674,6 +677,98 @@ struct completion_token_output {
674
677
return out;
675
678
}
676
679
680
+ static json oaicompat_probs_vector_to_json (
681
+ const std::vector<completion_token_output> & probs_out,
682
+ bool post_sampling_probs,
683
+ bool echo,
684
+ const std::vector<completion_token_output> & prompt_probs = {}
685
+ ) {
686
+ json out = json::object ();
687
+
688
+ std::vector<std::string> tokens;
689
+ std::vector<completion_token_output> all_probs;
690
+
691
+ if (echo && !prompt_probs.empty ()) {
692
+ all_probs.insert (all_probs.end (), prompt_probs.begin (), prompt_probs.end ());
693
+ }
694
+
695
+ all_probs.insert (all_probs.end (), probs_out.begin (), probs_out.end ());
696
+
697
+ tokens.reserve (all_probs.size ());
698
+ for (const auto & p : all_probs) {
699
+ std::string piece = p.text_to_send ;
700
+ piece.resize (validate_utf8 (piece));
701
+ tokens.push_back (piece);
702
+ }
703
+
704
+ int text_offset = 0 ;
705
+ std::vector<int > text_offsets;
706
+ text_offsets.reserve (tokens.size ());
707
+
708
+ int current_off = text_offset;
709
+ for (const auto & tok : tokens) {
710
+ text_offsets.push_back (current_off);
711
+ current_off += static_cast <int >(tok.size ());
712
+ }
713
+
714
+ std::vector<std::optional<float >> token_logprobs;
715
+ token_logprobs.reserve (all_probs.size ());
716
+
717
+ std::vector<std::optional<std::unordered_map<std::string, float >>> top_logprobs;
718
+ top_logprobs.reserve (all_probs.size ());
719
+
720
+ for (size_t i = 0 ; i < all_probs.size (); ++i) {
721
+ const auto & p = all_probs[i];
722
+
723
+ bool is_first_prompt_token = echo && (i == 0 );
724
+
725
+ if (is_first_prompt_token) {
726
+ token_logprobs.push_back (std::nullopt);
727
+ top_logprobs.push_back (std::nullopt);
728
+ } else {
729
+ if (std::isinf (p.prob ) && p.prob < 0 ) {
730
+ token_logprobs.push_back (std::nullopt);
731
+ top_logprobs.push_back (std::nullopt);
732
+ } else {
733
+ float logprob_value = p.prob ;
734
+ if (!post_sampling_probs) {
735
+ logprob_value = p.prob ;
736
+ } else {
737
+ logprob_value = p.prob > 0 .0f ? std::log (p.prob ) : -std::numeric_limits<float >::infinity ();
738
+ }
739
+
740
+ token_logprobs.push_back (std::optional<float >(logprob_value));
741
+
742
+ std::unordered_map<std::string, float > top_map;
743
+ for (const auto & cand : p.probs ) {
744
+ std::string cand_txt = cand.txt ;
745
+ cand_txt.resize (validate_utf8 (cand_txt));
746
+
747
+ float cand_logprob;
748
+ if (!post_sampling_probs) {
749
+ cand_logprob = cand.prob ;
750
+ } else {
751
+ cand_logprob = cand.prob > 0 .0f ? std::log (cand.prob ) : -std::numeric_limits<float >::infinity ();
752
+ }
753
+
754
+ top_map[cand_txt] = cand_logprob;
755
+ }
756
+
757
+ top_logprobs.push_back (std::move (top_map));
758
+ }
759
+ }
760
+ }
761
+
762
+ out = json{
763
+ {" text_offset" , text_offsets},
764
+ {" token_logprobs" , token_logprobs},
765
+ {" tokens" , tokens},
766
+ {" top_logprobs" , top_logprobs}
767
+ };
768
+
769
+ return out;
770
+ }
771
+
677
772
static float logarithm (float x) {
678
773
// nlohmann::json converts -inf to null, so we need to prevent that
679
774
return x == 0 .0f ? std::numeric_limits<float >::lowest () : std::log (x);
@@ -697,6 +792,7 @@ struct server_task_result_cmpl_final : server_task_result {
697
792
bool stream;
698
793
result_timings timings;
699
794
std::string prompt;
795
+ bool echo = false ;
700
796
701
797
bool truncated;
702
798
int32_t n_decoded;
@@ -708,6 +804,7 @@ struct server_task_result_cmpl_final : server_task_result {
708
804
709
805
bool post_sampling_probs;
710
806
std::vector<completion_token_output> probs_output;
807
+ std::vector<completion_token_output> prompt_probs_output;
711
808
std::vector<std::string> response_fields;
712
809
713
810
slot_params generation_params;
@@ -769,19 +866,26 @@ struct server_task_result_cmpl_final : server_task_result {
769
866
json to_json_oaicompat () {
770
867
std::time_t t = std::time (0 );
771
868
json logprobs = json (nullptr ); // OAI default to null
772
- if (!stream && probs_output.size () > 0 ) {
773
- logprobs = json{
774
- {" content" , completion_token_output::probs_vector_to_json (probs_output, post_sampling_probs)},
775
- };
869
+ if (!stream && (probs_output.size () > 0 || (echo && prompt_probs_output.size () > 0 ))) {
870
+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
871
+ probs_output,
872
+ post_sampling_probs,
873
+ echo,
874
+ prompt_probs_output
875
+ );
776
876
}
777
877
json finish_reason = " length" ;
778
878
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
779
879
finish_reason = " stop" ;
780
880
}
881
+ std::string response_text = content;
882
+ if (echo && !stream) {
883
+ response_text = prompt + content;
884
+ }
781
885
json res = json {
782
886
{" choices" , json::array ({
783
887
json{
784
- {" text" , stream ? " " : content }, // in stream mode, content is already in last partial chunk
888
+ {" text" , stream ? " " : response_text }, // in stream mode, content is already in last partial chunk
785
889
{" index" , index},
786
890
{" logprobs" , logprobs},
787
891
{" finish_reason" , finish_reason},
@@ -940,6 +1044,10 @@ struct server_task_result_cmpl_partial : server_task_result {
940
1044
std::string oaicompat_cmpl_id;
941
1045
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
942
1046
1047
+ bool echo = false ;
1048
+ std::string prompt_text;
1049
+ bool is_first_chunk = false ;
1050
+
943
1051
virtual int get_index () override {
944
1052
return index;
945
1053
}
@@ -986,14 +1094,21 @@ struct server_task_result_cmpl_partial : server_task_result {
986
1094
std::time_t t = std::time (0 );
987
1095
json logprobs = json (nullptr ); // OAI default to null
988
1096
if (prob_output.probs .size () > 0 ) {
989
- logprobs = json{
990
- {" content" , completion_token_output::probs_vector_to_json ({prob_output}, post_sampling_probs)},
991
- };
1097
+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
1098
+ std::vector<completion_token_output>{prob_output},
1099
+ post_sampling_probs,
1100
+ echo
1101
+ );
1102
+ }
1103
+
1104
+ std::string response_text = content;
1105
+ if (echo && is_first_chunk) {
1106
+ response_text = prompt_text + content;
992
1107
}
993
1108
json res = json {
994
1109
{" choices" , json::array ({
995
1110
json{
996
- {" text" , content },
1111
+ {" text" , response_text },
997
1112
{" index" , index},
998
1113
{" logprobs" , logprobs},
999
1114
{" finish_reason" , nullptr },
@@ -1321,6 +1436,8 @@ struct server_slot {
1321
1436
1322
1437
// input prompt tokens
1323
1438
server_tokens prompt_tokens;
1439
+ std::string prompt_text;
1440
+ std::vector<completion_token_output> prompt_token_probs;
1324
1441
1325
1442
size_t last_nl_pos = 0 ;
1326
1443
@@ -1368,6 +1485,7 @@ struct server_slot {
1368
1485
SLT_DBG (*this , " %s" , " \n " );
1369
1486
1370
1487
n_prompt_tokens = 0 ;
1488
+ prompt_text = " " ;
1371
1489
last_nl_pos = 0 ;
1372
1490
generated_text = " " ;
1373
1491
has_new_line = false ;
@@ -1381,6 +1499,7 @@ struct server_slot {
1381
1499
1382
1500
generated_tokens.clear ();
1383
1501
generated_token_probs.clear ();
1502
+ prompt_token_probs.clear ();
1384
1503
chat_msg = {};
1385
1504
json_schema = json ();
1386
1505
generated_tool_call_ids.clear ();
@@ -2240,6 +2359,77 @@ struct server_context {
2240
2359
slot.params = std::move (task.params );
2241
2360
slot.prompt_tokens = std::move (task.prompt_tokens );
2242
2361
2362
+ if (slot.params .echo ) {
2363
+ slot.prompt_text = slot.prompt_tokens .detokenize (ctx, true );
2364
+
2365
+ if (slot.params .sampling .n_probs > 0 && slot.prompt_tokens .size () > 0 && slot.prompt_token_probs .empty ()) {
2366
+ slot.prompt_token_probs .reserve (slot.prompt_tokens .size ());
2367
+
2368
+ llama_batch batch = llama_batch_init (slot.prompt_tokens .size (), 0 , 1 );
2369
+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2370
+ common_batch_add (batch, slot.prompt_tokens [i], i, {0 }, 1 );
2371
+ }
2372
+
2373
+ if (llama_decode (ctx, batch) == 0 ) {
2374
+ const int n_vocab = llama_vocab_n_tokens (vocab);
2375
+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2376
+ completion_token_output prompt_token;
2377
+ prompt_token.tok = slot.prompt_tokens [i];
2378
+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2379
+
2380
+ if (i > 0 && i < slot.prompt_tokens .size ()) {
2381
+ const float * logits = llama_get_logits_ith (ctx, i - 1 );
2382
+ if (logits != nullptr ) {
2383
+ std::vector<std::pair<float , llama_token>> logits_id;
2384
+ logits_id.reserve (n_vocab);
2385
+
2386
+ for (int j = 0 ; j < n_vocab; j++) {
2387
+ logits_id.emplace_back (logits[j], j);
2388
+ }
2389
+
2390
+ prompt_token.probs .clear ();
2391
+ size_t top_k = std::min (logits_id.size (), static_cast <size_t >(slot.params .sampling .n_probs ));
2392
+ for (size_t k = 0 ; k < top_k; ++k) {
2393
+ completion_token_output::prob_info prob_info;
2394
+ prob_info.tok = logits_id[k].second ;
2395
+ prob_info.prob = logits_id[k].first ;
2396
+ prob_info.txt = common_token_to_piece (ctx, logits_id[k].second , true );
2397
+ prompt_token.probs .push_back (prob_info);
2398
+ }
2399
+
2400
+ auto actual_token_it = std::find_if (logits_id.begin (), logits_id.end (),
2401
+ [&](const std::pair<float , llama_token> & pair) {
2402
+ return pair.second == slot.prompt_tokens [i];
2403
+ });
2404
+
2405
+ if (actual_token_it != logits_id.end ()) {
2406
+ prompt_token.prob = actual_token_it->first ;
2407
+ } else {
2408
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2409
+ }
2410
+ } else {
2411
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2412
+ }
2413
+ } else {
2414
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2415
+ }
2416
+
2417
+ slot.prompt_token_probs .push_back (prompt_token);
2418
+ }
2419
+ } else {
2420
+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2421
+ completion_token_output prompt_token;
2422
+ prompt_token.tok = slot.prompt_tokens [i];
2423
+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2424
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2425
+ slot.prompt_token_probs .push_back (prompt_token);
2426
+ }
2427
+ }
2428
+
2429
+ llama_batch_free (batch);
2430
+ }
2431
+ }
2432
+
2243
2433
if (!are_lora_equal (slot.params .lora , slot.lora )) {
2244
2434
// if lora is changed, we cannot reuse cached tokens
2245
2435
slot.cache_tokens .clear ();
@@ -2529,6 +2719,10 @@ struct server_context {
2529
2719
res->content = tkn.text_to_send ;
2530
2720
res->tokens = { tkn.tok };
2531
2721
2722
+ res->echo = slot.params .echo ;
2723
+ res->prompt_text = slot.prompt_text ;
2724
+ res->is_first_chunk = (slot.n_decoded == 1 );
2725
+
2532
2726
res->n_decoded = slot.n_decoded ;
2533
2727
res->n_prompt_tokens = slot.n_prompt_tokens ;
2534
2728
res->post_sampling_probs = slot.params .post_sampling_probs ;
@@ -2562,7 +2756,9 @@ struct server_context {
2562
2756
res->content = slot.generated_text ;
2563
2757
res->tokens = std::move (slot.generated_tokens );
2564
2758
res->timings = slot.get_timings ();
2565
- res->prompt = slot.prompt_tokens .detokenize (ctx, true );
2759
+
2760
+ res->echo = slot.params .echo ;
2761
+ res->prompt = slot.params .echo ? slot.prompt_text : slot.prompt_tokens .detokenize (ctx, true );
2566
2762
res->response_fields = std::move (slot.params .response_fields );
2567
2763
2568
2764
res->truncated = slot.truncated ;
@@ -2595,6 +2791,10 @@ struct server_context {
2595
2791
slot.generated_token_probs .begin (),
2596
2792
slot.generated_token_probs .end ());
2597
2793
}
2794
+
2795
+ if (slot.params .echo && !slot.prompt_token_probs .empty ()) {
2796
+ res->prompt_probs_output = slot.prompt_token_probs ;
2797
+ }
2598
2798
}
2599
2799
2600
2800
res->generation_params = slot.params ; // copy the parameters
0 commit comments