Skip to content

Commit ab6de27

Browse files
dsfacciniDouweM
andauthored
Make FallbackModel fall back on all model API errors, not just HTTP status 400+ (#3494)
Co-authored-by: Douwe Maan <[email protected]>
1 parent 2685b2c commit ab6de27

File tree

20 files changed

+317
-42
lines changed

20 files changed

+317
-42
lines changed

docs/models/overview.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ contains all the exceptions encountered during the `run` execution.
180180
=== "Python >=3.11"
181181

182182
```python {title="fallback_model_failure.py" py="3.11"}
183-
from pydantic_ai import Agent, ModelHTTPError
183+
from pydantic_ai import Agent, ModelAPIError
184184
from pydantic_ai.models.anthropic import AnthropicModel
185185
from pydantic_ai.models.fallback import FallbackModel
186186
from pydantic_ai.models.openai import OpenAIChatModel
@@ -192,7 +192,7 @@ contains all the exceptions encountered during the `run` execution.
192192
agent = Agent(fallback_model)
193193
try:
194194
response = agent.run_sync('What is the capital of France?')
195-
except* ModelHTTPError as exc_group:
195+
except* ModelAPIError as exc_group:
196196
for exc in exc_group.exceptions:
197197
print(exc)
198198
```
@@ -206,7 +206,7 @@ contains all the exceptions encountered during the `run` execution.
206206
```python {title="fallback_model_failure.py" noqa="F821" test="skip"}
207207
from exceptiongroup import catch
208208

209-
from pydantic_ai import Agent, ModelHTTPError
209+
from pydantic_ai import Agent, ModelAPIError
210210
from pydantic_ai.models.anthropic import AnthropicModel
211211
from pydantic_ai.models.fallback import FallbackModel
212212
from pydantic_ai.models.openai import OpenAIChatModel
@@ -222,10 +222,11 @@ contains all the exceptions encountered during the `run` execution.
222222
fallback_model = FallbackModel(openai_model, anthropic_model)
223223

224224
agent = Agent(fallback_model)
225-
with catch({ModelHTTPError: model_status_error_handler}):
225+
with catch({ModelAPIError: model_status_error_handler}):
226226
response = agent.run_sync('What is the capital of France?')
227227
```
228228

229229
By default, the `FallbackModel` only moves on to the next model if the current model raises a
230+
[`ModelAPIError`][pydantic_ai.exceptions.ModelAPIError], which includes
230231
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError]. You can customize this behavior by
231232
passing a custom `fallback_on` argument to the `FallbackModel` constructor.

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CallDeferred,
2525
FallbackExceptionGroup,
2626
IncompleteToolCall,
27+
ModelAPIError,
2728
ModelHTTPError,
2829
ModelRetry,
2930
UnexpectedModelBehavior,
@@ -126,6 +127,7 @@
126127
'CallDeferred',
127128
'ApprovalRequired',
128129
'ModelRetry',
130+
'ModelAPIError',
129131
'ModelHTTPError',
130132
'FallbackExceptionGroup',
131133
'IncompleteToolCall',

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
'AgentRunError',
2323
'UnexpectedModelBehavior',
2424
'UsageLimitExceeded',
25+
'ModelAPIError',
2526
'ModelHTTPError',
2627
'IncompleteToolCall',
2728
'FallbackExceptionGroup',
@@ -151,27 +152,31 @@ def __str__(self) -> str:
151152
return self.message
152153

153154

154-
class ModelHTTPError(AgentRunError):
155+
class ModelAPIError(AgentRunError):
156+
"""Raised when a model provider API request fails."""
157+
158+
model_name: str
159+
"""The name of the model associated with the error."""
160+
161+
def __init__(self, model_name: str, message: str):
162+
self.model_name = model_name
163+
super().__init__(message)
164+
165+
166+
class ModelHTTPError(ModelAPIError):
155167
"""Raised when an model provider response has a status code of 4xx or 5xx."""
156168

