@@ -111,6 +111,7 @@ static bool server_task_type_need_logits(server_task_type task_type) {
111111
112112struct slot_params {
113113 bool stream = true ;
114+ bool include_usage = false ;
114115 bool cache_prompt = true ; // remember the prompt to avoid reprocessing all prompt
115116 bool return_tokens = false ;
116117 bool return_progress = false ;
@@ -310,17 +311,19 @@ struct server_task {
310311 params.verbose = params_base.verbosity > 9 ;
311312 params.timings_per_token = json_value (data, " timings_per_token" , false );
312313
313- params.stream = json_value (data, " stream" , false );
314- params.cache_prompt = json_value (data, " cache_prompt" , true );
315- params.return_tokens = json_value (data, " return_tokens" , false );
316- params.return_progress = json_value (data, " return_progress" , false );
317- params.n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , defaults.n_predict ));
318- params.n_indent = json_value (data, " n_indent" , defaults.n_indent );
319- params.n_keep = json_value (data, " n_keep" , defaults.n_keep );
320- params.n_discard = json_value (data, " n_discard" , defaults.n_discard );
321- // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
322- params.t_max_predict_ms = json_value (data, " t_max_predict_ms" , defaults.t_max_predict_ms );
323- params.response_fields = json_value (data, " response_fields" , std::vector<std::string>());
314+ params.stream = json_value (data, " stream" , false );
315+ auto stream_opt = json_value (data, " stream_options" , json::object ());
316+ params.include_usage = json_value (stream_opt, " include_usage" , false );
317+ params.cache_prompt = json_value (data, " cache_prompt" , true );
318+ params.return_tokens = json_value (data, " return_tokens" , false );
319+ params.return_progress = json_value (data, " return_progress" , false );
320+ params.n_predict = json_value (data, " n_predict" , json_value (data, " max_tokens" , defaults.n_predict ));
321+ params.n_indent = json_value (data, " n_indent" , defaults.n_indent );
322+ params.n_keep = json_value (data, " n_keep" , defaults.n_keep );
323+ params.n_discard = json_value (data, " n_discard" , defaults.n_discard );
324+ // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
325+ params.t_max_predict_ms = json_value (data, " t_max_predict_ms" , defaults.t_max_predict_ms );
326+ params.response_fields = json_value (data, " response_fields" , std::vector<std::string>());
324327
325328 params.sampling .top_k = json_value (data, " top_k" , defaults.sampling .top_k );
326329 params.sampling .top_p = json_value (data, " top_p" , defaults.sampling .top_p );
@@ -775,6 +778,7 @@ struct server_task_result_cmpl_final : server_task_result {
775778 llama_tokens tokens;
776779
777780 bool stream;
781+ bool include_usage;
778782 result_timings timings;
779783 std::string prompt;
780784
@@ -982,21 +986,23 @@ struct server_task_result_cmpl_final : server_task_result {
982986 {" object" , " chat.completion.chunk" },
983987 });
984988
985- // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
986- // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
987- deltas.push_back ({
988- {" choices" , json::array ()},
989- {" created" , t},
990- {" id" , oaicompat_cmpl_id},
991- {" model" , oaicompat_model},
992- {" system_fingerprint" , build_info},
993- {" object" , " chat.completion.chunk" },
994- {" usage" , json {
995- {" completion_tokens" , n_decoded},
996- {" prompt_tokens" , n_prompt_tokens},
997- {" total_tokens" , n_decoded + n_prompt_tokens},
998- }},
999- });
989+ if (include_usage) {
990+ // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
991+ // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
992+ deltas.push_back ({
993+ {" choices" , json::array ()},
994+ {" created" , t},
995+ {" id" , oaicompat_cmpl_id},
996+ {" model" , oaicompat_model},
997+ {" system_fingerprint" , build_info},
998+ {" object" , " chat.completion.chunk" },
999+ {" usage" , json {
1000+ {" completion_tokens" , n_decoded},
1001+ {" prompt_tokens" , n_prompt_tokens},
1002+ {" total_tokens" , n_decoded + n_prompt_tokens},
1003+ }},
1004+ });
1005+ }
10001006
10011007 if (timings.prompt_n >= 0 ) {
10021008 deltas.back ().push_back ({" timings" , timings.to_json ()});
@@ -2815,6 +2821,7 @@ struct server_context {
28152821
28162822 res->verbose = slot.params .verbose ;
28172823 res->stream = slot.params .stream ;
2824+ res->include_usage = slot.params .include_usage ;
28182825 res->oaicompat = slot.params .oaicompat ;
28192826 res->oaicompat_model = slot.params .oaicompat_model ;
28202827 res->oaicompat_cmpl_id = slot.params .oaicompat_cmpl_id ;
0 commit comments