diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 39ef439e94f97..ed8b31595fc92 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -130,6 +130,9 @@ struct slot_params { bool timings_per_token = false; bool post_sampling_probs = false; + // stream_options for OAI compatible chat/completions streaming + json stream_options = json::object(); + struct common_params_sampling sampling; struct common_params_speculative speculative; @@ -321,6 +324,7 @@ struct server_task { //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); params.response_fields = json_value(data, "response_fields", std::vector()); + params.stream_options = json_value(data, "stream_options", json::object()); params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); @@ -984,22 +988,35 @@ struct server_task_result_cmpl_final : server_task_result { // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices - deltas.push_back({ - {"choices", json::array()}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); + + // Check if stream_options.include_usage is true + bool include_usage = false; + if (generation_params.stream_options.contains("include_usage")) { + include_usage = json_value(generation_params.stream_options, "include_usage", false); + } + + if (include_usage) { + deltas.push_back({ + {"choices", json::array()}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"usage", json { + {"completion_tokens", n_decoded}, + {"prompt_tokens", n_prompt_tokens}, + {"total_tokens", n_decoded + n_prompt_tokens}, + }}, + }); + } if (timings.prompt_n >= 0) { - deltas.back().push_back({"timings", timings.to_json()}); + if (include_usage && !deltas.empty()) { + deltas.back().push_back({"timings", timings.to_json()}); + } else if (!include_usage && deltas.size() >= 2) { + deltas[deltas.size() - 2].push_back({"timings", timings.to_json()}); + } } // extra fields for debugging purposes diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index 53421d1b57351..68b7677aea4de 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -93,8 +93,9 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte assert choice["finish_reason"] is None content += choice["delta"]["content"] or '' else: - assert data["usage"]["prompt_tokens"] == n_prompt - assert data["usage"]["completion_tokens"] == n_predicted + # After our changes, usage is only included when stream_options.include_usage is true + # Since this test doesn't specify stream_options, no usage should be present + assert "usage" not in data def test_chat_completion_with_openai_library(): @@ -450,3 +451,99 @@ def make_cmpl_request(): assert last_progress["total"] > 0 assert last_progress["processed"] == last_progress["total"] assert total_batch_count == batch_count + + +@pytest.mark.parametrize( + "include_usage,expect_usage_chunk", + [ + (True, True), # stream_options.include_usage = true should include usage + (False, False), # stream_options.include_usage = false should NOT include usage + ] +) +def test_chat_completion_stream_options_include_usage(include_usage: bool, expect_usage_chunk: bool): + global server + server.start() + + # Create the request data + request_data = { + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + "stream": True, + "stream_options": { + "include_usage": include_usage + } + } + + res = server.make_stream_request("POST", "/chat/completions", data=request_data) + + found_usage_chunk = False + content = "" + last_cmpl_id = None + + for i, data in enumerate(res): + if data["choices"]: + choice = data["choices"][0] + if i == 0: + # Check first role message for stream=True + assert choice["delta"]["content"] is None + assert choice["delta"]["role"] == "assistant" + else: + assert "role" not in choice["delta"] + assert data["system_fingerprint"].startswith("b") + assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL + if last_cmpl_id is None: + last_cmpl_id = data["id"] + assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream + if choice["finish_reason"] in ["stop", "length"]: + assert "content" not in choice["delta"] + assert choice["finish_reason"] == "length" + else: + assert choice["finish_reason"] is None + content += choice["delta"]["content"] or '' + else: + # This is the final chunk with empty choices - should contain usage if include_usage is true + found_usage_chunk = True + if expect_usage_chunk: + assert "usage" in data + assert "prompt_tokens" in data["usage"] + assert "completion_tokens" in data["usage"] + assert "total_tokens" in data["usage"] + assert data["usage"]["total_tokens"] == data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"] + else: + assert "usage" not in data + + # We should only find a usage chunk if include_usage is true + assert found_usage_chunk == expect_usage_chunk + + +def test_chat_completion_stream_without_stream_options(): + """Test that streaming without stream_options behaves as before (no usage included)""" + global server + server.start() + + request_data = { + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + "stream": True, + # No stream_options provided + } + + res = server.make_stream_request("POST", "/chat/completions", data=request_data) + + found_usage_chunk = False + + for data in res: + if not data["choices"]: + # This is the final chunk with empty choices + found_usage_chunk = True + # Should not contain usage when stream_options is not provided + assert "usage" not in data + + # Should not find any usage chunk when stream_options is not provided + assert not found_usage_chunk