@@ -399,29 +399,39 @@ def _prepare_api_call( # noqa: PLR0913
399399 def _handle_stream_response (self , chat_completion : Stream , callback : SyncStreamingCallbackT ) -> List [ChatMessage ]:
400400 chunks : List [StreamingChunk ] = []
401401 chunk = None
402+ chunk_delta : StreamingChunk
402403
403404 for chunk in chat_completion : # pylint: disable=not-an-iterable
404- assert len (chunk .choices ) == 1 , "Streaming responses should have only one choice."
405- chunk_delta : StreamingChunk = self ._convert_chat_completion_chunk_to_streaming_chunk (chunk )
405+ # choices is an empty array for usage_chunk when include_usage is set to True
406+ if chunk .usage is not None :
407+ chunk_delta = self ._convert_usage_chunk_to_streaming_chunk (chunk )
408+
409+ else :
410+ assert len (chunk .choices ) == 1 , "Streaming responses should have only one choice."
411+ chunk_delta = self ._convert_chat_completion_chunk_to_streaming_chunk (chunk )
406412 chunks .append (chunk_delta )
407413
408414 callback (chunk_delta )
409-
410415 return [self ._convert_streaming_chunks_to_chat_message (chunk , chunks )]
411416
412417 async def _handle_async_stream_response (
413418 self , chat_completion : AsyncStream , callback : AsyncStreamingCallbackT
414419 ) -> List [ChatMessage ]:
415420 chunks : List [StreamingChunk ] = []
416421 chunk = None
422+ chunk_delta : StreamingChunk
417423
418424 async for chunk in chat_completion : # pylint: disable=not-an-iterable
419- assert len (chunk .choices ) == 1 , "Streaming responses should have only one choice."
420- chunk_delta : StreamingChunk = self ._convert_chat_completion_chunk_to_streaming_chunk (chunk )
425+ # choices is an empty array for usage_chunk when include_usage is set to True
426+ if chunk .usage is not None :
427+ chunk_delta = self ._convert_usage_chunk_to_streaming_chunk (chunk )
428+
429+ else :
430+ assert len (chunk .choices ) == 1 , "Streaming responses should have only one choice."
431+ chunk_delta = self ._convert_chat_completion_chunk_to_streaming_chunk (chunk )
421432 chunks .append (chunk_delta )
422433
423434 await callback (chunk_delta )
424-
425435 return [self ._convert_streaming_chunks_to_chat_message (chunk , chunks )]
426436
427437 def _check_finish_reason (self , meta : Dict [str , Any ]) -> None :
@@ -447,6 +457,8 @@ def _convert_streaming_chunks_to_chat_message(
447457
448458 :param chunk: The last chunk returned by the OpenAI API.
449459 :param chunks: The list of all `StreamingChunk` objects.
460+
461+ :returns: The ChatMessage.
450462 """
451463 text = "" .join ([chunk .content for chunk in chunks ])
452464 tool_calls = []
@@ -486,12 +498,15 @@ def _convert_streaming_chunks_to_chat_message(
486498 _arguments = call_data ["arguments" ],
487499 )
488500
501+ # finish_reason is in the last chunk if usage is not included, and in the second last chunk if usage is included
502+ finish_reason = (chunks [- 2 ] if chunk .usage and len (chunks ) >= 2 else chunks [- 1 ]).meta .get ("finish_reason" )
503+
489504 meta = {
490505 "model" : chunk .model ,
491506 "index" : 0 ,
492- "finish_reason" : chunk . choices [ 0 ]. finish_reason ,
507+ "finish_reason" : finish_reason ,
493508 "completion_start_time" : chunks [0 ].meta .get ("received_at" ), # first chunk received
494- "usage" : {}, # we don't have usage data for streaming responses
509+ "usage" : chunk . usage or {},
495510 }
496511
497512 return ChatMessage .from_assistant (text = text or None , tool_calls = tool_calls , meta = meta )
@@ -559,3 +574,18 @@ def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletio
559574 }
560575 )
561576 return chunk_message
577+
578+ def _convert_usage_chunk_to_streaming_chunk (self , chunk : ChatCompletionChunk ) -> StreamingChunk :
579+ """
580+ Converts the usage chunk received from the OpenAI API when `include_usage` is set to `True` to a StreamingChunk.
581+
582+ :param chunk: The usage chunk returned by the OpenAI API.
583+
584+ :returns:
585+ The StreamingChunk.
586+ """
587+ chunk_message = StreamingChunk (content = "" )
588+ chunk_message .meta .update (
589+ {"model" : chunk .model , "usage" : chunk .usage , "received_at" : datetime .now ().isoformat ()}
590+ )
591+ return chunk_message
0 commit comments