Skip to content

Commit 2f2713b

Browse files
sydney-runklealexmojakidmontaguKludex
authored
Add FallbackModel support (#894)
Co-authored-by: Alex Hall <[email protected]> Co-authored-by: David Montague <[email protected]> Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 8c5cd10 commit 2f2713b

32 files changed

+942
-312
lines changed

docs/agents.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ async def main():
136136
HandleResponseNode(
137137
model_response=ModelResponse(
138138
parts=[TextPart(content='Paris', part_kind='text')],
139-
model_name='function:model_logic',
139+
model_name='gpt-4o',
140140
timestamp=datetime.datetime(...),
141141
kind='response',
142142
)
@@ -197,7 +197,7 @@ async def main():
197197
HandleResponseNode(
198198
model_response=ModelResponse(
199199
parts=[TextPart(content='Paris', part_kind='text')],
200-
model_name='function:model_logic',
200+
model_name='gpt-4o',
201201
timestamp=datetime.datetime(...),
202202
kind='response',
203203
)
@@ -612,7 +612,7 @@ with capture_run_messages() as messages: # (2)!
612612
part_kind='tool-call',
613613
)
614614
],
615-
model_name='function:model_logic',
615+
model_name='gpt-4o',
616616
timestamp=datetime.datetime(...),
617617
kind='response',
618618
),
@@ -637,7 +637,7 @@ with capture_run_messages() as messages: # (2)!
637637
part_kind='tool-call',
638638
)
639639
],
640-
model_name='function:model_logic',
640+
model_name='gpt-4o',
641641
timestamp=datetime.datetime(...),
642642
kind='response',
643643
),

docs/api/models/fallback.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# pydantic_ai.models.fallback
2+
3+
::: pydantic_ai.models.fallback

docs/message-history.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ print(result.all_messages())
6262
part_kind='text',
6363
)
6464
],
65-
model_name='function:model_logic',
65+
model_name='gpt-4o',
6666
timestamp=datetime.datetime(...),
6767
kind='response',
6868
),
@@ -136,7 +136,7 @@ async def main():
136136
part_kind='text',
137137
)
138138
],
139-
model_name='function:stream_model_logic',
139+
model_name='gpt-4o',
140140
timestamp=datetime.datetime(...),
141141
kind='response',
142142
),
@@ -193,7 +193,7 @@ print(result2.all_messages())
193193
part_kind='text',
194194
)
195195
],
196-
model_name='function:model_logic',
196+
model_name='gpt-4o',
197197
timestamp=datetime.datetime(...),
198198
kind='response',
199199
),
@@ -214,7 +214,7 @@ print(result2.all_messages())
214214
part_kind='text',
215215
)
216216
],
217-
model_name='function:model_logic',
217+
model_name='gpt-4o',
218218
timestamp=datetime.datetime(...),
219219
kind='response',
220220
),
@@ -273,7 +273,7 @@ print(result2.all_messages())
273273
part_kind='text',
274274
)
275275
],
276-
model_name='function:model_logic',
276+
model_name='gpt-4o',
277277
timestamp=datetime.datetime(...),
278278
kind='response',
279279
),
@@ -294,7 +294,7 @@ print(result2.all_messages())
294294
part_kind='text',
295295
)
296296
],
297-
model_name='function:model_logic',
297+
model_name='gemini-1.5-pro',
298298
timestamp=datetime.datetime(...),
299299
kind='response',
300300
),

docs/models.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,3 +653,115 @@ For streaming, you'll also need to implement the following abstract base class:
653653
The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py).
654654

