Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ struct slot_params {
bool timings_per_token = false;
bool post_sampling_probs = false;

// stream_options for OAI compatible chat/completions streaming
json stream_options = json::object();

struct common_params_sampling sampling;
struct common_params_speculative speculative;

Expand Down Expand Up @@ -321,6 +324,7 @@ struct server_task {
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
params.stream_options = json_value(data, "stream_options", json::object());

params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
Expand Down Expand Up @@ -984,22 +988,35 @@ struct server_task_result_cmpl_final : server_task_result {

// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
deltas.push_back({
{"choices", json::array()},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
});

// Check if stream_options.include_usage is true
bool include_usage = false;
if (generation_params.stream_options.contains("include_usage")) {
include_usage = json_value(generation_params.stream_options, "include_usage", false);
}

if (include_usage) {
deltas.push_back({
{"choices", json::array()},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"system_fingerprint", build_info},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
});
}

if (timings.prompt_n >= 0) {
deltas.back().push_back({"timings", timings.to_json()});
if (include_usage && !deltas.empty()) {
deltas.back().push_back({"timings", timings.to_json()});
} else if (!include_usage && deltas.size() >= 2) {
deltas[deltas.size() - 2].push_back({"timings", timings.to_json()});
}
}

// extra fields for debugging purposes
Expand Down
101 changes: 99 additions & 2 deletions tools/server/tests/unit/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,9 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
else:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
# After our changes, usage is only included when stream_options.include_usage is true
# Since this test doesn't specify stream_options, no usage should be present
assert "usage" not in data


def test_chat_completion_with_openai_library():
Expand Down Expand Up @@ -450,3 +451,99 @@ def make_cmpl_request():
assert last_progress["total"] > 0
assert last_progress["processed"] == last_progress["total"]
assert total_batch_count == batch_count


@pytest.mark.parametrize(
"include_usage,expect_usage_chunk",
[
(True, True), # stream_options.include_usage = true should include usage
(False, False), # stream_options.include_usage = false should NOT include usage
]
)
def test_chat_completion_stream_options_include_usage(include_usage: bool, expect_usage_chunk: bool):
global server
server.start()

# Create the request data
request_data = {
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
"stream": True,
"stream_options": {
"include_usage": include_usage
}
}

res = server.make_stream_request("POST", "/chat/completions", data=request_data)

found_usage_chunk = False
content = ""
last_cmpl_id = None

for i, data in enumerate(res):
if data["choices"]:
choice = data["choices"][0]
if i == 0:
# Check first role message for stream=True
assert choice["delta"]["content"] is None
assert choice["delta"]["role"] == "assistant"
else:
assert "role" not in choice["delta"]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]:
assert "content" not in choice["delta"]
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"] or ''
else:
# This is the final chunk with empty choices - should contain usage if include_usage is true
found_usage_chunk = True
if expect_usage_chunk:
assert "usage" in data
assert "prompt_tokens" in data["usage"]
assert "completion_tokens" in data["usage"]
assert "total_tokens" in data["usage"]
assert data["usage"]["total_tokens"] == data["usage"]["prompt_tokens"] + data["usage"]["completion_tokens"]
else:
assert "usage" not in data

# We should only find a usage chunk if include_usage is true
assert found_usage_chunk == expect_usage_chunk


def test_chat_completion_stream_without_stream_options():
"""Test that streaming without stream_options behaves as before (no usage included)"""
global server
server.start()

request_data = {
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
"stream": True,
# No stream_options provided
}

res = server.make_stream_request("POST", "/chat/completions", data=request_data)

found_usage_chunk = False

for data in res:
if not data["choices"]:
# This is the final chunk with empty choices
found_usage_chunk = True
# Should not contain usage when stream_options is not provided
assert "usage" not in data

# Should not find any usage chunk when stream_options is not provided
assert not found_usage_chunk