Skip to content

Commit 2b6b55a

Browse files
authored
server : include usage statistics only when user request them (ggml-org#16052)
* server : include usage statistics only when user request them When serving the OpenAI compatible API, we should check if {"stream_options": {"include_usage": true} is set in the request when deciding whether we should send usage statistics closes: ggml-org#16048 * add unit test
1 parent e58174c commit 2b6b55a

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

tools/server/server.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ static bool server_task_type_need_logits(server_task_type task_type) {
111111

112112
struct slot_params {
113113
bool stream = true;
114+
bool include_usage = false;
114115
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
115116
bool return_tokens = false;
116117
bool return_progress = false;
@@ -310,17 +311,19 @@ struct server_task {
310311
params.verbose = params_base.verbosity > 9;
311312
params.timings_per_token = json_value(data, "timings_per_token", false);
312313

313-
params.stream = json_value(data, "stream", false);
314-
params.cache_prompt = json_value(data, "cache_prompt", true);
315-
params.return_tokens = json_value(data, "return_tokens", false);
316-
params.return_progress = json_value(data, "return_progress", false);
317-
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
318-
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
319-
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
320-
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
321-
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
322-
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
323-
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
314+
params.stream = json_value(data, "stream", false);
315+
auto stream_opt = json_value(data, "stream_options", json::object());
316+
params.include_usage = json_value(stream_opt, "include_usage", false);
317+
params.cache_prompt = json_value(data, "cache_prompt", true);
318+
params.return_tokens = json_value(data, "return_tokens", false);
319+
params.return_progress = json_value(data, "return_progress", false);
320+
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
321+
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
322+
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
323+
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
324+
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
325+
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
326+
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
324327

325328
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
326329
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
@@ -775,6 +778,7 @@ struct server_task_result_cmpl_final : server_task_result {
775778
llama_tokens tokens;
776779

777780
bool stream;
781+
bool include_usage;
778782
result_timings timings;
779783
std::string prompt;
780784

@@ -982,21 +986,23 @@ struct server_task_result_cmpl_final : server_task_result {
982986
{"object", "chat.completion.chunk"},
983987
});
984988

985-
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
986-
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
987-
deltas.push_back({
988-
{"choices", json::array()},
989-
{"created", t},
990-
{"id", oaicompat_cmpl_id},
991-
{"model", oaicompat_model},
992-
{"system_fingerprint", build_info},
993-
{"object", "chat.completion.chunk"},
994-
{"usage", json {
995-
{"completion_tokens", n_decoded},
996-
{"prompt_tokens", n_prompt_tokens},
997-
{"total_tokens", n_decoded + n_prompt_tokens},
998-
}},
999-
});
989+
if (include_usage) {
990+
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
991+
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
992+
deltas.push_back({
993+
{"choices", json::array()},
994+
{"created", t},
995+
{"id", oaicompat_cmpl_id},
996+
{"model", oaicompat_model},
997+
{"system_fingerprint", build_info},
998+
{"object", "chat.completion.chunk"},
999+
{"usage", json {
1000+
{"completion_tokens", n_decoded},
1001+
{"prompt_tokens", n_prompt_tokens},
1002+
{"total_tokens", n_decoded + n_prompt_tokens},
1003+
}},
1004+
});
1005+
}
10001006

10011007
if (timings.prompt_n >= 0) {
10021008
deltas.back().push_back({"timings", timings.to_json()});
@@ -2815,6 +2821,7 @@ struct server_context {
28152821

28162822
res->verbose = slot.params.verbose;
28172823
res->stream = slot.params.stream;
2824+
res->include_usage = slot.params.include_usage;
28182825
res->oaicompat = slot.params.oaicompat;
28192826
res->oaicompat_model = slot.params.oaicompat_model;
28202827
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;

tools/server/tests/unit/test_chat_completion.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,10 @@ def test_chat_completion_with_timings_per_token():
271271
"max_tokens": 10,
272272
"messages": [{"role": "user", "content": "test"}],
273273
"stream": True,
274+
"stream_options": {"include_usage": True},
274275
"timings_per_token": True,
275276
})
277+
stats_received = False
276278
for i, data in enumerate(res):
277279
if i == 0:
278280
# Check first role message for stream=True
@@ -288,6 +290,8 @@ def test_chat_completion_with_timings_per_token():
288290
assert "predicted_per_second" in data["timings"]
289291
assert "predicted_n" in data["timings"]
290292
assert data["timings"]["predicted_n"] <= 10
293+
stats_received = True
294+
assert stats_received
291295

292296

293297
def test_logprobs():

0 commit comments

Comments
 (0)