|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import json |
3 | 4 | import sys
|
4 | 5 | from collections.abc import AsyncIterator
|
5 | 6 | from datetime import timezone
|
6 | 7 | from typing import Any
|
7 | 8 |
|
8 | 9 | import pytest
|
| 10 | +from dirty_equals import IsJson |
9 | 11 | from inline_snapshot import snapshot
|
| 12 | +from pydantic_core import to_json |
10 | 13 |
|
11 | 14 | from pydantic_ai import Agent, ModelHTTPError
|
12 | 15 | from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, UserPromptPart
|
13 | 16 | from pydantic_ai.models.fallback import FallbackModel
|
14 | 17 | from pydantic_ai.models.function import AgentInfo, FunctionModel
|
| 18 | +from pydantic_ai.settings import ModelSettings |
15 | 19 | from pydantic_ai.usage import Usage
|
16 | 20 |
|
17 | 21 | from ..conftest import IsNow, try_import
|
@@ -445,3 +449,67 @@ async def test_fallback_condition_tuple() -> None:
|
445 | 449 |
|
446 | 450 | response = await agent.run('hello')
|
447 | 451 | assert response.output == 'success'
|
| 452 | + |
| 453 | + |
| 454 | +async def test_fallback_model_settings_merge(): |
| 455 | + """Test that FallbackModel properly merges model settings from wrapped model and runtime settings.""" |
| 456 | + |
| 457 | + def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
| 458 | + return ModelResponse(parts=[TextPart(to_json(info.model_settings).decode())]) |
| 459 | + |
| 460 | + base_model = FunctionModel(return_settings, settings=ModelSettings(temperature=0.1, max_tokens=1024)) |
| 461 | + fallback_model = FallbackModel(base_model) |
| 462 | + |
| 463 | + # Test that base model settings are preserved when no additional settings are provided |
| 464 | + agent = Agent(fallback_model) |
| 465 | + result = await agent.run('Hello') |
| 466 | + assert result.output == IsJson({'max_tokens': 1024, 'temperature': 0.1}) |
| 467 | + |
| 468 | + # Test that runtime model_settings are merged with base settings |
| 469 | + agent_with_settings = Agent(fallback_model, model_settings=ModelSettings(temperature=0.5, parallel_tool_calls=True)) |
| 470 | + result = await agent_with_settings.run('Hello') |
| 471 | + expected = {'max_tokens': 1024, 'temperature': 0.5, 'parallel_tool_calls': True} |
| 472 | + assert result.output == IsJson(expected) |
| 473 | + |
| 474 | + # Test that run-time model_settings override both base and agent settings |
| 475 | + result = await agent_with_settings.run( |
| 476 | + 'Hello', model_settings=ModelSettings(temperature=0.9, extra_headers={'runtime_setting': 'runtime_value'}) |
| 477 | + ) |
| 478 | + expected = { |
| 479 | + 'max_tokens': 1024, |
| 480 | + 'temperature': 0.9, |
| 481 | + 'parallel_tool_calls': True, |
| 482 | + 'extra_headers': { |
| 483 | + 'runtime_setting': 'runtime_value', |
| 484 | + }, |
| 485 | + } |
| 486 | + assert result.output == IsJson(expected) |
| 487 | + |
| 488 | + |
| 489 | +async def test_fallback_model_settings_merge_streaming(): |
| 490 | + """Test that FallbackModel properly merges model settings in streaming mode.""" |
| 491 | + |
| 492 | + async def return_settings_stream(_: list[ModelMessage], info: AgentInfo): |
| 493 | + # Yield the merged settings as JSON to verify they were properly combined |
| 494 | + yield to_json(info.model_settings).decode() |
| 495 | + |
| 496 | + base_model = FunctionModel( |
| 497 | + stream_function=return_settings_stream, |
| 498 | + settings=ModelSettings(temperature=0.1, extra_headers={'anthropic-beta': 'context-1m-2025-08-07'}), |
| 499 | + ) |
| 500 | + fallback_model = FallbackModel(base_model) |
| 501 | + |
| 502 | + # Test that base model settings are preserved in streaming mode |
| 503 | + agent = Agent(fallback_model) |
| 504 | + async with agent.run_stream('Hello') as result: |
| 505 | + output = await result.get_output() |
| 506 | + |
| 507 | + assert json.loads(output) == {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.1} |
| 508 | + |
| 509 | + # Test that runtime model_settings are merged with base settings in streaming mode |
| 510 | + agent_with_settings = Agent(fallback_model, model_settings=ModelSettings(temperature=0.5)) |
| 511 | + async with agent_with_settings.run_stream('Hello') as result: |
| 512 | + output = await result.get_output() |
| 513 | + |
| 514 | + expected = {'extra_headers': {'anthropic-beta': 'context-1m-2025-08-07'}, 'temperature': 0.5} |
| 515 | + assert json.loads(output) == expected |
0 commit comments