Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import anyio
import anyio.to_thread
from botocore.exceptions import ClientError
from typing_extensions import ParamSpec, assert_never

from pydantic_ai import (
Expand All @@ -21,6 +22,7 @@
DocumentUrl,
FinishReason,
ImageUrl,
ModelHTTPError,
ModelMessage,
ModelProfileSpec,
ModelRequest,
Expand Down Expand Up @@ -408,10 +410,18 @@ async def _messages_create(
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
params['promptVariables'] = prompt_variables

if stream:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
else:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
try:
if stream:
model_response = await anyio.to_thread.run_sync(
functools.partial(self.client.converse_stream, **params)
)
else:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
except ClientError as e:
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
raise ModelHTTPError(
status_code=status_code, model_name=self.model_name, body=e.response.get('Error')
) from e
return model_response

@staticmethod
Expand Down
87 changes: 87 additions & 0 deletions tests/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import datetime
from typing import Any
from unittest.mock import Mock

import pytest
from botocore.exceptions import ClientError
from inline_snapshot import snapshot
from typing_extensions import TypedDict

Expand All @@ -14,6 +16,7 @@
FunctionToolCallEvent,
FunctionToolResultEvent,
ImageUrl,
ModelHTTPError,
ModelRequest,
ModelResponse,
PartDeltaEvent,
Expand Down Expand Up @@ -1376,3 +1379,87 @@ async def test_bedrock_model_stream_empty_text_delta(allow_model_requests: None,
PartEndEvent(index=1, part=TextPart(content='Hello! How can I help you today?')),
]
)


def test_bedrock_throttling_error(bedrock_provider: BedrockProvider):
"""Test that ThrottlingException converts to ModelHTTPError with 429 status."""
expected_error = {'Code': 'ThrottlingException', 'Message': 'Rate exceeded'}
expected_status = 429
model_id = 'us.amazon.nova.micro-v1:0'

# Mock the client to raise throttling exception
error_response = {'Error': expected_error, 'ResponseMetadata': {'HTTPStatusCode': expected_status}}
mock_client = Mock()
mock_client.converse.side_effect = ClientError(error_response, 'Converse') # type: ignore[reportArgumentType]
mock_client.meta.endpoint_url = 'https://bedrock-runtime.us-east-1.amazonaws.com'

# Create provider with mocked client
from pydantic_ai.providers.bedrock import BedrockProvider

provider = BedrockProvider(bedrock_client=mock_client)
model = BedrockConverseModel(model_id, provider=provider)
agent = Agent(model)

with pytest.raises(ModelHTTPError) as exc_info:
agent.run_sync('hello')

assert exc_info.value.status_code == expected_status
assert exc_info.value.model_name == model_id
assert exc_info.value.body == expected_error


def test_bedrock_server_error():
"""Test that 5xx errors convert to ModelHTTPError."""
expected_error = {'Code': 'InternalServerError', 'Message': 'Internal error'}
expected_status = 500
model_id = 'us.amazon.nova.micro-v1:0'

# Mock the client to raise throttling exception
error_response = {'Error': expected_error, 'ResponseMetadata': {'HTTPStatusCode': expected_status}}
mock_client = Mock()
mock_client.converse.side_effect = ClientError(error_response, 'Converse') # type: ignore[reportArgumentType]

mock_client.meta.endpoint_url = 'https://bedrock-runtime.us-east-1.amazonaws.com'

# Create provider with mocked client
from pydantic_ai.providers.bedrock import BedrockProvider

provider = BedrockProvider(bedrock_client=mock_client)
model = BedrockConverseModel(model_id, provider=provider)
agent = Agent(model)

with pytest.raises(ModelHTTPError) as exc_info:
agent.run_sync('hello')

assert exc_info.value.status_code == expected_status
assert exc_info.value.model_name == model_id
assert exc_info.value.body == expected_error


async def test_bedrock_streaming_error():
"""Test that errors during streaming convert to ModelHTTPError."""
expected_error = {'Code': 'ThrottlingException', 'Message': 'Rate exceeded'}
expected_status = 429
model_id = 'us.amazon.nova.micro-v1:0'

# Mock the client to raise throttling exception
error_response = {'Error': expected_error, 'ResponseMetadata': {'HTTPStatusCode': expected_status}}
mock_client = Mock()
mock_client.converse_stream.side_effect = ClientError(error_response, 'ConverseStream') # type: ignore[reportArgumentType]

mock_client.meta.endpoint_url = 'https://bedrock-runtime.us-east-1.amazonaws.com'

# Create provider with mocked client
from pydantic_ai.providers.bedrock import BedrockProvider

provider = BedrockProvider(bedrock_client=mock_client)
model = BedrockConverseModel(model_id, provider=provider)
agent = Agent(model)

with pytest.raises(ModelHTTPError) as exc_info:
async with agent.run_stream('hello'):
pass

assert exc_info.value.status_code == expected_status
assert exc_info.value.model_name == model_id
assert exc_info.value.body == expected_error