Skip to content

Commit ea4b274

Browse files
Fix: Bedrock Application Inference Profile (AIP) is not streaming response with ChatBedrockConverse (#568)
### Description This PR fixed the streaming issue with ChatBedrockConverse when input is Application Inference Profile (AIP). The issue happened because langchain-aws fails to identify the foundation model used in AIP (e.g., `arn:aws:bedrock:us-east-1:111111484058:application-inference-profile/c3myu2h6fllr`), therefore, it cannot set the streaming_support flag for the AIP correctly. So the entire response was returned to user as a whole rather than streaming the response back. ### Solution We have to create a Bedrock client to call Bedrock get_inference_profile control plane [API](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock/client/get_inference_profile.html), and parse the foundation model id from the response. However, the existing `set_disable_streaming` function works on the raw user input before the ChatBedrockConverse object is instantiated, so we have to extract the logic of determining streaming_support for models into a common function `_get_streaming_support` and invoke it in both places (i.e., the original place that works on raw user input, and the new place that works on the resolved model Id from get_inference_profile API call). ### Issue #538 ### Test * add new unit tests in libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py and all passed * run through the integration test for ChatBedrockConverse and all passed except one existing failure * The main branch has failed this integration test already, which is not caused by new change. ``` FAILED tests/integration_tests/chat_models/test_bedrock_converse.py::test_structured_output_tool_choice_not_supported - assert 1 == 0 ``` --------- Co-authored-by: Michael Chin <[email protected]>
1 parent bce4ed8 commit ea4b274

File tree

2 files changed

+272
-25
lines changed

2 files changed

+272
-25
lines changed

libs/aws/langchain_aws/chat_models/bedrock_converse.py

Lines changed: 83 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ class Joke(BaseModel):
312312
""" # noqa: E501
313313

314314
client: Any = Field(default=None, exclude=True) #: :meta private:
315+
"""The bedrock runtime client for making data plane API calls"""
316+
317+
bedrock_client: Any = Field(default=None, exclude=True) #: :meta private:
318+
"""The bedrock client for making control plane API calls"""
315319

316320
model_id: str = Field(alias="model")
317321
"""Id of the model to call.
@@ -500,29 +504,15 @@ def build_extra(cls, values: dict[str, Any]) -> Any:
500504
}
501505
return values
502506

503-
@model_validator(mode="before")
504507
@classmethod
505-
def set_disable_streaming(cls, values: Dict) -> Any:
506-
model_id = values.get("model_id", values.get("model"))
507-
508-
# Extract provider from the model_id
509-
# (e.g., "amazon", "anthropic", "ai21", "meta", "mistral")
510-
if "provider" not in values:
511-
if model_id.startswith("arn"):
512-
raise ValueError(
513-
"Model provider should be supplied when passing a model ARN as model_id."
514-
)
515-
model_parts = model_id.split(".")
516-
values["provider"] = (
517-
model_parts[-2] if len(model_parts) > 1 else model_parts[0]
518-
)
519-
520-
provider = values["provider"]
521-
522-
model_id_lower = values.get(
523-
"base_model_id", values.get("base_model", model_id)
524-
).lower()
525-
508+
def _get_streaming_support(cls, provider: str, model_id_lower: str) -> Union[bool, str]:
509+
"""Determine streaming support for a given provider and model.
510+
511+
Returns:
512+
True: Full streaming support
513+
"no_tools": Streaming supported but not with tools
514+
False: No streaming support
515+
"""
526516
# Determine if the model supports plain-text streaming (ConverseStream)
527517
# Here we check based on the updated AWS documentation.
528518
if (
@@ -550,7 +540,7 @@ def set_disable_streaming(cls, values: Dict) -> Any:
550540
# Cohere Command R models
551541
(provider == "cohere" and "command-r" in model_id_lower)
552542
):
553-
streaming_support = True
543+
return True
554544
elif (
555545
# AI21 Jamba-Instruct model
556546
(provider == "ai21" and "jamba-instruct" in model_id_lower)
@@ -583,9 +573,34 @@ def set_disable_streaming(cls, values: Dict) -> Any:
583573
# Writer Palmyra models
584574
(provider == "writer" and "palmyra" in model_id_lower)
585575
):
586-
streaming_support = "no_tools"
576+
return "no_tools"
587577
else:
588-
streaming_support = False
578+
return False
579+
580+
@model_validator(mode="before")
581+
@classmethod
582+
def set_disable_streaming(cls, values: Dict) -> Any:
583+
model_id = values.get("model_id", values.get("model"))
584+
585+
# Extract provider from the model_id
586+
# (e.g., "amazon", "anthropic", "ai21", "meta", "mistral")
587+
if "provider" not in values:
588+
if model_id.startswith("arn"):
589+
raise ValueError(
590+
"Model provider should be supplied when passing a model ARN as model_id."
591+
)
592+
model_parts = model_id.split(".")
593+
values["provider"] = (
594+
model_parts[-2] if len(model_parts) > 1 else model_parts[0]
595+
)
596+
597+
provider = values["provider"]
598+
599+
model_id_lower = values.get(
600+
"base_model_id", values.get("base_model", model_id)
601+
).lower()
602+
603+
streaming_support = cls._get_streaming_support(provider, model_id_lower)
589604

590605
# Set the disable_streaming flag accordingly:
591606
# - If streaming is supported (plain streaming),
@@ -606,6 +621,23 @@ def set_disable_streaming(cls, values: Dict) -> Any:
606621
@model_validator(mode="after")
607622
def validate_environment(self) -> Self:
608623
"""Validate that AWS credentials to and python package exists in environment."""
624+
625+
# Create bedrock client for control plane API call
626+
if self.bedrock_client is None:
627+
self.bedrock_client = create_aws_client(
628+
region_name=self.region_name,
629+
credentials_profile_name=self.credentials_profile_name,
630+
aws_access_key_id=self.aws_access_key_id,
631+
aws_secret_access_key=self.aws_secret_access_key,
632+
aws_session_token=self.aws_session_token,
633+
endpoint_url=self.endpoint_url,
634+
config=self.config,
635+
service_name="bedrock",
636+
)
637+
638+
# Handle streaming configuration for application inference profiles
639+
if "application-inference-profile" in self.model_id:
640+
self._configure_streaming_for_resolved_model()
609641

610642
# As of 12/03/24:
611643
# only claude-3/4, mistral-large, and nova models support tool choice:
@@ -649,10 +681,36 @@ def validate_environment(self) -> Self:
649681
"Provide a guardrail via `guardrail_config` or "
650682
"disable `guard_last_turn_only`."
651683
)
684+
652685
return self
653686

654687
def _get_base_model(self) -> str:
688+
# identify the base model id used in the application inference profile (AIP)
689+
# Format: arn:aws:bedrock:us-east-1:<accountId>:application-inference-profile/<id>
690+
if self.base_model_id is None and 'application-inference-profile' in self.model_id:
691+
response = self.bedrock_client.get_inference_profile(
692+
inferenceProfileIdentifier=self.model_id
693+
)
694+
if 'models' in response and len(response['models']) > 0:
695+
model_arn = response['models'][0]['modelArn']
696+
# Format: arn:aws:bedrock:region::foundation-model/provider.model-name
697+
self.base_model_id = model_arn.split('/')[-1]
655698
return self.base_model_id if self.base_model_id else self.model_id
699+
700+
def _configure_streaming_for_resolved_model(self) -> None:
701+
"""Configure streaming support after resolving the base model for application inference profiles."""
702+
base_model = self._get_base_model()
703+
model_id_lower = base_model.lower()
704+
705+
streaming_support = self._get_streaming_support(self.provider, model_id_lower)
706+
707+
# Set the disable_streaming flag accordingly
708+
if not streaming_support:
709+
self.disable_streaming = True
710+
elif streaming_support == "no_tools":
711+
self.disable_streaming = "tool_calling"
712+
else:
713+
self.disable_streaming = False
656714