157169
status_code: int
158170
"""The HTTP status code returned by the API."""
159171

160-
model_name: str
161-
"""The name of the model associated with the error."""
162-
163172
body: object | None
164173
"""The body of the response, if available."""
165174

166-
message: str
167-
"""The error message with the status code and response body, if available."""
168-
169175
def __init__(self, status_code: int, model_name: str, body: object | None = None):
170176
self.status_code = status_code
171-
self.model_name = model_name
172177
self.body = body
173178
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
174-
super().__init__(message)
179+
super().__init__(model_name=model_name, message=message)
175180

176181

177182
class FallbackExceptionGroup(ExceptionGroup[Any]):

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .._run_context import RunContext
1515
from .._utils import guard_tool_call_id as _guard_tool_call_id
1616
from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebSearchTool
17-
from ..exceptions import UserError
17+
from ..exceptions import ModelAPIError, UserError
1818
from ..messages import (
1919
BinaryContent,
2020
BuiltinToolCallPart,
@@ -55,7 +55,14 @@
5555

5656

5757
try:
58-
from anthropic import NOT_GIVEN, APIStatusError, AsyncAnthropicBedrock, AsyncStream, omit as OMIT
58+
from anthropic import (
59+
NOT_GIVEN,
60+
APIConnectionError,
61+
APIStatusError,
62+
AsyncAnthropicBedrock,
63+
AsyncStream,
64+
omit as OMIT,
65+
)
5966
from anthropic.types.beta import (
6067
BetaBase64PDFBlockParam,
6168
BetaBase64PDFSourceParam,
@@ -358,7 +365,9 @@ async def _messages_create(
358365
except APIStatusError as e:
359366
if (status_code := e.status_code) >= 400:
360367
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
361-
raise # pragma: lax no cover
368+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: lax no cover
369+
except APIConnectionError as e:
370+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e
362371

363372
async def _messages_count_tokens(
364373
self,
@@ -395,7 +404,9 @@ async def _messages_count_tokens(
395404
except APIStatusError as e:
396405
if (status_code := e.status_code) >= 400:
397406
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
398-
raise # pragma: lax no cover
407+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: lax no cover
408+
except APIConnectionError as e:
409+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e
399410

400411
def _process_response(self, response: BetaMessage) -> ModelResponse:
401412
"""Process a non-streamed response, and prepare a message to return."""

pydantic_ai_slim/pydantic_ai/models/bedrock.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
usage,
4242
)
4343
from pydantic_ai._run_context import RunContext
44-
from pydantic_ai.exceptions import ModelHTTPError, UserError
44+
from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, UserError
4545
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse, download_item
4646
from pydantic_ai.providers import Provider, infer_provider
4747
from pydantic_ai.providers.bedrock import BedrockModelProfile
@@ -312,8 +312,10 @@ async def count_tokens(
312312
try:
313313
response = await anyio.to_thread.run_sync(functools.partial(self.client.count_tokens, **params))
314314
except ClientError as e:
315-
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
316-
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
315+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode')
316+
if isinstance(status_code, int):
317+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
318+
raise ModelAPIError(model_name=self.model_name, message=str(e)) from e
317319
return usage.RequestUsage(input_tokens=response['inputTokens'])
318320

319321
@asynccontextmanager
@@ -459,8 +461,10 @@ async def _messages_create(
459461
else:
460462
model_response = await anyio.to_thread.run_sync(functools.partial(self.client.converse, **params))
461463
except ClientError as e:
462-
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500)
463-
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
464+
status_code = e.response.get('ResponseMetadata', {}).get('HTTPStatusCode')
465+
if isinstance(status_code, int):
466+
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.response) from e
467+
raise ModelAPIError(model_name=self.model_name, message=str(e)) from e
464468
return model_response
465469

466470
@staticmethod

pydantic_ai_slim/pydantic_ai/models/cohere.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing_extensions import assert_never
88

9-
from pydantic_ai.exceptions import UserError
9+
from pydantic_ai.exceptions import ModelAPIError, UserError
1010

1111
from .. import ModelHTTPError, usage
1212
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
@@ -195,7 +195,7 @@ async def _chat(
195195
except ApiError as e:
196196
if (status_code := e.status_code) and status_code >= 400:
197197
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
198-
raise # pragma: lax no cover
198+
raise ModelAPIError(model_name=self.model_name, message=str(e)) from e
199199

200200
def _process_response(self, response: V2ChatResponse) -> ModelResponse:
201201
"""Process a non-streamed response, and prepare a message to return."""

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pydantic_ai._run_context import RunContext
1212
from pydantic_ai.models.instrumented import InstrumentedModel
1313

14-
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14+
from ..exceptions import FallbackExceptionGroup, ModelAPIError
1515
from ..profiles import ModelProfile
1616
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1717

@@ -36,7 +36,7 @@ def __init__(
3636
self,
3737
default_model: Model | KnownModelName | str,
3838
*fallback_models: Model | KnownModelName | str,
39-
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelHTTPError,),
39+
fallback_on: Callable[[Exception], bool] | tuple[type[Exception], ...] = (ModelAPIError,),
4040
):
4141
"""Initialize a fallback model instance.
4242

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .._output import OutputObjectDefinition
1515
from .._run_context import RunContext
1616
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, UrlContextTool, WebSearchTool
17-
from ..exceptions import ModelHTTPError, UserError
17+
from ..exceptions import ModelAPIError, ModelHTTPError, UserError
1818
from ..messages import (
1919
BinaryContent,
2020
BuiltinToolCallPart,
@@ -410,7 +410,7 @@ async def _generate_content(
410410
model_name=self._model_name,
411411
body=cast(Any, e.details), # pyright: ignore[reportUnknownMemberType]
412412
) from e
413-
raise # pragma: lax no cover
413+
raise ModelAPIError(model_name=self._model_name, message=str(e)) from e
414414

415415
async def _build_content_and_config(
416416
self,

pydantic_ai_slim/pydantic_ai/models/groq.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .._thinking_part import split_content_into_text_and_thinking
1717
from .._utils import generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id, number_to_datetime
1818
from ..builtin_tools import WebSearchTool
19-
from ..exceptions import UserError
19+
from ..exceptions import ModelAPIError, UserError
2020
from ..messages import (
2121
BinaryContent,
2222
BuiltinToolCallPart,
@@ -52,7 +52,7 @@
5252
)
5353

5454
try:
55-
from groq import NOT_GIVEN, APIError, APIStatusError, AsyncGroq, AsyncStream
55+
from groq import NOT_GIVEN, APIConnectionError, APIError, APIStatusError, AsyncGroq, AsyncStream
5656
from groq.types import chat
5757
from groq.types.chat.chat_completion_content_part_image_param import ImageURL
5858
from groq.types.chat.chat_completion_message import ExecutedTool
@@ -314,7 +314,9 @@ async def _completions_create(
314314
except APIStatusError as e:
315315
if (status_code := e.status_code) >= 400:
316316
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
317-
raise # pragma: lax no cover
317+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e # pragma: no cover
318+
except APIConnectionError as e:
319+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e
318320

319321
def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
320322
"""Process a non-streamed response, and prepare a message to return."""

pydantic_ai_slim/pydantic_ai/models/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
1414
from .._run_context import RunContext
1515
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
16-
from ..exceptions import UserError
16+
from ..exceptions import ModelAPIError, UserError
1717
from ..messages import (
1818
BinaryContent,
1919
BuiltinToolCallPart,
@@ -246,7 +246,7 @@ async def _completions_create(
246246
except SDKError as e:
247247
if (status_code := e.status_code) >= 400:
248248
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e
249-
raise # pragma: lax no cover
249+
raise ModelAPIError(model_name=self.model_name, message=e.message) from e
250250

251251
assert response, 'A unexpected empty response from Mistral.'
252252
return response

0 commit comments

Comments
 (0)