|
1 | 1 | from __future__ import annotations as _annotations |
2 | 2 |
|
3 | 3 | import datetime |
| 4 | +from types import SimpleNamespace |
4 | 5 | from typing import Any |
5 | 6 |
|
6 | 7 | import pytest |
| 8 | +from botocore.exceptions import ClientError |
7 | 9 | from inline_snapshot import snapshot |
8 | 10 | from typing_extensions import TypedDict |
9 | 11 |
|
|
33 | 35 | VideoUrl, |
34 | 36 | ) |
35 | 37 | from pydantic_ai.agent import Agent |
36 | | -from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UsageLimitExceeded |
| 38 | +from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UsageLimitExceeded |
37 | 39 | from pydantic_ai.messages import AgentStreamEvent |
38 | 40 | from pydantic_ai.models import ModelRequestParameters |
39 | 41 | from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings |
40 | 42 | from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings |
| 43 | +from pydantic_ai.profiles import DEFAULT_PROFILE |
| 44 | +from pydantic_ai.providers import Provider |
41 | 45 | from pydantic_ai.providers.bedrock import BedrockProvider |
42 | 46 | from pydantic_ai.providers.openai import OpenAIProvider |
43 | 47 | from pydantic_ai.run import AgentRunResult, AgentRunResultEvent |
|
56 | 60 | ] |
57 | 61 |
|
58 | 62 |
|
| 63 | +class _StubBedrockClient: |
| 64 | + """Minimal Bedrock client that always raises the provided error.""" |
| 65 | + |
| 66 | + def __init__(self, error: ClientError): |
| 67 | + self._error = error |
| 68 | + self.meta = SimpleNamespace(endpoint_url='https://bedrock.stub') |
| 69 | + |
| 70 | + def converse(self, **_: Any) -> None: |
| 71 | + raise self._error |
| 72 | + |
| 73 | + def converse_stream(self, **_: Any) -> None: |
| 74 | + raise self._error |
| 75 | + |
| 76 | + def count_tokens(self, **_: Any) -> None: |
| 77 | + raise self._error |
| 78 | + |
| 79 | + |
| 80 | +class _StubBedrockProvider(Provider[Any]): |
| 81 | + """Provider implementation backed by the stub client.""" |
| 82 | + |
| 83 | + def __init__(self, client: _StubBedrockClient): |
| 84 | + self._client = client |
| 85 | + |
| 86 | + @property |
| 87 | + def name(self) -> str: |
| 88 | + return 'bedrock-stub' |
| 89 | + |
| 90 | + @property |
| 91 | + def base_url(self) -> str: |
| 92 | + return 'https://bedrock.stub' |
| 93 | + |
| 94 | + @property |
| 95 | + def client(self) -> _StubBedrockClient: |
| 96 | + return self._client |
| 97 | + |
| 98 | + def model_profile(self, model_name: str): |
| 99 | + return DEFAULT_PROFILE |
| 100 | + |
| 101 | + |
| 102 | +def _bedrock_model_with_client_error(error: ClientError) -> BedrockConverseModel: |
| 103 | + """Instantiate a BedrockConverseModel wired to always raise the given error.""" |
| 104 | + return BedrockConverseModel( |
| 105 | + 'us.amazon.nova-micro-v1:0', |
| 106 | + provider=_StubBedrockProvider(_StubBedrockClient(error)), |
| 107 | + ) |
| 108 | + |
| 109 | + |
59 | 110 | async def test_bedrock_model(allow_model_requests: None, bedrock_provider: BedrockProvider): |
60 | 111 | model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) |
61 | 112 | assert model.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com' |
@@ -153,6 +204,28 @@ async def test_bedrock_count_tokens_error(allow_model_requests: None, bedrock_pr |
153 | 204 | assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr] |
154 | 205 |
|
155 | 206 |
|
| 207 | +async def test_bedrock_request_non_http_error(): |
| 208 | + error = ClientError({'Error': {'Code': 'TestException', 'Message': 'broken connection'}}, 'converse') |
| 209 | + model = _bedrock_model_with_client_error(error) |
| 210 | + params = ModelRequestParameters() |
| 211 | + |
| 212 | + with pytest.raises(ModelAPIError) as exc_info: |
| 213 | + await model.request([ModelRequest.user_text_prompt('hi')], None, params) |
| 214 | + |
| 215 | + assert exc_info.value.body == error.response |
| 216 | + |
| 217 | + |
| 218 | +async def test_bedrock_count_tokens_non_http_error(): |
| 219 | + error = ClientError({'Error': {'Code': 'TestException', 'Message': 'broken connection'}}, 'count_tokens') |
| 220 | + model = _bedrock_model_with_client_error(error) |
| 221 | + params = ModelRequestParameters() |
| 222 | + |
| 223 | + with pytest.raises(ModelAPIError) as exc_info: |
| 224 | + await model.count_tokens([ModelRequest.user_text_prompt('hi')], None, params) |
| 225 | + |
| 226 | + assert exc_info.value.body == error.response |
| 227 | + |
| 228 | + |
156 | 229 | @pytest.mark.parametrize( |
157 | 230 | ('model_name', 'expected'), |
158 | 231 | [ |
|
0 commit comments