657715
def _apply_guard_last_turn_only(self, messages: List[Dict[str, Any]]) -> None:
658716
for msg in reversed(messages):

libs/aws/tests/unit_tests/chat_models/test_bedrock_converse.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,3 +1383,192 @@ def test_stream_guard_last_turn_only() -> None:
13831383
assert bedrock_msgs[-1]["content"][0] == {
13841384
"guardContent": {"text": {"text": "How are you?"}}
13851385
}
1386+
1387+
@mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client")
1388+
def test_bedrock_client_creation(mock_create_client: mock.Mock) -> None:
1389+
"""Test that bedrock_client is created during validation."""
1390+
mock_bedrock_client = mock.Mock()
1391+
mock_runtime_client = mock.Mock()
1392+
1393+
def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
1394+
if service_name == "bedrock":
1395+
return mock_bedrock_client
1396+
elif service_name == "bedrock-runtime":
1397+
return mock_runtime_client
1398+
return mock.Mock()
1399+
1400+
mock_create_client.side_effect = side_effect
1401+
1402+
chat_model = ChatBedrockConverse(
1403+
model="anthropic.claude-3-sonnet-20240229-v1:0",
1404+
region_name="us-west-2"
1405+
)
1406+
1407+
assert chat_model.bedrock_client == mock_bedrock_client
1408+
assert chat_model.client == mock_runtime_client
1409+
assert mock_create_client.call_count == 2
1410+
1411+
1412+
@mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client")
1413+
def test_get_base_model_with_application_inference_profile(mock_create_client: mock.Mock) -> None:
1414+
"""Test _get_base_model method with application inference profile."""
1415+
mock_bedrock_client = mock.Mock()
1416+
mock_runtime_client = mock.Mock()
1417+
1418+
# Mock the get_inference_profile response
1419+
mock_bedrock_client.get_inference_profile.return_value = {
1420+
'models': [
1421+
{
1422+
'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0'
1423+
}
1424+
]
1425+
}
1426+
1427+
def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
1428+
if service_name == "bedrock":
1429+
return mock_bedrock_client
1430+
elif service_name == "bedrock-runtime":
1431+
return mock_runtime_client
1432+
return mock.Mock()
1433+
1434+
mock_create_client.side_effect = side_effect
1435+
1436+
chat_model = ChatBedrockConverse(
1437+
model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile",
1438+
region_name="us-west-2",
1439+
provider="anthropic"
1440+
)
1441+
1442+
base_model = chat_model._get_base_model()
1443+
assert base_model == "anthropic.claude-3-sonnet-20240229-v1:0"
1444+
mock_bedrock_client.get_inference_profile.assert_called_once_with(
1445+
inferenceProfileIdentifier="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile"
1446+
)
1447+
1448+
1449+
@mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client")
1450+
def test_get_base_model_without_application_inference_profile(mock_create_client: mock.Mock) -> None:
1451+
"""Test _get_base_model method without application inference profile."""
1452+
mock_bedrock_client = mock.Mock()
1453+
mock_runtime_client = mock.Mock()
1454+
1455+
def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
1456+
if service_name == "bedrock":
1457+
return mock_bedrock_client
1458+
elif service_name == "bedrock-runtime":
1459+
return mock_runtime_client
1460+
return mock.Mock()
1461+
1462+
mock_create_client.side_effect = side_effect
1463+
1464+
chat_model = ChatBedrockConverse(
1465+
model="anthropic.claude-3-sonnet-20240229-v1:0",
1466+
region_name="us-west-2",
1467+
provider="anthropic"
1468+
)
1469+
1470+
base_model = chat_model._get_base_model()
1471+
assert base_model == "anthropic.claude-3-sonnet-20240229-v1:0"
1472+
mock_bedrock_client.get_inference_profile.assert_not_called()
1473+
1474+
1475+
@mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client")
1476+
def test_configure_streaming_for_resolved_model(mock_create_client: mock.Mock) -> None:
1477+
"""Test _configure_streaming_for_resolved_model method."""
1478+
mock_bedrock_client = mock.Mock()
1479+
mock_runtime_client = mock.Mock()
1480+
1481+
# Mock the get_inference_profile response for a model with full streaming support
1482+
mock_bedrock_client.get_inference_profile.return_value = {
1483+
'models': [
1484+
{
1485+
'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/anthropic.claude-3-sonnet-20240229-v1:0'
1486+
}
1487+
]
1488+
}
1489+
1490+
def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
1491+
if service_name == "bedrock":
1492+
return mock_bedrock_client
1493+
elif service_name == "bedrock-runtime":
1494+
return mock_runtime_client
1495+
return mock.Mock()
1496+
1497+
mock_create_client.side_effect = side_effect
1498+
1499+
chat_model = ChatBedrockConverse(
1500+
model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile",
1501+
region_name="us-west-2",
1502+
provider="anthropic"
1503+
)
1504+
1505+
# The streaming should be configured based on the resolved model
1506+
assert chat_model.disable_streaming is False
1507+
1508+
1509+
@mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client")
1510+
def test_configure_streaming_for_resolved_model_no_tools(mock_create_client: mock.Mock) -> None:
1511+
"""Test _configure_streaming_for_resolved_model method with no-tools streaming."""
1512+
mock_bedrock_client = mock.Mock()
1513+
mock_runtime_client = mock.Mock()
1514+
1515+
# Mock the get_inference_profile response for a model with no-tools streaming support
1516+
mock_bedrock_client.get_inference_profile.return_value = {
1517+
'models': [
1518+
{
1519+
'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/amazon.titan-text-express-v1'
1520+
}
1521+
]
1522+
}
1523+
1524+
def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
1525+
if service_name == "bedrock":
1526+
return mock_bedrock_client
1527+
elif service_name == "bedrock-runtime":
1528+
return mock_runtime_client
1529+
return mock.Mock()
1530+
1531+
mock_create_client.side_effect = side_effect
1532+
1533+
chat_model = ChatBedrockConverse(
1534+
model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile",
1535+
region_name="us-west-2",
1536+
provider="amazon"
1537+
)
1538+
1539+
# The streaming should be configured as "tool_calling" for no-tools models
1540+
assert chat_model.disable_streaming == "tool_calling"
1541+
1542+
1543+
@mock.patch("langchain_aws.chat_models.bedrock_converse.create_aws_client")
1544+
def test_configure_streaming_for_resolved_model_no_streaming(mock_create_client: mock.Mock) -> None:
1545+
"""Test _configure_streaming_for_resolved_model method with no streaming support."""
1546+
mock_bedrock_client = mock.Mock()
1547+
mock_runtime_client = mock.Mock()
1548+
1549+
# Mock the get_inference_profile response for a model with no streaming support
1550+
mock_bedrock_client.get_inference_profile.return_value = {
1551+
'models': [
1552+
{
1553+
'modelArn': 'arn:aws:bedrock:us-east-1::foundation-model/stability.stable-image-core-v1:0'
1554+
}
1555+
]
1556+
}
1557+
1558+
def side_effect(service_name: str, **kwargs: Any) -> mock.Mock:
1559+
if service_name == "bedrock":
1560+
return mock_bedrock_client
1561+
elif service_name == "bedrock-runtime":
1562+
return mock_runtime_client
1563+
return mock.Mock()
1564+
1565+
mock_create_client.side_effect = side_effect
1566+
1567+
chat_model = ChatBedrockConverse(
1568+
model="arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile",
1569+
region_name="us-west-2",
1570+
provider="stability"
1571+
)
1572+
1573+
# The streaming should be disabled for models with no streaming support
1574+
assert chat_model.disable_streaming is True

0 commit comments

Comments
 (0)