Skip to content

Commit cef7a91

Browse files
authored
Fix FallbackModel to respect each model's model settings (#2540)
1 parent c8bcc6e commit cef7a91

File tree

2 files changed

+75
-2
lines changed

2 files changed

+75
-2
lines changed

pydantic_ai_slim/pydantic_ai/models/fallback.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic_ai.models.instrumented import InstrumentedModel
1212

1313
from ..exceptions import FallbackExceptionGroup, ModelHTTPError
14+
from ..settings import merge_model_settings
1415
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model
1516

1617
if TYPE_CHECKING:
@@ -65,8 +66,9 @@ async def request(
6566

6667
for model in self.models:
6768
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
69+
merged_settings = merge_model_settings(model.settings, model_settings)
6870
try:
69-
response = await model.request(messages, model_settings, customized_model_request_parameters)
71+
response = await model.request(messages, merged_settings, customized_model_request_parameters)
7072
except Exception as exc:
7173
if self._fallback_on(exc):
7274
exceptions.append(exc)
@@ -91,10 +93,13 @@ async def request_stream(
9193

9294
for model in self.models:
9395
customized_model_request_parameters = model.customize_request_parameters(model_request_parameters)
96+
merged_settings = merge_model_settings(model.settings, model_settings)
9497
async with AsyncExitStack() as stack:
9598
try:
9699
response = await stack.enter_async_context(
97-
model.request_stream(messages, model_settings, customized_model_request_parameters, run_context)
100+
model.request_stream(
101+
messages, merged_settings, customized_model_request_parameters, run_context
102+
)
98103
)
99104
except Exception as exc:
100105
if self._fallback_on(exc):

tests/models/test_fallback.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
from __future__ import annotations
22

3+
import json
34
import sys
45
from collections.abc import AsyncIterator
56
from datetime import timezone
67
from typing import Any
78

89
import pytest
10+
from dirty_equals import IsJson
911
from inline_snapshot import snapshot
12+
from pydantic_core import to_json
1013

1114
from pydantic_ai import Agent, ModelHTTPError
1215
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart
1316
from pydantic_ai.models.fallback import FallbackModel
1417
from pydantic_ai.models.function import AgentInfo, FunctionModel
18+
from pydantic_ai.settings import ModelSettings
1519
from pydantic_ai.usage import Usage
1620

1721
from ..conftest import IsNow, try_import
@@ -445,3 +449,67 @@ async def test_fallback_condition_tuple() -> None:
445449

446450
response = await agent.run('hello')
447451
assert response.output == 'success'
452+
453+
454+
async def test_fallback_model_settings_merge():
455+
"""Test that FallbackModel properly merges model settings from wrapped model and runtime settings."""
456+
457+
def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
458+
return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())])
459+
460+
base_model = FunctionModel(return_settings, settings=ModelSettings(temperature=0.1, max_tokens=1024))
461+
fallback_model = FallbackModel(base_model)
462+
463+
# Test that base model settings are preserved when no additional settings are provided
464+
agent = Agent(fallback_model)
465+
result = await agent.run('Hello')
466+
assert result.output == IsJson({'max_tokens': 1024, 'temperature': 0.1})
467+
468+
# Test that runtime model_settings are merged with base settings
469+
agent_with_settings = Agent(fallback_model, model_settings=ModelSettings(temperature=0.5, parallel_tool_calls=True))
470+
result = await agent_with_settings.run('Hello')
471+
expected = {'max_tokens': 1024, 'temperature': 0.5, 'parallel_tool_calls': True}
472+
assert result.output == IsJson(expected)
473+
474+
# Test that run-time model_settings override both base and agent settings
475+
result = await agent_with_settings.run(
476+
'Hello', model_settings=ModelSettings(temperature=0.9, extra_headers={'runtime_setting': 'runtime_value'})
477+
)
478+
expected = {
479+
'max_tokens': 1024,
480+
'temperature': 0.9,
481+
'parallel_tool_calls': True,
482+
'extra_headers': {
483+
'runtime_setting': 'runtime_value',
484+
},
485+
}
486+
assert result.output == IsJson(expected)
487+
488+
489+
async def test_fallback_model_settings_merge_streaming():
490+
"""Test that FallbackModel properly merges model settings in streaming mode."""
491+
492+
async def return_settings_stream(_: list[ModelMessage], info: AgentInfo):
493+
# Yield the merged settings as JSON to verify they were properly combined
494+
yield to_json(info.model_settings).decode()
495+
496+
base_model = FunctionModel(
497+
stream_function=return_settings_stream,
498+
settings=ModelSettings(temperature=0.1, extra_headers={'anthropic-beta': 'context-1m-2025-08-07'}),
499+
)
500+
fallback_model = FallbackModel(base_model)
501+
502+
# Test that base model settings are preserved in streaming mode
503+
agent = Agent(fallback_model)
504+
async with agent.run_stream('Hello') as result:
505+
output = await result.get_output()
506+
507+
assert json.loads(output) == {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.1}
508+
509+
# Test that runtime model_settings are merged with base settings in streaming mode
510+
agent_with_settings = Agent(fallback_model, model_settings=ModelSettings(temperature=0.5))
511+
async with agent_with_settings.run_stream('Hello') as result:
512+
output = await result.get_output()
513+
514+
expected = {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.5}
515+
assert json.loads(output) == expected

0 commit comments

Comments
 (0)