Skip to content

Commit 13c3768

Browse files
authored
fix: allow support for include_usage in streaming using OpenAIChatGenerator (#8968)
* fix error in handling usage completion chunk
1 parent ab67a76 commit 13c3768

File tree

4 files changed

+44
-16
lines changed

4 files changed

+44
-16
lines changed

haystack/components/generators/chat/openai.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -399,29 +399,39 @@ def _prepare_api_call( # noqa: PLR0913
399399
def _handle_stream_response(self, chat_completion: Stream, callback: SyncStreamingCallbackT) -> List[ChatMessage]:
400400
chunks: List[StreamingChunk] = []
401401
chunk = None
402+
chunk_delta: StreamingChunk
402403

403404
for chunk in chat_completion: # pylint: disable=not-an-iterable
404-
assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
405-
chunk_delta: StreamingChunk = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
405+
# choices is an empty array for usage_chunk when include_usage is set to True
406+
if chunk.usage is not None:
407+
chunk_delta = self._convert_usage_chunk_to_streaming_chunk(chunk)
408+
409+
else:
410+
assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
411+
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
406412
chunks.append(chunk_delta)
407413

408414
callback(chunk_delta)
409-
410415
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
411416

412417
async def _handle_async_stream_response(
413418
self, chat_completion: AsyncStream, callback: AsyncStreamingCallbackT
414419
) -> List[ChatMessage]:
415420
chunks: List[StreamingChunk] = []
416421
chunk = None
422+
chunk_delta: StreamingChunk
417423

418424
async for chunk in chat_completion: # pylint: disable=not-an-iterable
419-
assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
420-
chunk_delta: StreamingChunk = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
425+
# choices is an empty array for usage_chunk when include_usage is set to True
426+
if chunk.usage is not None:
427+
chunk_delta = self._convert_usage_chunk_to_streaming_chunk(chunk)
428+
429+
else:
430+
assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
431+
chunk_delta = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
421432
chunks.append(chunk_delta)
422433

423434
await callback(chunk_delta)
424-
425435
return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
426436

427437
def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
@@ -447,6 +457,8 @@ def _convert_streaming_chunks_to_chat_message(
447457
448458
:param chunk: The last chunk returned by the OpenAI API.
449459
:param chunks: The list of all `StreamingChunk` objects.
460+
461+
:returns: The ChatMessage.
450462
"""
451463
text = "".join([chunk.content for chunk in chunks])
452464
tool_calls = []
@@ -486,12 +498,15 @@ def _convert_streaming_chunks_to_chat_message(
486498
_arguments=call_data["arguments"],
487499
)
488500

501+
# finish_reason is in the last chunk if usage is not included, and in the second last chunk if usage is included
502+
finish_reason = (chunks[-2] if chunk.usage and len(chunks) >= 2 else chunks[-1]).meta.get("finish_reason")
503+
489504
meta = {
490505
"model": chunk.model,
491506
"index": 0,
492-
"finish_reason": chunk.choices[0].finish_reason,
507+
"finish_reason": finish_reason,
493508
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
494-
"usage": {}, # we don't have usage data for streaming responses
509+
"usage": chunk.usage or {},
495510
}
496511

497512
return ChatMessage.from_assistant(text=text or None, tool_calls=tool_calls, meta=meta)
@@ -559,3 +574,18 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio
559574
}
560575
)
561576
return chunk_message
577+
578+
def _convert_usage_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
579+
"""
580+
Converts the usage chunk received from the OpenAI API when `include_usage` is set to `True` to a StreamingChunk.
581+
582+
:param chunk: The usage chunk returned by the OpenAI API.
583+
584+
:returns:
585+
The StreamingChunk.
586+
"""
587+
chunk_message = StreamingChunk(content="")
588+
chunk_message.meta.update(
589+
{"model": chunk.model, "usage": chunk.usage, "received_at": datetime.now().isoformat()}
590+
)
591+
return chunk_message

test/components/generators/chat/test_openai.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
from datetime import datetime
1111

12-
from openai import AsyncOpenAI, OpenAIError
12+
from openai import OpenAIError
1313
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
1414
from openai.types.chat.chat_completion import Choice
1515
from openai.types.chat.chat_completion_message_tool_call import Function
@@ -63,7 +63,6 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream):
6363
)
6464
],
6565
created=int(datetime.now().timestamp()),
66-
usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
6766
)
6867
mock_chat_completion_create.return_value = openai_mock_stream(
6968
completion, cast_to=None, response=None, client=None

test/components/generators/chat/test_openai_async.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def mock_chat_completion_chunk_with_tools(openai_mock_stream_async):
6262
)
6363
],
6464
created=int(datetime.now().timestamp()),
65-
usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
65+
usage=None,
6666
)
6767
mock_chat_completion_create.return_value = openai_mock_stream_async(completion)
6868
yield mock_chat_completion_create

test/components/generators/conftest.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77

88
import pytest
99
from openai import AsyncStream, Stream
10-
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
11-
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
12-
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, ChatCompletionMessageToolCall
10+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
11+
from openai.types.chat import ChatCompletion, ChatCompletionChunk
1312
from openai.types.chat import chat_completion_chunk
1413

1514

@@ -146,7 +145,7 @@ def openai_mock_chat_completion_chunk():
146145
)
147146
],
148147
created=int(datetime.now().timestamp()),
149-
usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
148+
usage=None,
150149
)
151150
mock_chat_completion_create.return_value = OpenAIMockStream(
152151
completion, cast_to=None, response=None, client=None
@@ -175,7 +174,7 @@ async def openai_mock_async_chat_completion_chunk():
175174
)
176175
],
177176
created=int(datetime.now().timestamp()),
178-
usage={"prompt_tokens": 57, "completion_tokens": 40, "total_tokens": 97},
177+
usage=None,
179178
)
180179
mock_chat_completion_create.return_value = OpenAIAsyncMockStream(completion)
181180
yield mock_chat_completion_create

0 commit comments

Comments
 (0)