@@ -112,6 +112,7 @@ struct slot_params {
112
112
bool stream = true ;
113
113
bool cache_prompt = true ; // remember the prompt to avoid reprocessing all prompt
114
114
bool return_tokens = false ;
115
+ bool echo = false ;
115
116
116
117
int32_t n_keep = 0 ; // number of tokens to keep from initial prompt
117
118
int32_t n_discard = 0 ; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
@@ -160,6 +161,7 @@ struct slot_params {
160
161
}
161
162
162
163
return json {
164
+ {" echo" , echo},
163
165
{" n_predict" , n_predict}, // Server configured n_predict
164
166
{" seed" , sampling.seed },
165
167
{" temperature" , sampling.temp },
@@ -265,6 +267,7 @@ struct server_task {
265
267
params.stream = json_value (data, " stream" , false );
266
268
params.cache_prompt = json_value (data, " cache_prompt" , true );
267
269
params.return_tokens = json_value (data, " return_tokens" , false );
270
+ params.echo = json_value (data, " echo" , false );
268
271
params.n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , defaults.n_predict ));
269
272
params.n_indent = json_value (data, " n_indent" , defaults.n_indent );
270
273
params.n_keep = json_value (data, " n_keep" , defaults.n_keep );
@@ -674,6 +677,91 @@ struct completion_token_output {
674
677
return out;
675
678
}
676
679
680
+ static json oaicompat_probs_vector_to_json (
681
+ const std::vector<completion_token_output> & probs_out,
682
+ bool post_sampling_probs,
683
+ bool echo,
684
+ const std::vector<completion_token_output> & prompt_probs = {}
685
+ ) {
686
+ json out = json::object ();
687
+
688
+ std::vector<std::string> tokens;
689
+ std::vector<completion_token_output> all_probs;
690
+
691
+ if (echo && !prompt_probs.empty ()) {
692
+ all_probs.insert (all_probs.end (), prompt_probs.begin (), prompt_probs.end ());
693
+ }
694
+
695
+ all_probs.insert (all_probs.end (), probs_out.begin (), probs_out.end ());
696
+
697
+ tokens.reserve (all_probs.size ());
698
+ for (const auto & p : all_probs) {
699
+ std::string piece = p.text_to_send ;
700
+ piece.resize (validate_utf8 (piece));
701
+ tokens.push_back (piece);
702
+ }
703
+
704
+ int text_offset = 0 ;
705
+ std::vector<int > text_offsets;
706
+ text_offsets.reserve (tokens.size ());
707
+
708
+ int current_off = text_offset;
709
+ for (const auto & tok : tokens) {
710
+ text_offsets.push_back (current_off);
711
+ current_off += static_cast <int >(tok.size ());
712
+ }
713
+
714
+ std::vector<std::optional<float >> token_logprobs;
715
+ token_logprobs.reserve (all_probs.size ());
716
+
717
+ std::vector<std::optional<std::unordered_map<std::string, float >>> top_logprobs;
718
+ top_logprobs.reserve (all_probs.size ());
719
+
720
+ for (size_t i = 0 ; i < all_probs.size (); ++i) {
721
+ const auto & p = all_probs[i];
722
+
723
+ if (std::isinf (p.prob ) && p.prob < 0 ) {
724
+ token_logprobs.push_back (std::nullopt);
725
+ top_logprobs.push_back (std::nullopt);
726
+ } else {
727
+ float logprob_value = p.prob ;
728
+ if (!post_sampling_probs) {
729
+ logprob_value = p.prob ;
730
+ } else {
731
+ logprob_value = p.prob > 0 .0f ? std::log (p.prob ) : -std::numeric_limits<float >::infinity ();
732
+ }
733
+
734
+ token_logprobs.push_back (std::optional<float >(logprob_value));
735
+
736
+ std::unordered_map<std::string, float > top_map;
737
+ for (const auto & cand : p.probs ) {
738
+ std::string cand_txt = cand.txt ;
739
+ cand_txt.resize (validate_utf8 (cand_txt));
740
+
741
+ float cand_logprob;
742
+ if (!post_sampling_probs) {
743
+ cand_logprob = cand.prob ;
744
+ } else {
745
+ cand_logprob = cand.prob > 0 .0f ? std::log (cand.prob ) : -std::numeric_limits<float >::infinity ();
746
+ }
747
+
748
+ top_map[cand_txt] = cand_logprob;
749
+ }
750
+
751
+ top_logprobs.push_back (std::move (top_map));
752
+ }
753
+ }
754
+
755
+ out = json{
756
+ {" text_offset" , text_offsets},
757
+ {" token_logprobs" , token_logprobs},
758
+ {" tokens" , tokens},
759
+ {" top_logprobs" , top_logprobs}
760
+ };
761
+
762
+ return out;
763
+ }
764
+
677
765
static float logarithm (float x) {
678
766
// nlohmann::json converts -inf to null, so we need to prevent that
679
767
return x == 0 .0f ? std::numeric_limits<float >::lowest () : std::log (x);
@@ -697,6 +785,7 @@ struct server_task_result_cmpl_final : server_task_result {
697
785
bool stream;
698
786
result_timings timings;
699
787
std::string prompt;
788
+ bool echo = false ;
700
789
701
790
bool truncated;
702
791
int32_t n_decoded;
@@ -708,6 +797,7 @@ struct server_task_result_cmpl_final : server_task_result {
708
797
709
798
bool post_sampling_probs;
710
799
std::vector<completion_token_output> probs_output;
800
+ std::vector<completion_token_output> prompt_probs_output;
711
801
std::vector<std::string> response_fields;
712
802
713
803
slot_params generation_params;
@@ -769,19 +859,26 @@ struct server_task_result_cmpl_final : server_task_result {
769
859
json to_json_oaicompat () {
770
860
std::time_t t = std::time (0 );
771
861
json logprobs = json (nullptr ); // OAI default to null
772
- if (!stream && probs_output.size () > 0 ) {
773
- logprobs = json{
774
- {" content" , completion_token_output::probs_vector_to_json (probs_output, post_sampling_probs)},
775
- };
862
+ if (!stream && (probs_output.size () > 0 || (echo && prompt_probs_output.size () > 0 ))) {
863
+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
864
+ probs_output,
865
+ post_sampling_probs,
866
+ echo,
867
+ prompt_probs_output
868
+ );
776
869
}
777
870
json finish_reason = " length" ;
778
871
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
779
872
finish_reason = " stop" ;
780
873
}
874
+ std::string response_text = content;
875
+ if (echo && !stream) {
876
+ response_text = prompt + content;
877
+ }
781
878
json res = json {
782
879
{" choices" , json::array ({
783
880
json{
784
- {" text" , stream ? " " : content }, // in stream mode, content is already in last partial chunk
881
+ {" text" , stream ? " " : response_text }, // in stream mode, content is already in last partial chunk
785
882
{" index" , index},
786
883
{" logprobs" , logprobs},
787
884
{" finish_reason" , finish_reason},
@@ -940,6 +1037,10 @@ struct server_task_result_cmpl_partial : server_task_result {
940
1037
std::string oaicompat_cmpl_id;
941
1038
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
942
1039
1040
+ bool echo = false ;
1041
+ std::string prompt_text;
1042
+ bool is_first_chunk = false ;
1043
+
943
1044
virtual int get_index () override {
944
1045
return index;
945
1046
}
@@ -986,14 +1087,21 @@ struct server_task_result_cmpl_partial : server_task_result {
986
1087
std::time_t t = std::time (0 );
987
1088
json logprobs = json (nullptr ); // OAI default to null
988
1089
if (prob_output.probs .size () > 0 ) {
989
- logprobs = json{
990
- {" content" , completion_token_output::probs_vector_to_json ({prob_output}, post_sampling_probs)},
991
- };
1090
+ logprobs = completion_token_output::oaicompat_probs_vector_to_json (
1091
+ std::vector<completion_token_output>{prob_output},
1092
+ post_sampling_probs,
1093
+ echo
1094
+ );
1095
+ }
1096
+
1097
+ std::string response_text = content;
1098
+ if (echo && is_first_chunk) {
1099
+ response_text = prompt_text + content;
992
1100
}
993
1101
json res = json {
994
1102
{" choices" , json::array ({
995
1103
json{
996
- {" text" , content },
1104
+ {" text" , response_text },
997
1105
{" index" , index},
998
1106
{" logprobs" , logprobs},
999
1107
{" finish_reason" , nullptr },
@@ -1321,6 +1429,8 @@ struct server_slot {
1321
1429
1322
1430
// input prompt tokens
1323
1431
server_tokens prompt_tokens;
1432
+ std::string prompt_text;
1433
+ std::vector<completion_token_output> prompt_token_probs;
1324
1434
1325
1435
size_t last_nl_pos = 0 ;
1326
1436
@@ -1368,6 +1478,7 @@ struct server_slot {
1368
1478
SLT_DBG (*this , " %s" , " \n " );
1369
1479
1370
1480
n_prompt_tokens = 0 ;
1481
+ prompt_text = " " ;
1371
1482
last_nl_pos = 0 ;
1372
1483
generated_text = " " ;
1373
1484
has_new_line = false ;
@@ -1381,6 +1492,7 @@ struct server_slot {
1381
1492
1382
1493
generated_tokens.clear ();
1383
1494
generated_token_probs.clear ();
1495
+ prompt_token_probs.clear ();
1384
1496
chat_msg = {};
1385
1497
json_schema = json ();
1386
1498
generated_tool_call_ids.clear ();
@@ -2240,6 +2352,113 @@ struct server_context {
2240
2352
slot.params = std::move (task.params );
2241
2353
slot.prompt_tokens = std::move (task.prompt_tokens );
2242
2354
2355
+ if (slot.params .echo ) {
2356
+ slot.prompt_text = slot.prompt_tokens .detokenize (ctx, true );
2357
+
2358
+ if (slot.params .sampling .n_probs > 0 && slot.prompt_tokens .size () > 1 && slot.prompt_token_probs .empty ()) {
2359
+ slot.prompt_token_probs .reserve (slot.prompt_tokens .size ());
2360
+
2361
+ llama_memory_clear (llama_get_memory (ctx), true );
2362
+
2363
+ const int n_batch = llama_n_batch (ctx);
2364
+ const int num_batches = (slot.prompt_tokens .size () + n_batch - 1 ) / n_batch;
2365
+ const int n_vocab = llama_vocab_n_tokens (vocab);
2366
+
2367
+ std::vector<float > all_logits;
2368
+ if (num_batches > 1 ) {
2369
+ all_logits.reserve (slot.prompt_tokens .size () * n_vocab);
2370
+ }
2371
+
2372
+ for (int batch_idx = 0 ; batch_idx < num_batches; ++batch_idx) {
2373
+ const int batch_start = batch_idx * n_batch;
2374
+ const int batch_size = std::min ((int )slot.prompt_tokens .size () - batch_start, n_batch);
2375
+
2376
+ llama_batch batch = llama_batch_init (batch_size, 0 , 1 );
2377
+ for (int i = 0 ; i < batch_size; ++i) {
2378
+ common_batch_add (batch, slot.prompt_tokens [batch_start + i], batch_start + i, {0 }, true );
2379
+ }
2380
+
2381
+ if (llama_decode (ctx, batch) == 0 ) {
2382
+ const float * batch_logits = llama_get_logits (ctx);
2383
+ if (num_batches > 1 ) {
2384
+ all_logits.insert (all_logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
2385
+ }
2386
+ } else {
2387
+ llama_batch_free (batch);
2388
+ break ;
2389
+ }
2390
+ llama_batch_free (batch);
2391
+ }
2392
+
2393
+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2394
+ completion_token_output prompt_token;
2395
+ prompt_token.tok = slot.prompt_tokens [i];
2396
+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2397
+
2398
+ if (i == 0 ) {
2399
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2400
+ } else {
2401
+ const float * logits = num_batches > 1 ?
2402
+ all_logits.data () + (i - 1 ) * n_vocab :
2403
+ llama_get_logits_ith (ctx, i - 1 );
2404
+
2405
+ if (logits != nullptr ) {
2406
+ float max_logit = logits[0 ];
2407
+ for (int j = 1 ; j < n_vocab; ++j) {
2408
+ max_logit = std::max (max_logit, logits[j]);
2409
+ }
2410
+
2411
+ double sum_exp = 0.0 ;
2412
+ for (int j = 0 ; j < n_vocab; ++j) {
2413
+ sum_exp += expf (logits[j] - max_logit);
2414
+ }
2415
+
2416
+ const float log_sum_exp = max_logit + logf (sum_exp);
2417
+ prompt_token.prob = logits[slot.prompt_tokens [i]] - log_sum_exp;
2418
+
2419
+ if (slot.params .sampling .n_probs > 0 ) {
2420
+ std::vector<std::pair<float , llama_token>> logits_id;
2421
+ logits_id.reserve (n_vocab);
2422
+
2423
+ for (int j = 0 ; j < n_vocab; j++) {
2424
+ const float logprob = logits[j] - log_sum_exp;
2425
+ logits_id.emplace_back (logprob, j);
2426
+ }
2427
+
2428
+ std::partial_sort (logits_id.begin (),
2429
+ logits_id.begin () + std::min ((size_t )slot.params .sampling .n_probs , logits_id.size ()),
2430
+ logits_id.end (),
2431
+ [](const auto & a, const auto & b) { return a.first > b.first ; });
2432
+
2433
+ prompt_token.probs .clear ();
2434
+ size_t top_k = std::min (logits_id.size (), static_cast <size_t >(slot.params .sampling .n_probs ));
2435
+ for (size_t k = 0 ; k < top_k; ++k) {
2436
+ completion_token_output::prob_info prob_info;
2437
+ prob_info.tok = logits_id[k].second ;
2438
+ prob_info.prob = logits_id[k].first ;
2439
+ prob_info.txt = common_token_to_piece (ctx, logits_id[k].second , true );
2440
+ prompt_token.probs .push_back (prob_info);
2441
+ }
2442
+ }
2443
+ } else {
2444
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2445
+ }
2446
+ }
2447
+
2448
+ slot.prompt_token_probs .push_back (prompt_token);
2449
+ }
2450
+ } else {
2451
+ for (size_t i = 0 ; i < slot.prompt_tokens .size (); ++i) {
2452
+ completion_token_output prompt_token;
2453
+ prompt_token.tok = slot.prompt_tokens [i];
2454
+ prompt_token.text_to_send = common_token_to_piece (ctx, slot.prompt_tokens [i], true );
2455
+ prompt_token.prob = -std::numeric_limits<float >::infinity ();
2456
+ slot.prompt_token_probs .push_back (prompt_token);
2457
+ }
2458
+ }
2459
+ }
2460
+
2461
+
2243
2462
if (!are_lora_equal (slot.params .lora , slot.lora )) {
2244
2463
// if lora is changed, we cannot reuse cached tokens
2245
2464
slot.cache_tokens .clear ();
@@ -2529,6 +2748,10 @@ struct server_context {
2529
2748
res->content = tkn.text_to_send ;
2530
2749
res->tokens = { tkn.tok };
2531
2750
2751
+ res->echo = slot.params .echo ;
2752
+ res->prompt_text = slot.prompt_text ;
2753
+ res->is_first_chunk = (slot.n_decoded == 1 );
2754
+
2532
2755
res->n_decoded = slot.n_decoded ;
2533
2756
res->n_prompt_tokens = slot.n_prompt_tokens ;
2534
2757
res->post_sampling_probs = slot.params .post_sampling_probs ;
@@ -2562,7 +2785,9 @@ struct server_context {
2562
2785
res->content = slot.generated_text ;
2563
2786
res->tokens = std::move (slot.generated_tokens );
2564
2787
res->timings = slot.get_timings ();
2565
- res->prompt = slot.prompt_tokens .detokenize (ctx, true );
2788
+
2789
+ res->echo = slot.params .echo ;
2790
+ res->prompt = slot.params .echo ? slot.prompt_text : slot.prompt_tokens .detokenize (ctx, true );
2566
2791
res->response_fields = std::move (slot.params .response_fields );
2567
2792
2568
2793
res->truncated = slot.truncated ;
@@ -2595,6 +2820,10 @@ struct server_context {
2595
2820
slot.generated_token_probs .begin (),
2596
2821
slot.generated_token_probs .end ());
2597
2822
}
2823
+
2824
+ if (slot.params .echo && !slot.prompt_token_probs .empty ()) {
2825
+ res->prompt_probs_output = slot.prompt_token_probs ;
2826
+ }
2598
2827
}
2599
2828
2600
2829
res->generation_params = slot.params ; // copy the parameters
0 commit comments