@@ -499,6 +499,9 @@ class GetPopulation(BaseModel):
499499 """Modify the likelihood of specified tokens appearing in the completion."""
500500 streaming : bool = False
501501 """Whether to stream the results or not."""
502+ stream_usage : bool | None = None
503+ """Whether to include usage metadata in streaming output. If True, an additional
504+ message chunk will be generated during the stream including usage metadata."""
502505 n : int | None = None
503506 """Number of chat completions to generate for each prompt."""
504507 top_p : float | None = None
@@ -634,14 +637,40 @@ async def _agenerate(
634637 )
635638 return self ._to_chat_result (llm_result )
636639
640+ def _should_stream_usage (
641+ self , * , stream_usage : bool | None = None , ** kwargs : Any
642+ ) -> bool | None :
643+ """Determine whether to include usage metadata in streaming output.
644+
645+ For backwards compatibility, we check for `stream_options` passed
646+ explicitly to kwargs or in the model_kwargs and override self.stream_usage.
647+ """
648+ stream_usage_sources = [ # order of precedence
649+ stream_usage ,
650+ kwargs .get ("stream_options" , {}).get ("include_usage" ),
651+ self .model_kwargs .get ("stream_options" , {}).get ("include_usage" ),
652+ self .stream_usage ,
653+ ]
654+ for source in stream_usage_sources :
655+ if isinstance (source , bool ):
656+ return source
657+ return self .stream_usage
658+
637659 def _stream (
638660 self ,
639661 messages : list [BaseMessage ],
640662 stop : list [str ] | None = None ,
641663 run_manager : CallbackManagerForLLMRun | None = None ,
664+ * ,
665+ stream_usage : bool | None = None ,
642666 ** kwargs : Any ,
643667 ) -> Iterator [ChatGenerationChunk ]:
644668 if _is_huggingface_endpoint (self .llm ):
669+ stream_usage = self ._should_stream_usage (
670+ stream_usage = stream_usage , ** kwargs
671+ )
672+ if stream_usage :
673+ kwargs ["stream_options" ] = {"include_usage" : stream_usage }
645674 message_dicts , params = self ._create_message_dicts (messages , stop )
646675 params = {** params , ** kwargs , "stream" : True }
647676
@@ -650,7 +679,20 @@ def _stream(
650679 messages = message_dicts , ** params
651680 ):
652681 if len (chunk ["choices" ]) == 0 :
682+ if usage := chunk .get ("usage" ):
683+ usage_msg = AIMessageChunk (
684+ content = "" ,
685+ additional_kwargs = {},
686+ response_metadata = {},
687+ usage_metadata = {
688+ "input_tokens" : usage .get ("prompt_tokens" , 0 ),
689+ "output_tokens" : usage .get ("completion_tokens" , 0 ),
690+ "total_tokens" : usage .get ("total_tokens" , 0 ),
691+ },
692+ )
693+ yield ChatGenerationChunk (message = usage_msg )
653694 continue
695+
654696 choice = chunk ["choices" ][0 ]
655697 message_chunk = _convert_chunk_to_message_chunk (
656698 chunk , default_chunk_class
@@ -688,8 +730,13 @@ async def _astream(
688730 messages : list [BaseMessage ],
689731 stop : list [str ] | None = None ,
690732 run_manager : AsyncCallbackManagerForLLMRun | None = None ,
733+ * ,
734+ stream_usage : bool | None = None ,
691735 ** kwargs : Any ,
692736 ) -> AsyncIterator [ChatGenerationChunk ]:
737+ stream_usage = self ._should_stream_usage (stream_usage = stream_usage , ** kwargs )
738+ if stream_usage :
739+ kwargs ["stream_options" ] = {"include_usage" : stream_usage }
693740 message_dicts , params = self ._create_message_dicts (messages , stop )
694741 params = {** params , ** kwargs , "stream" : True }
695742
@@ -699,7 +746,20 @@ async def _astream(
699746 messages = message_dicts , ** params
700747 ):
701748 if len (chunk ["choices" ]) == 0 :
749+ if usage := chunk .get ("usage" ):
750+ usage_msg = AIMessageChunk (
751+ content = "" ,
752+ additional_kwargs = {},
753+ response_metadata = {},
754+ usage_metadata = {
755+ "input_tokens" : usage .get ("prompt_tokens" , 0 ),
756+ "output_tokens" : usage .get ("completion_tokens" , 0 ),
757+ "total_tokens" : usage .get ("total_tokens" , 0 ),
758+ },
759+ )
760+ yield ChatGenerationChunk (message = usage_msg )
702761 continue
762+
703763 choice = chunk ["choices" ][0 ]
704764 message_chunk = _convert_chunk_to_message_chunk (chunk , default_chunk_class )
705765 generation_info = {}
0 commit comments