Skip to content

Commit 5a68e01

Browse files
committed
Wrapping AWS Bedrock ClientError with ModelHTTPError
1 parent 6a153c8 commit 5a68e01

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import anyio
1313
import anyio.to_thread
14+
from botocore.exceptions import ClientError
1415
from typing_extensions import ParamSpec, assert_never
1516

1617
from pydantic_ai import (
@@ -21,6 +22,7 @@
2122
DocumentUrl,
2223
FinishReason,
2324
ImageUrl,
25+
ModelHTTPError,
2426
ModelMessage,
2527
ModelProfileSpec,
2628
ModelRequest,
@@ -408,10 +410,18 @@ async def _messages_create(
408410
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
409411
params['promptVariables'] = prompt_variables
410412

411-
if stream:
412-
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
413-
else:
414-
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
413+
try:
414+
if stream:
415+
model_response = await anyio.to_thread.run_sync(
416+
functools.partial(self.client.converse_stream, **params)
417+
)
418+
else:
419+
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
420+
except ClientError as e:
421+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
422+
raise ModelHTTPError(
423+
status_code=status_code, model_name=self.model_name, body=e.response.get('Error')
424+
) from e
415425
return model_response
416426

417427
@staticmethod

tests/models/test_bedrock.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import datetime
44
from typing import Any
5+
from unittest.mock import Mock
56

67
import pytest
8+
from botocore.exceptions import ClientError
79
from inline_snapshot import snapshot
810
from typing_extensions import TypedDict
911

@@ -14,6 +16,7 @@
1416
FunctionToolCallEvent,
1517
FunctionToolResultEvent,
1618
ImageUrl,
19+
ModelHTTPError,
1720
ModelRequest,
1821
ModelResponse,
1922
PartDeltaEvent,
@@ -1376,3 +1379,87 @@ async def test_bedrock_model_stream_empty_text_delta(allow_model_requests: None,
13761379
PartEndEvent(index=1, part=TextPart(content='Hello! How can I help you today?')),
13771380
]
13781381
)
1382+
1383+
1384+
def test_bedrock_throttling_error(bedrock_provider: BedrockProvider):
1385+
"""Test that ThrottlingException converts to ModelHTTPError with 429 status."""
1386+
expected_error = {'Code': 'ThrottlingException', 'Message': 'Rate exceeded'}
1387+
expected_status = 429
1388+
model_id = 'us.amazon.nova.micro-v1:0'
1389+
1390+
# Mock the client to raise throttling exception
1391+
error_response = {'Error': expected_error, 'ResponseMetadata': {'HTTPStatusCode': expected_status}}
1392+
mock_client = Mock()
1393+
mock_client.converse.side_effect = ClientError(error_response, 'Converse') # type: ignore[reportArgumentType]
1394+
mock_client.meta.endpoint_url = 'https://bedrock-runtime.us-east-1.amazonaws.com'
1395+
1396+
# Create provider with mocked client
1397+
from pydantic_ai.providers.bedrock import BedrockProvider
1398+
1399+
provider = BedrockProvider(bedrock_client=mock_client)
1400+
model = BedrockConverseModel(model_id, provider=provider)
1401+
agent = Agent(model)
1402+
1403+
with pytest.raises(ModelHTTPError) as exc_info:
1404+
agent.run_sync('hello')
1405+
1406+
assert exc_info.value.status_code == expected_status
1407+
assert exc_info.value.model_name == model_id
1408+
assert exc_info.value.body == expected_error
1409+
1410+
1411+
def test_bedrock_server_error():
1412+
"""Test that 5xx errors convert to ModelHTTPError."""
1413+
expected_error = {'Code': 'InternalServerError', 'Message': 'Internal error'}
1414+
expected_status = 500
1415+
model_id = 'us.amazon.nova.micro-v1:0'
1416+
1417+
# Mock the client to raise throttling exception
1418+
error_response = {'Error': expected_error, 'ResponseMetadata': {'HTTPStatusCode': expected_status}}
1419+
mock_client = Mock()
1420+
mock_client.converse.side_effect = ClientError(error_response, 'Converse') # type: ignore[reportArgumentType]
1421+
1422+
mock_client.meta.endpoint_url = 'https://bedrock-runtime.us-east-1.amazonaws.com'
1423+
1424+
# Create provider with mocked client
1425+
from pydantic_ai.providers.bedrock import BedrockProvider
1426+
1427+
provider = BedrockProvider(bedrock_client=mock_client)
1428+
model = BedrockConverseModel(model_id, provider=provider)
1429+
agent = Agent(model)
1430+
1431+
with pytest.raises(ModelHTTPError) as exc_info:
1432+
agent.run_sync('hello')
1433+
1434+
assert exc_info.value.status_code == expected_status
1435+
assert exc_info.value.model_name == model_id
1436+
assert exc_info.value.body == expected_error
1437+
1438+
1439+
async def test_bedrock_streaming_error():
1440+
"""Test that errors during streaming convert to ModelHTTPError."""
1441+
expected_error = {'Code': 'ThrottlingException', 'Message': 'Rate exceeded'}
1442+
expected_status = 429
1443+
model_id = 'us.amazon.nova.micro-v1:0'
1444+
1445+
# Mock the client to raise throttling exception
1446+
error_response = {'Error': expected_error, 'ResponseMetadata': {'HTTPStatusCode': expected_status}}
1447+
mock_client = Mock()
1448+
mock_client.converse_stream.side_effect = ClientError(error_response, 'ConverseStream') # type: ignore[reportArgumentType]
1449+
1450+
mock_client.meta.endpoint_url = 'https://bedrock-runtime.us-east-1.amazonaws.com'
1451+
1452+
# Create provider with mocked client
1453+
from pydantic_ai.providers.bedrock import BedrockProvider
1454+
1455+
provider = BedrockProvider(bedrock_client=mock_client)
1456+
model = BedrockConverseModel(model_id, provider=provider)
1457+
agent = Agent(model)
1458+
1459+
with pytest.raises(ModelHTTPError) as exc_info:
1460+
async with agent.run_stream('hello'):
1461+
pass
1462+
1463+
assert exc_info.value.status_code == expected_status
1464+
assert exc_info.value.model_name == model_id
1465+
assert exc_info.value.body == expected_error

0 commit comments

Comments
 (0)