Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import anyio
import anyio.to_thread
from botocore.exceptions import ClientError
from typing_extensions import ParamSpec, assert_never

from pydantic_ai import (
Expand All @@ -21,6 +22,7 @@
DocumentUrl,
FinishReason,
ImageUrl,
ModelHTTPError,
ModelMessage,
ModelProfileSpec,
ModelRequest,
Expand Down Expand Up @@ -408,10 +410,16 @@ async def _messages_create(
if prompt_variables := model_settings.get('bedrock_prompt_variables', None):
params['promptVariables'] = prompt_variables

if stream:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse_stream, **params))
else:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
try:
if stream:
model_response = await anyio.to_thread.run_sync(
functools.partial(self.client.converse_stream, **params)
)
else:
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
except ClientError as e:
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
return model_response

@staticmethod
Expand Down
31 changes: 31 additions & 0 deletions tests/models/cassettes/test_bedrock/test_bedrock_error.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": [{"text": "hello"}]}], "system": [], "inferenceConfig": {}}'
headers:
amz-sdk-invocation-id:
- !!binary |
ZWIzNTA0MWYtOTNhZi00YTFmLTk3YjEtMzE0MTFiNjA4ZjU5
amz-sdk-request:
- !!binary |
YXR0ZW1wdD0x
content-length:
- '101'
content-type:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
method: POST
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.does-not-exist-model-v1%3A0/converse
response:
headers:
connection:
- keep-alive
content-length:
- '55'
content-type:
- application/json
parsed_body:
message: The provided model identifier is invalid.
status:
code: 400
message: Bad Request
version: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
interactions:
- request:
body: '{"messages": [{"role": "user", "content": [{"text": "hello"}]}], "system": [], "inferenceConfig": {}}'
headers:
amz-sdk-invocation-id:
- !!binary |
ZGQ5YWNhZjAtNjM4Mi00NjI5LTkwMWMtOGY4MWY1Yjc5OGYz
amz-sdk-request:
- !!binary |
YXR0ZW1wdD0x
content-length:
- '101'
content-type:
- !!binary |
YXBwbGljYXRpb24vanNvbg==
method: POST
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.does-not-exist-model-v1%3A0/converse-stream
response:
headers:
connection:
- keep-alive
content-length:
- '55'
content-type:
- application/json
parsed_body:
message: The provided model identifier is invalid.
status:
code: 400
message: Bad Request
version: 1
32 changes: 32 additions & 0 deletions tests/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
FunctionToolCallEvent,
FunctionToolResultEvent,
ImageUrl,
ModelHTTPError,
ModelRequest,
ModelResponse,
PartDeltaEvent,
Expand Down Expand Up @@ -1376,3 +1377,34 @@ async def test_bedrock_model_stream_empty_text_delta(allow_model_requests: None,
PartEndEvent(index=1, part=TextPart(content='Hello! How can I help you today?')),
]
)


@pytest.mark.vcr()
async def test_bedrock_error(allow_model_requests: None, bedrock_provider: BedrockProvider):
"""Test that errors convert to ModelHTTPError."""
model_id = 'us.does-not-exist-model-v1:0'
model = BedrockConverseModel(model_id, provider=bedrock_provider)
agent = Agent(model)

with pytest.raises(ModelHTTPError) as exc_info:
await agent.run('hello')

assert exc_info.value.status_code == 400
assert exc_info.value.model_name == model_id
assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr]


@pytest.mark.vcr()
async def test_bedrock_streaming_error(allow_model_requests: None, bedrock_provider: BedrockProvider):
"""Test that errors during streaming convert to ModelHTTPError."""
model_id = 'us.does-not-exist-model-v1:0'
model = BedrockConverseModel(model_id, provider=bedrock_provider)
agent = Agent(model)

with pytest.raises(ModelHTTPError) as exc_info:
async with agent.run_stream('hello'):
pass

assert exc_info.value.status_code == 400
assert exc_info.value.model_name == model_id
assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr]