Skip to content

Commit 6829f50

Browse files
authored
Fixes streaming for llama3 models. (#116)
1 parent 8a17693 commit 6829f50

File tree

3 files changed

+27
-22
lines changed

3 files changed

+27
-22
lines changed

libs/aws/langchain_aws/chat_models/bedrock.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -565,20 +565,16 @@ def _generate(
565565
)
566566
else:
567567
usage_metadata = None
568+
568569
llm_output["model_id"] = self.model_id
569-
if len(tool_calls) > 0:
570-
msg = AIMessage(
571-
content=completion,
572-
additional_kwargs=llm_output,
573-
tool_calls=cast(List[ToolCall], tool_calls),
574-
usage_metadata=usage_metadata,
575-
)
576-
else:
577-
msg = AIMessage(
578-
content=completion,
579-
additional_kwargs=llm_output,
580-
usage_metadata=usage_metadata,
581-
)
570+
571+
msg = AIMessage(
572+
content=completion,
573+
additional_kwargs=llm_output,
574+
tool_calls=cast(List[ToolCall], tool_calls),
575+
usage_metadata=usage_metadata,
576+
)
577+
582578
return ChatResult(
583579
generations=[
584580
ChatGeneration(

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@ def _stream_response_to_generation_chunk(
117117
return None
118118
else:
119119
# chunk obj format varies with provider
120-
generation_info = {k: v for k, v in stream_response.items() if k != output_key}
120+
generation_info = {
121+
k: v
122+
for k, v in stream_response.items()
123+
if k not in [output_key, "prompt_token_count", "generation_token_count"]
124+
}
121125
return GenerationChunk(
122126
text=(
123127
stream_response[output_key]
@@ -347,6 +351,10 @@ def prepare_output_stream(
347351
yield _get_invocation_metrics_chunk(chunk_obj)
348352
return
349353

354+
elif provider == "meta" and chunk_obj.get("stop_reason", "") == "stop":
355+
yield _get_invocation_metrics_chunk(chunk_obj)
356+
return
357+
350358
elif messages_api and (chunk_obj.get("type") == "message_stop"):
351359
yield _get_invocation_metrics_chunk(chunk_obj)
352360
return

libs/aws/tests/integration_tests/chat_models/test_bedrock.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,17 +84,18 @@ def test_chat_bedrock_streaming() -> None:
8484
@pytest.mark.scheduled
8585
def test_chat_bedrock_streaming_llama3() -> None:
8686
"""Test that streaming correctly invokes on_llm_new_token callback."""
87-
callback_handler = FakeCallbackHandler()
8887
chat = ChatBedrock( # type: ignore[call-arg]
89-
model_id="meta.llama3-8b-instruct-v1:0",
90-
streaming=True,
91-
callbacks=[callback_handler],
92-
verbose=True,
88+
model_id="meta.llama3-8b-instruct-v1:0"
9389
)
9490
message = HumanMessage(content="Hello")
95-
response = chat([message])
96-
assert callback_handler.llm_streams > 0
97-
assert isinstance(response, BaseMessage)
91+
92+
response = AIMessageChunk(content="")
93+
for chunk in chat.stream([message]):
94+
response += chunk # type: ignore[assignment]
95+
96+
assert response.content
97+
assert response.response_metadata
98+
assert response.usage_metadata
9899

99100

100101
@pytest.mark.scheduled

0 commit comments

Comments
 (0)