Skip to content

Commit 396c83e

Browse files
committed
- implement new ModelAPIError superclass
- handle connection errors for bedrock and openai
1 parent a25612e commit 396c83e

File tree

10 files changed

+117
-28
lines changed

10 files changed

+117
-28
lines changed

docs/models/overview.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ contains all the exceptions encountered during the `run` execution.
180180
=== "Python >=3.11"
181181

182182
```python {title="fallback_model_failure.py" py="3.11"}
183-
from pydantic_ai import Agent, ModelHTTPError
183+
from pydantic_ai import Agent, ModelAPIError
184184
from pydantic_ai.models.anthropic import AnthropicModel
185185
from pydantic_ai.models.fallback import FallbackModel
186186
from pydantic_ai.models.openai import OpenAIChatModel
@@ -192,7 +192,7 @@ contains all the exceptions encountered during the `run` execution.
192192
agent = Agent(fallback_model)
193193
try:
194194
response = agent.run_sync('What is the capital of France?')
195-
except* ModelHTTPError as exc_group:
195+
except* ModelAPIError as exc_group:
196196
for exc in exc_group.exceptions:
197197
print(exc)
198198
```
@@ -206,7 +206,7 @@ contains all the exceptions encountered during the `run` execution.
206206
```python {title="fallback_model_failure.py" noqa="F821" test="skip"}
207207
from exceptiongroup import catch
208208

209-
from pydantic_ai import Agent, ModelHTTPError
209+
from pydantic_ai import Agent, ModelAPIError
210210
from pydantic_ai.models.anthropic import AnthropicModel
211211
from pydantic_ai.models.fallback import FallbackModel
212212
from pydantic_ai.models.openai import OpenAIChatModel
@@ -222,10 +222,11 @@ contains all the exceptions encountered during the `run` execution.
222222
fallback_model = FallbackModel(openai_model, anthropic_model)
223223

224224
agent = Agent(fallback_model)
225-
with catch({ModelHTTPError: model_status_error_handler}):
225+
with catch({ModelAPIError: model_status_error_handler}):
226226
response = agent.run_sync('What is the capital of France?')
227227
```
228228

229229
By default, the `FallbackModel` only moves on to the next model if the current model raises a
230+
[`ModelAPIError`][pydantic_ai.exceptions.ModelAPIError], which includes
230231
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError]. You can customize this behavior by
231232
passing a custom `fallback_on` argument to the `FallbackModel` constructor.

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CallDeferred,
2525
FallbackExceptionGroup,
2626
IncompleteToolCall,
27+
ModelAPIError,
2728
ModelHTTPError,
2829
ModelRetry,
2930
UnexpectedModelBehavior,
@@ -126,6 +127,7 @@
126127
'CallDeferred',
127128
'ApprovalRequired',
128129
'ModelRetry',
130+
'ModelAPIError',
129131
'ModelHTTPError',
130132
'FallbackExceptionGroup',
131133
'IncompleteToolCall',

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'AgentRunError',
2323
'UnexpectedModelBehavior',
2424
'UsageLimitExceeded',
25+
'ModelAPIError',
2526
'ModelHTTPError',
2627
'IncompleteToolCall',
2728
'FallbackExceptionGroup',
@@ -151,27 +152,31 @@ def __str__(self) -> str:
151152
return self.message
152153

153154

154-
class ModelHTTPError(AgentRunError):
155-
"""Raised when an model provider response has a status code of 4xx or 5xx."""
156-
157-
status_code: int
158-
"""The HTTP status code returned by the API."""
155+
class ModelAPIError(AgentRunError):
156+
"""Raised when a model provider API request fails."""
159157

160158
model_name: str
161159
"""The name of the model associated with the error."""
162160

163161
body: object | None
164162
"""The body of the response, if available."""
165163

166-
message: str
167-
"""The error message with the status code and response body, if available."""
164+
def __init__(self, model_name: str, body: object | None = None, message: str | None = None):
165+
self.model_name = model_name
166+
self.body = body
167+
super().__init__(message or f'model_name: {model_name}, body: {body}')
168+
169+
170+
class ModelHTTPError(ModelAPIError):
171+
"""Raised when an model provider response has a status code of 4xx or 5xx."""
172+
173+
status_code: int
174+
"""The HTTP status code returned by the API."""
168175

169176
def __init__(self, status_code: int, model_name: str, body: object | None = None):
170177
self.status_code = status_code
171-
self.model_name = model_name
172-
self.body = body
173178
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
174-
super().__init__(message)
179+
super().__init__(model_name=model_name, body=body, message=message)
175180

176181

177182
class FallbackExceptionGroup(ExceptionGroup[Any]):

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
usage,
4242
)
4343
from pydantic_ai._run_context import RunContext
44-
from pydantic_ai.exceptions import ModelHTTPError, UserError
44+
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError
4545
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
4646
from pydantic_ai.providers import Provider, infer_provider
4747
from pydantic_ai.providers.bedrock import BedrockModelProfile
@@ -312,8 +312,10 @@ async def count_tokens(
312312
try:
313313
response = await anyio.to_thread.run_sync(functools.partial(self.client.count_tokens, **params))
314314
except ClientError as e:
315-
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
316-
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
315+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode')
316+
if isinstance(status_code, int):
317+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
318+
raise ModelAPIError(model_name=self.model_name, body=e.response) from e
317319
return usage.RequestUsage(input_tokens=response['inputTokens'])
318320

319321
@asynccontextmanager
@@ -459,8 +461,10 @@ async def _messages_create(
459461
else:
460462
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
461463
except ClientError as e:
462-
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
463-
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
464+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode')
465+
if isinstance(status_code, int):
466+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
467+
raise ModelAPIError(model_name=self.model_name, body=e.response) from e
464468
return model_response
465469

466470
@staticmethod

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic_ai._run_context import RunContext
1212
from pydantic_ai.models.instrumented import InstrumentedModel
1313

14-
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14+
from ..exceptions import FallbackExceptionGroup, ModelAPIError
1515
from ..profiles import ModelProfile
1616
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1717

@@ -36,7 +36,7 @@ def __init__(
3636
self,
3737
default_model: Model | KnownModelName | str,
3838
*fallback_models: Model | KnownModelName | str,
39-
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
39+
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelAPIError,),
4040
):
4141
"""Initialize a fallback model instance.
4242

pydantic_ai_slim/pydantic_ai/models/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pydantic_core import to_json
1414
from typing_extensions import assert_never, deprecated
1515

16-
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
16+
from .. import ModelAPIError, ModelHTTPError, UnexpectedModelBehavior, _utils, usage
1717
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
1818
from .._run_context import RunContext
1919
from .._thinking_part import split_content_into_text_and_thinking
@@ -547,7 +547,7 @@ async def _completions_create(
547547
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
548548
raise # pragma: lax no cover
549549
except APIConnectionError as e:
550-
raise ModelHTTPError(status_code=0, model_name=self.model_name, body=str(e)) from e
550+
raise ModelAPIError(model_name=self.model_name, body=str(e)) from e
551551

552552
def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse:
553553
"""Process a non-streamed response, and prepare a message to return."""
@@ -1255,7 +1255,7 @@ async def _responses_create(
12551255
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
12561256
raise # pragma: lax no cover
12571257
except APIConnectionError as e:
1258-
raise ModelHTTPError(status_code=0, model_name=self.model_name, body=str(e)) from e
1258+
raise ModelAPIError(model_name=self.model_name, body=str(e)) from e
12591259

12601260
def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | Omit:
12611261
reasoning_effort = model_settings.get('openai_reasoning_effort', None)

tests/models/test_bedrock.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations as _annotations
22

33
import datetime
4+
from types import SimpleNamespace
45
from typing import Any
56

67
import pytest
8+
from botocore.exceptions import ClientError
79
from inline_snapshot import snapshot
810
from typing_extensions import TypedDict
911

@@ -33,11 +35,13 @@
3335
VideoUrl,
3436
)
3537
from pydantic_ai.agent import Agent
36-
from pydantic_ai.exceptions import ModelHTTPError, ModelRetry, UsageLimitExceeded
38+
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UsageLimitExceeded
3739
from pydantic_ai.messages import AgentStreamEvent
3840
from pydantic_ai.models import ModelRequestParameters
3941
from pydantic_ai.models.bedrock import BedrockConverseModel, BedrockModelSettings
4042
from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings
43+
from pydantic_ai.profiles import DEFAULT_PROFILE
44+
from pydantic_ai.providers import Provider
4145
from pydantic_ai.providers.bedrock import BedrockProvider
4246
from pydantic_ai.providers.openai import OpenAIProvider
4347
from pydantic_ai.run import AgentRunResult, AgentRunResultEvent
@@ -56,6 +60,53 @@
5660
]
5761

5862

63+
class _StubBedrockClient:
64+
"""Minimal Bedrock client that always raises the provided error."""
65+
66+
def __init__(self, error: ClientError):
67+
self._error = error
68+
self.meta = SimpleNamespace(endpoint_url='https://bedrock.stub')
69+
70+
def converse(self, **_: Any) -> None:
71+
raise self._error
72+
73+
def converse_stream(self, **_: Any) -> None:
74+
raise self._error
75+
76+
def count_tokens(self, **_: Any) -> None:
77+
raise self._error
78+
79+
80+
class _StubBedrockProvider(Provider[Any]):
81+
"""Provider implementation backed by the stub client."""
82+
83+
def __init__(self, client: _StubBedrockClient):
84+
self._client = client
85+
86+
@property
87+
def name(self) -> str:
88+
return 'bedrock-stub'
89+
90+
@property
91+
def base_url(self) -> str:
92+
return 'https://bedrock.stub'
93+
94+
@property
95+
def client(self) -> _StubBedrockClient:
96+
return self._client
97+
98+
def model_profile(self, model_name: str):
99+
return DEFAULT_PROFILE
100+
101+
102+
def _bedrock_model_with_client_error(error: ClientError) -> BedrockConverseModel:
103+
"""Instantiate a BedrockConverseModel wired to always raise the given error."""
104+
return BedrockConverseModel(
105+
'us.amazon.nova-micro-v1:0',
106+
provider=_StubBedrockProvider(_StubBedrockClient(error)),
107+
)
108+
109+
59110
async def test_bedrock_model(allow_model_requests: None, bedrock_provider: BedrockProvider):
60111
model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider)
61112
assert model.base_url == 'https://bedrock-runtime.us-east-1.amazonaws.com'
@@ -153,6 +204,28 @@ async def test_bedrock_count_tokens_error(allow_model_requests: None, bedrock_pr
153204
assert exc_info.value.body.get('Error', {}).get('Message') == 'The provided model identifier is invalid.' # type: ignore[union-attr]
154205

155206

207+
async def test_bedrock_request_non_http_error():
208+
error = ClientError({'Error': {'Code': 'TestException', 'Message': 'broken connection'}}, 'converse')
209+
model = _bedrock_model_with_client_error(error)
210+
params = ModelRequestParameters()
211+
212+
with pytest.raises(ModelAPIError) as exc_info:
213+
await model.request([ModelRequest.user_text_prompt('hi')], None, params)
214+
215+
assert exc_info.value.body == error.response
216+
217+
218+
async def test_bedrock_count_tokens_non_http_error():
219+
error = ClientError({'Error': {'Code': 'TestException', 'Message': 'broken connection'}}, 'count_tokens')
220+
model = _bedrock_model_with_client_error(error)
221+
params = ModelRequestParameters()
222+
223+
with pytest.raises(ModelAPIError) as exc_info:
224+
await model.count_tokens([ModelRequest.user_text_prompt('hi')], None, params)
225+
226+
assert exc_info.value.body == error.response
227+
228+
156229
@pytest.mark.parametrize(
157230
('model_name', 'expected'),
158231
[

tests/models/test_fallback.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pydantic_ai import (
1717
Agent,
18+
ModelAPIError,
1819
ModelHTTPError,
1920
ModelMessage,
2021
ModelProfile,
@@ -572,7 +573,7 @@ async def test_fallback_condition_tuple() -> None:
572573

573574
async def test_fallback_connection_error() -> None:
574575
def connection_error_response(_model_messages: list[ModelMessage], _agent_info: AgentInfo) -> ModelResponse:
575-
raise ModelHTTPError(status_code=0, model_name='test-connection-model', body='Connection timed out')
576+
raise ModelAPIError(model_name='test-connection-model', body='Connection timed out')
576577

577578
connection_error_model = FunctionModel(connection_error_response)
578579
fallback_model = FallbackModel(connection_error_model, success_model)

tests/models/test_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
CachePoint,
2121
DocumentUrl,
2222
ImageUrl,
23+
ModelAPIError,
2324
ModelHTTPError,
2425
ModelProfile,
2526
ModelRequest,
@@ -1155,9 +1156,8 @@ def test_model_connection_error(allow_model_requests: None) -> None:
11551156
)
11561157
m = OpenAIChatModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
11571158
agent = Agent(m)
1158-
with pytest.raises(ModelHTTPError) as exc_info:
1159+
with pytest.raises(ModelAPIError) as exc_info:
11591160
agent.run_sync('hello')
1160-
assert exc_info.value.status_code == 0
11611161
assert exc_info.value.model_name == 'gpt-4o'
11621162
assert 'Connection to http://localhost:11434/v1 timed out' in str(exc_info.value.body)
11631163

tests/test_exceptions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
ApprovalRequired,
1212
CallDeferred,
1313
IncompleteToolCall,
14+
ModelAPIError,
1415
ModelHTTPError,
1516
UnexpectedModelBehavior,
1617
UsageLimitExceeded,
@@ -28,6 +29,7 @@
2829
lambda: AgentRunError('test'),
2930
lambda: UnexpectedModelBehavior('test'),
3031
lambda: UsageLimitExceeded('test'),
32+
lambda: ModelAPIError('model'),
3133
lambda: ModelHTTPError(500, 'model'),
3234
lambda: IncompleteToolCall('test'),
3335
],
@@ -39,6 +41,7 @@
3941
'AgentRunError',
4042
'UnexpectedModelBehavior',
4143
'UsageLimitExceeded',
44+
'ModelAPIError',
4245
'ModelHTTPError',
4346
'IncompleteToolCall',
4447
],

0 commit comments

Comments
 (0)