655655
For details on when we'll accept contributions adding new models to PydanticAI, see the [contributing guidelines](contributing.md#new-model-rules).
656+
657+
658+
## Fallback
659+
660+
You can use [`FallbackModel`][pydantic_ai.models.fallback.FallbackModel] to attempt multiple models
661+
in sequence until one returns a successful result. Under the hood, PydanticAI automatically switches
662+
from one model to the next if the current model returns a 4xx or 5xx status code.
663+
664+
In the following example, the agent first makes a request to the OpenAI model (which fails due to an invalid API key),
665+
and then falls back to the Anthropic model.
666+
667+
```python {title="fallback_model.py"}
668+
from pydantic_ai import Agent
669+
from pydantic_ai.models.anthropic import AnthropicModel
670+
from pydantic_ai.models.fallback import FallbackModel
671+
from pydantic_ai.models.openai import OpenAIModel
672+
673+
openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
674+
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest')
675+
fallback_model = FallbackModel(openai_model, anthropic_model)
676+
677+
agent = Agent(fallback_model)
678+
response = agent.run_sync('What is the capital of France?')
679+
print(response.data)
680+
#> Paris
681+
682+
print(response.all_messages())
683+
"""
684+
[
685+
ModelRequest(
686+
parts=[
687+
UserPromptPart(
688+
content='What is the capital of France?',
689+
timestamp=datetime.datetime(...),
690+
part_kind='user-prompt',
691+
)
692+
],
693+
kind='request',
694+
),
695+
ModelResponse(
696+
parts=[TextPart(content='Paris', part_kind='text')],
697+
model_name='claude-3-5-sonnet-latest',
698+
timestamp=datetime.datetime(...),
699+
kind='response',
700+
),
701+
]
702+
"""
703+
```
704+
705+
The `ModelResponse` message above indicates in the `model_name` field that the result was returned by the Anthropic model, which is the second model specified in the `FallbackModel`.
706+
707+
!!! note
708+
Each model's options should be configured individually. For example, `base_url`, `api_key`, and custom clients should be set on each model itself, not on the `FallbackModel`.
709+
710+
In this next example, we demonstrate the exception-handling capabilities of `FallbackModel`.
711+
If all models fail, a [`FallbackExceptionGroup`][pydantic_ai.exceptions.FallbackExceptionGroup] is raised, which
712+
contains all the exceptions encountered during the `run` execution.
713+
714+
=== "Python >=3.11"
715+
716+
```python {title="fallback_model_failure.py" py="3.11"}
717+
from pydantic_ai import Agent
718+
from pydantic_ai.exceptions import ModelHTTPError
719+
from pydantic_ai.models.anthropic import AnthropicModel
720+
from pydantic_ai.models.fallback import FallbackModel
721+
from pydantic_ai.models.openai import OpenAIModel
722+
723+
openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
724+
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest', api_key='not-valid')
725+
fallback_model = FallbackModel(openai_model, anthropic_model)
726+
727+
agent = Agent(fallback_model)
728+
try:
729+
response = agent.run_sync('What is the capital of France?')
730+
except* ModelHTTPError as exc_group:
731+
for exc in exc_group.exceptions:
732+
print(exc)
733+
```
734+
735+
=== "Python <3.11"
736+
737+
Since [`except*`](https://docs.python.org/3/reference/compound_stmts.html#except-star) is only supported
738+
in Python 3.11+, we use the [`exceptiongroup`](https://github.com/agronholm/exceptiongroup) backport
739+
package for earlier Python versions:
740+
741+
```python {title="fallback_model_failure.py" noqa="F821" test="skip"}
742+
from exceptiongroup import catch
743+
744+
from pydantic_ai import Agent
745+
from pydantic_ai.exceptions import ModelHTTPError
746+
from pydantic_ai.models.anthropic import AnthropicModel
747+
from pydantic_ai.models.fallback import FallbackModel
748+
from pydantic_ai.models.openai import OpenAIModel
749+
750+
751+
def model_status_error_handler(exc_group: BaseExceptionGroup) -> None:
752+
for exc in exc_group.exceptions:
753+
print(exc)
754+
755+
756+
openai_model = OpenAIModel('gpt-4o', api_key='not-valid')
757+
anthropic_model = AnthropicModel('claude-3-5-sonnet-latest', api_key='not-valid')
758+
fallback_model = FallbackModel(openai_model, anthropic_model)
759+
760+
agent = Agent(fallback_model)
761+
with catch({ModelHTTPError: model_status_error_handler}):
762+
response = agent.run_sync('What is the capital of France?')
763+
```
764+
765+
By default, the `FallbackModel` only moves on to the next model if the current model raises a
766+
[`ModelHTTPError`][pydantic_ai.exceptions.ModelHTTPError]. You can customize this behavior by
767+
passing a custom `fallback_on` argument to the `FallbackModel` constructor.

docs/tools.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ print(dice_result.all_messages())
8989
tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call'
9090
)
9191
],
92-
model_name='function:model_logic',
92+
model_name='gemini-1.5-flash',
9393
timestamp=datetime.datetime(...),
9494
kind='response',
9595
),
@@ -114,7 +114,7 @@ print(dice_result.all_messages())
114114
part_kind='tool-call',
115115
)
116116
],
117-
model_name='function:model_logic',
117+
model_name='gemini-1.5-flash',
118118
timestamp=datetime.datetime(...),
119119
kind='response',
120120
),
@@ -137,7 +137,7 @@ print(dice_result.all_messages())
137137
part_kind='text',
138138
)
139139
],
140-
model_name='function:model_logic',
140+
model_name='gemini-1.5-flash',
141141
timestamp=datetime.datetime(...),
142142
kind='response',
143143
),

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ nav:
5858
- api/models/mistral.md
5959
- api/models/test.md
6060
- api/models/function.md
61+
- api/models/fallback.md
6162
- api/pydantic_graph/graph.md
6263
- api/pydantic_graph/nodes.md
6364
- api/pydantic_graph/state.md

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from importlib.metadata import version
22

