Skip to content

Commit 0e36185

Browse files
authored
fix(huggingface): add stream_usage support for ChatHuggingFace invoke/stream (#32708)
1 parent 6617865 commit 0e36185

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

libs/partners/huggingface/langchain_huggingface/chat_models/huggingface.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,9 @@ class GetPopulation(BaseModel):
499499
"""Modify the likelihood of specified tokens appearing in the completion."""
500500
streaming: bool = False
501501
"""Whether to stream the results or not."""
502+
stream_usage: bool | None = None
503+
"""Whether to include usage metadata in streaming output. If True, an additional
504+
message chunk will be generated during the stream including usage metadata."""
502505
n: int | None = None
503506
"""Number of chat completions to generate for each prompt."""
504507
top_p: float | None = None
@@ -634,14 +637,40 @@ async def _agenerate(
634637
)
635638
return self._to_chat_result(llm_result)
636639

640+
def _should_stream_usage(
641+
self, *, stream_usage: bool | None = None, **kwargs: Any
642+
) -> bool | None:
643+
"""Determine whether to include usage metadata in streaming output.
644+
645+
For backwards compatibility, we check for `stream_options` passed
646+
explicitly to kwargs or in the model_kwargs and override self.stream_usage.
647+
"""
648+
stream_usage_sources = [ # order of precedence
649+
stream_usage,
650+
kwargs.get("stream_options", {}).get("include_usage"),
651+
self.model_kwargs.get("stream_options", {}).get("include_usage"),
652+
self.stream_usage,
653+
]
654+
for source in stream_usage_sources:
655+
if isinstance(source, bool):
656+
return source
657+
return self.stream_usage
658+
637659
def _stream(
638660
self,
639661
messages: list[BaseMessage],
640662
stop: list[str] | None = None,
641663
run_manager: CallbackManagerForLLMRun | None = None,
664+
*,
665+
stream_usage: bool | None = None,
642666
**kwargs: Any,
643667
) -> Iterator[ChatGenerationChunk]:
644668
if _is_huggingface_endpoint(self.llm):
669+
stream_usage = self._should_stream_usage(
670+
stream_usage=stream_usage, **kwargs
671+
)
672+
if stream_usage:
673+
kwargs["stream_options"] = {"include_usage": stream_usage}
645674
message_dicts, params = self._create_message_dicts(messages, stop)
646675
params = {**params, **kwargs, "stream": True}
647676

@@ -650,7 +679,20 @@ def _stream(
650679
messages=message_dicts, **params
651680
):
652681
if len(chunk["choices"]) == 0:
682+
if usage := chunk.get("usage"):
683+
usage_msg = AIMessageChunk(
684+
content="",
685+
additional_kwargs={},
686+
response_metadata={},
687+
usage_metadata={
688+
"input_tokens": usage.get("prompt_tokens", 0),
689+
"output_tokens": usage.get("completion_tokens", 0),
690+
"total_tokens": usage.get("total_tokens", 0),
691+
},
692+
)
693+
yield ChatGenerationChunk(message=usage_msg)
653694
continue
695+
654696
choice = chunk["choices"][0]
655697
message_chunk = _convert_chunk_to_message_chunk(
656698
chunk, default_chunk_class
@@ -688,8 +730,13 @@ async def _astream(
688730
messages: list[BaseMessage],
689731
stop: list[str] | None = None,
690732
run_manager: AsyncCallbackManagerForLLMRun | None = None,
733+
*,
734+
stream_usage: bool | None = None,
691735
**kwargs: Any,
692736
) -> AsyncIterator[ChatGenerationChunk]:
737+
stream_usage = self._should_stream_usage(stream_usage=stream_usage, **kwargs)
738+
if stream_usage:
739+
kwargs["stream_options"] = {"include_usage": stream_usage}
693740
message_dicts, params = self._create_message_dicts(messages, stop)
694741
params = {**params, **kwargs, "stream": True}
695742

@@ -699,7 +746,20 @@ async def _astream(
699746
messages=message_dicts, **params
700747
):
701748
if len(chunk["choices"]) == 0:
749+
if usage := chunk.get("usage"):
750+
usage_msg = AIMessageChunk(
751+
content="",
752+
additional_kwargs={},
753+
response_metadata={},
754+
usage_metadata={
755+
"input_tokens": usage.get("prompt_tokens", 0),
756+
"output_tokens": usage.get("completion_tokens", 0),
757+
"total_tokens": usage.get("total_tokens", 0),
758+
},
759+
)
760+
yield ChatGenerationChunk(message=usage_msg)
702761
continue
762+
703763
choice = chunk["choices"][0]
704764
message_chunk = _convert_chunk_to_message_chunk(chunk, default_chunk_class)
705765
generation_info = {}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from langchain_core.messages import AIMessageChunk
2+
3+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
4+
5+
6+
def test_stream_usage() -> None:
7+
"""Test we are able to configure stream options on models that require it."""
8+
llm = HuggingFaceEndpoint( # type: ignore[call-arg] # (model is inferred in class)
9+
repo_id="google/gemma-3-27b-it",
10+
task="conversational",
11+
provider="nebius",
12+
)
13+
14+
model = ChatHuggingFace(llm=llm, stream_usage=True)
15+
16+
full: AIMessageChunk | None = None
17+
for chunk in model.stream("hello"):
18+
assert isinstance(chunk, AIMessageChunk)
19+
full = chunk if full is None else full + chunk
20+
21+
assert isinstance(full, AIMessageChunk)
22+
assert full.usage_metadata

0 commit comments

Comments
 (0)