|
12 | 12 | import pytest
|
13 | 13 | from dirty_equals import IsListOrTuple
|
14 | 14 | from inline_snapshot import snapshot
|
15 |
| -from pydantic import AnyUrl, BaseModel, Discriminator, Field, Tag |
| 15 | +from pydantic import AnyUrl, BaseModel, ConfigDict, Discriminator, Field, Tag |
16 | 16 | from typing_extensions import NotRequired, TypedDict
|
17 | 17 |
|
18 | 18 | from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
|
@@ -1777,6 +1777,44 @@ class MyModel(BaseModel):
|
1777 | 1777 | )
|
1778 | 1778 |
|
1779 | 1779 |
|
| 1780 | +def test_native_output_strict_mode(allow_model_requests: None): |
| 1781 | + class CityLocation(BaseModel): |
| 1782 | + city: str |
| 1783 | + country: str |
| 1784 | + |
| 1785 | + c = completion_message( |
| 1786 | + ChatCompletionMessage(content='{"city": "Mexico City", "country": "Mexico"}', role='assistant'), |
| 1787 | + ) |
| 1788 | + mock_client = MockOpenAI.create_mock(c) |
| 1789 | + model = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) |
| 1790 | + |
| 1791 | + # Explicit strict=True |
| 1792 | + agent = Agent(model, output_type=NativeOutput(CityLocation, strict=True)) |
| 1793 | + |
| 1794 | + agent.run_sync('What is the capital of Mexico?') |
| 1795 | + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is True |
| 1796 | + |
| 1797 | + # Explicit strict=False |
| 1798 | + agent = Agent(model, output_type=NativeOutput(CityLocation, strict=False)) |
| 1799 | + |
| 1800 | + agent.run_sync('What is the capital of Mexico?') |
| 1801 | + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is False |
| 1802 | + |
| 1803 | + # Strict-compatible |
| 1804 | + agent = Agent(model, output_type=NativeOutput(CityLocation)) |
| 1805 | + |
| 1806 | + agent.run_sync('What is the capital of Mexico?') |
| 1807 | + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is True |
| 1808 | + |
| 1809 | + # Strict-incompatible |
| 1810 | + CityLocation.model_config = ConfigDict(extra='allow') |
| 1811 | + |
| 1812 | + agent = Agent(model, output_type=NativeOutput(CityLocation)) |
| 1813 | + |
| 1814 | + agent.run_sync('What is the capital of Mexico?') |
| 1815 | + assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is False |
| 1816 | + |
| 1817 | + |
1780 | 1818 | async def test_openai_instructions(allow_model_requests: None, openai_api_key: str):
|
1781 | 1819 | m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key))
|
1782 | 1820 | agent = Agent(m, instructions='You are a helpful assistant.')
|
|
0 commit comments