33
from .agent import Agent, EndStrategy, HandleResponseNode, ModelRequestNode, UserPromptNode, capture_run_messages
4-
from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError
4+
from .exceptions import (
5+
AgentRunError,
6+
FallbackExceptionGroup,
7+
ModelHTTPError,
8+
ModelRetry,
9+
UnexpectedModelBehavior,
10+
UsageLimitExceeded,
11+
UserError,
12+
)
513
from .messages import AudioUrl, BinaryContent, ImageUrl
614
from .tools import RunContext, Tool
715

@@ -17,6 +25,8 @@
1725
# exceptions
1826
'AgentRunError',
1927
'ModelRetry',
28+
'ModelHTTPError',
29+
'FallbackExceptionGroup',
2030
'UnexpectedModelBehavior',
2131
'UsageLimitExceeded',
2232
'UserError',

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ async def main():
365365
HandleResponseNode(
366366
model_response=ModelResponse(
367367
parts=[TextPart(content='Paris', part_kind='text')],
368-
model_name='function:model_logic',
368+
model_name='gpt-4o',
369369
timestamp=datetime.datetime(...),
370370
kind='response',
371371
)
@@ -1214,7 +1214,7 @@ async def main():
12141214
HandleResponseNode(
12151215
model_response=ModelResponse(
12161216
parts=[TextPart(content='Paris', part_kind='text')],
1217-
model_name='function:model_logic',
1217+
model_name='gpt-4o',
12181218
timestamp=datetime.datetime(...),
12191219
kind='response',
12201220
)
@@ -1357,7 +1357,7 @@ async def main():
13571357
HandleResponseNode(
13581358
model_response=ModelResponse(
13591359
parts=[TextPart(content='Paris', part_kind='text')],
1360-
model_name='function:model_logic',
1360+
model_name='gpt-4o',
13611361
timestamp=datetime.datetime(...),
13621362
kind='response',
13631363
)

pydantic_ai_slim/pydantic_ai/exceptions.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
from __future__ import annotations as _annotations
22

33
import json
4+
import sys
45

5-
__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded'
6+
if sys.version_info < (3, 11):
7+
from exceptiongroup import ExceptionGroup
8+
else:
9+
ExceptionGroup = ExceptionGroup
10+
11+
__all__ = (
12+
'ModelRetry',
13+
'UserError',
14+
'AgentRunError',
15+
'UnexpectedModelBehavior',
16+
'UsageLimitExceeded',
17+
'ModelHTTPError',
18+
'FallbackExceptionGroup',
19+
)
620

721

822
class ModelRetry(Exception):
@@ -72,3 +86,30 @@ def __str__(self) -> str:
7286
return f'{self.message}, body:\n{self.body}'
7387
else:
7488
return self.message
89+
90+
91+
class ModelHTTPError(AgentRunError):
92+
"""Raised when an model provider response has a status code of 4xx or 5xx."""
93+
94+
status_code: int
95+
"""The HTTP status code returned by the API."""
96+
97+
model_name: str
98+
"""The name of the model associated with the error."""
99+
100+
body: object | None
101+
"""The body of the response, if available."""
102+
103+
message: str
104+
"""The error message with the status code and response body, if available."""
105+
106+
def __init__(self, status_code: int, model_name: str, body: object | None = None):
107+
self.status_code = status_code
108+
self.model_name = model_name
109+
self.body = body
110+
message = f'status_code: {status_code}, model_name: {model_name}, body: {body}'
111+
super().__init__(message)
112+
113+
114+
class FallbackExceptionGroup(ExceptionGroup):
115+
"""A group of exceptions that can be raised when all fallback models fail."""

0 commit comments

Comments
 (0)