Skip to content

Commit 9ab4dbe

Browse files
authored
Fix: fix the streaming issue for nova models (#580)
### Description This PR fixes the bug that cause streaming is broken for Amazon Nova model. The response was returned as a whole rather than streaming back chunks to customers, even though `.stream()` is used. ### Root Cause and Fix First of all, when using Amazon Nova model, langchain-aws will use ChatBedrockConverse even though we create ChatBedrock ([code](https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/langchain_aws/chat_models/bedrock.py#L704)). Inside ChatBedrockConverse, the input parameter `provider` is default to empty string rather than None ([code](https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/langchain_aws/chat_models/bedrock_converse.py#L407)), therefore, the [set_disable_streaming](https://github.com/langchain-ai/langchain-aws/blob/main/libs/aws/langchain_aws/chat_models/bedrock_converse.py#L582) function will use `provider=""` to disable the streaming, even though the model can support streaming. To fix it, this PR updates **set_disable_streaming** function to extract provider from model Id correctly when `provider` is empty string as default value, so it can setup the streaming flag correctly. Alternatively, user can set the `provider` parameter explicitly when initializing the ChatBedrock as below, which can make streaming work for Nova models without any code change. ``` llm = ChatBedrock( provider="amazon", model="us.amazon.nova-pro-v1:0", ) for chunk in llm.stream(prompt): print(chunk) ``` ### Issue #544
1 parent 6d5df57 commit 9ab4dbe

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def set_disable_streaming(cls, values: Dict) -> Any:
584584

585585
# Extract provider from the model_id
586586
# (e.g., "amazon", "anthropic", "ai21", "meta", "mistral")
587-
if "provider" not in values:
587+
if "provider" not in values or values["provider"] == "":
588588
if model_id.startswith("arn"):
589589
raise ValueError(
590590
"Model provider should be supplied when passing a model ARN as model_id."

libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1571,4 +1571,9 @@ def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
15711571
)
15721572

15731573
# The streaming should be disabled for models with no streaming support
1574-
assert chat_model.disable_streaming is True
1574+
assert chat_model.disable_streaming is True
1575+
1576+
def test_nova_provider_extraction() -> None:
1577+
"""Test that provider is correctly extracted from Nova model ID when not provided."""
1578+
model = ChatBedrockConverse(client=mock.MagicMock(), model="us.amazon.nova-pro-v1:0", region_name="us-west-2")
1579+
assert model.provider == "amazon"

0 commit comments

Comments
 (0)