Skip to content

Commit 1bc664a

Browse files
authored
server: fix OpenAI API compatibility for usage statistics in chat streams (ggml-org#15444)
1 parent 13aeb7a commit 1bc664a

File tree

3 files changed

+105
-82
lines changed

3 files changed

+105
-82
lines changed

tools/server/server.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,17 @@ struct server_task_result_cmpl_final : server_task_result {
911911
{"model", oaicompat_model},
912912
{"system_fingerprint", build_info},
913913
{"object", "chat.completion.chunk"},
914+
});
915+
916+
// OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage
917+
// https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices
918+
deltas.push_back({
919+
{"choices", json::array()},
920+
{"created", t},
921+
{"id", oaicompat_cmpl_id},
922+
{"model", oaicompat_model},
923+
{"system_fingerprint", build_info},
924+
{"object", "chat.completion.chunk"},
914925
{"usage", json {
915926
{"completion_tokens", n_decoded},
916927
{"prompt_tokens", n_prompt_tokens},

tools/server/tests/unit/test_chat_completion.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -72,27 +72,29 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
7272
content = ""
7373
last_cmpl_id = None
7474
for i, data in enumerate(res):
75-
choice = data["choices"][0]
76-
if i == 0:
77-
# Check first role message for stream=True
78-
assert choice["delta"]["content"] is None
79-
assert choice["delta"]["role"] == "assistant"
75+
if data["choices"]:
76+
choice = data["choices"][0]
77+
if i == 0:
78+
# Check first role message for stream=True
79+
assert choice["delta"]["content"] is None
80+
assert choice["delta"]["role"] == "assistant"
81+
else:
82+
assert "role" not in choice["delta"]
83+
assert data["system_fingerprint"].startswith("b")
84+
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
85+
if last_cmpl_id is None:
86+
last_cmpl_id = data["id"]
87+
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
88+
if choice["finish_reason"] in ["stop", "length"]:
89+
assert "content" not in choice["delta"]
90+
assert match_regex(re_content, content)
91+
assert choice["finish_reason"] == finish_reason
92+
else:
93+
assert choice["finish_reason"] is None
94+
content += choice["delta"]["content"] or ''
8095
else:
81-
assert "role" not in choice["delta"]
82-
assert data["system_fingerprint"].startswith("b")
83-
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
84-
if last_cmpl_id is None:
85-
last_cmpl_id = data["id"]
86-
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
87-
if choice["finish_reason"] in ["stop", "length"]:
8896
assert data["usage"]["prompt_tokens"] == n_prompt
8997
assert data["usage"]["completion_tokens"] == n_predicted
90-
assert "content" not in choice["delta"]
91-
assert match_regex(re_content, content)
92-
assert choice["finish_reason"] == finish_reason
93-
else:
94-
assert choice["finish_reason"] is None
95-
content += choice["delta"]["content"] or ''
9698

9799

98100
def test_chat_completion_with_openai_library():
@@ -278,12 +280,14 @@ def test_chat_completion_with_timings_per_token():
278280
assert data["choices"][0]["delta"]["role"] == "assistant"
279281
assert "timings" not in data, f'First event should not have timings: {data}'
280282
else:
281-
assert "role" not in data["choices"][0]["delta"]
282-
assert "timings" in data
283-
assert "prompt_per_second" in data["timings"]
284-
assert "predicted_per_second" in data["timings"]
285-
assert "predicted_n" in data["timings"]
286-
assert data["timings"]["predicted_n"] <= 10
283+
if data["choices"]:
284+
assert "role" not in data["choices"][0]["delta"]
285+
else:
286+
assert "timings" in data
287+
assert "prompt_per_second" in data["timings"]
288+
assert "predicted_per_second" in data["timings"]
289+
assert "predicted_n" in data["timings"]
290+
assert data["timings"]["predicted_n"] <= 10
287291

288292

289293
def test_logprobs():
@@ -332,24 +336,25 @@ def test_logprobs_stream():
332336
output_text = ''
333337
aggregated_text = ''
334338
for i, data in enumerate(res):
335-
choice = data.choices[0]
336-
if i == 0:
337-
# Check first role message for stream=True
338-
assert choice.delta.content is None
339-
assert choice.delta.role == "assistant"
340-
else:
341-
assert choice.delta.role is None
342-
if choice.finish_reason is None:
343-
if choice.delta.content:
344-
output_text += choice.delta.content
345-
assert choice.logprobs is not None
346-
assert choice.logprobs.content is not None
347-
for token in choice.logprobs.content:
348-
aggregated_text += token.token
349-
assert token.logprob <= 0.0
350-
assert token.bytes is not None
351-
assert token.top_logprobs is not None
352-
assert len(token.top_logprobs) > 0
339+
if data.choices:
340+
choice = data.choices[0]
341+
if i == 0:
342+
# Check first role message for stream=True
343+
assert choice.delta.content is None
344+
assert choice.delta.role == "assistant"
345+
else:
346+
assert choice.delta.role is None
347+
if choice.finish_reason is None:
348+
if choice.delta.content:
349+
output_text += choice.delta.content
350+
assert choice.logprobs is not None
351+
assert choice.logprobs.content is not None
352+
for token in choice.logprobs.content:
353+
aggregated_text += token.token
354+
assert token.logprob <= 0.0
355+
assert token.bytes is not None
356+
assert token.top_logprobs is not None
357+
assert len(token.top_logprobs) > 0
353358
assert aggregated_text == output_text
354359

355360

tools/server/tests/utils.py

Lines changed: 47 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -318,46 +318,53 @@ def make_any_request(
318318
arguments_parts = 0
319319

320320
for chunk in self.make_stream_request(method, path, data, headers):
321-
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
322-
choice = chunk['choices'][0]
323-
if choice['delta'].get('content') is not None:
324-
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
325-
content.append(choice['delta']['content'])
326-
content_parts += 1
327-
if choice['delta'].get('reasoning_content') is not None:
328-
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
329-
reasoning_content.append(choice['delta']['reasoning_content'])
330-
reasoning_content_parts += 1
331-
if choice['delta'].get('finish_reason') is not None:
332-
finish_reason = choice['delta']['finish_reason']
333-
for tc in choice['delta'].get('tool_calls', []):
334-
if 'function' not in tc:
335-
raise ValueError(f"Expected function type, got {tc['type']}")
336-
if tc['index'] >= len(tool_calls):
337-
assert 'id' in tc
338-
assert tc.get('type') == 'function'
339-
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
340-
f"Expected function call with name, got {tc.get('function')}"
341-
tool_calls.append(dict(
342-
id="",
343-
type="function",
344-
function=dict(
345-
name="",
346-
arguments="",
347-
)
348-
))
349-
tool_call = tool_calls[tc['index']]
350-
if tc.get('id') is not None:
351-
tool_call['id'] = tc['id']
352-
fct = tc['function']
353-
assert 'id' not in fct, f"Function call should not have id: {fct}"
354-
if fct.get('name') is not None:
355-
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
356-
if fct.get('arguments') is not None:
357-
tool_call['function']['arguments'] += fct['arguments']
358-
arguments_parts += 1
359-
tool_call_parts += 1
360-
321+
if chunk['choices']:
322+
assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}'
323+
choice = chunk['choices'][0]
324+
if choice['delta'].get('content') is not None:
325+
assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!'
326+
content.append(choice['delta']['content'])
327+
content_parts += 1
328+
if choice['delta'].get('reasoning_content') is not None:
329+
assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!'
330+
reasoning_content.append(choice['delta']['reasoning_content'])
331+
reasoning_content_parts += 1
332+
if choice['delta'].get('finish_reason') is not None:
333+
finish_reason = choice['delta']['finish_reason']
334+
for tc in choice['delta'].get('tool_calls', []):
335+
if 'function' not in tc:
336+
raise ValueError(f"Expected function type, got {tc['type']}")
337+
if tc['index'] >= len(tool_calls):
338+
assert 'id' in tc
339+
assert tc.get('type') == 'function'
340+
assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \
341+
f"Expected function call with name, got {tc.get('function')}"
342+
tool_calls.append(dict(
343+
id="",
344+
type="function",
345+
function=dict(
346+
name="",
347+
arguments="",
348+
)
349+
))
350+
tool_call = tool_calls[tc['index']]
351+
if tc.get('id') is not None:
352+
tool_call['id'] = tc['id']
353+
fct = tc['function']
354+
assert 'id' not in fct, f"Function call should not have id: {fct}"
355+
if fct.get('name') is not None:
356+
tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name']
357+
if fct.get('arguments') is not None:
358+
tool_call['function']['arguments'] += fct['arguments']
359+
arguments_parts += 1
360+
tool_call_parts += 1
361+
else:
362+
# When `include_usage` is True (the default), we expect the last chunk of the stream
363+
# immediately preceding the `data: [DONE]` message to contain a `choices` field with an empty array
364+
# and a `usage` field containing the usage statistics (n.b., llama-server also returns `timings` in
365+
# the last chunk)
366+
assert 'usage' in chunk, f"Expected finish_reason in chunk: {chunk}"
367+
assert 'timings' in chunk, f"Expected finish_reason in chunk: {chunk}"
361368
print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts')
362369
result = dict(
363370
choices=[

0 commit comments

Comments
 (0)