Skip to content

Commit 6d4eb06

Browse files
authored
Automatically use OpenAI strict mode for strict-compatible native output types (#2447)
1 parent 13b712f commit 6d4eb06

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -790,12 +790,18 @@ def get_user_agent() -> str:
790790
def _customize_tool_def(transformer: type[JsonSchemaTransformer], t: ToolDefinition):
791791
schema_transformer = transformer(t.parameters_json_schema, strict=t.strict)
792792
parameters_json_schema = schema_transformer.walk()
793-
if t.strict is None:
794-
t = replace(t, strict=schema_transformer.is_strict_compatible)
795-
return replace(t, parameters_json_schema=parameters_json_schema)
793+
return replace(
794+
t,
795+
parameters_json_schema=parameters_json_schema,
796+
strict=schema_transformer.is_strict_compatible if t.strict is None else t.strict,
797+
)
796798

797799

798800
def _customize_output_object(transformer: type[JsonSchemaTransformer], o: OutputObjectDefinition):
799-
schema_transformer = transformer(o.json_schema, strict=True)
800-
son_schema = schema_transformer.walk()
801-
return replace(o, json_schema=son_schema)
801+
schema_transformer = transformer(o.json_schema, strict=o.strict)
802+
json_schema = schema_transformer.walk()
803+
return replace(
804+
o,
805+
json_schema=json_schema,
806+
strict=schema_transformer.is_strict_compatible if o.strict is None else o.strict,
807+
)

tests/models/test_openai.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pytest
1313
from dirty_equals import IsListOrTuple
1414
from inline_snapshot import snapshot
15-
from pydantic import AnyUrl, BaseModel, Discriminator, Field, Tag
15+
from pydantic import AnyUrl, BaseModel, ConfigDict, Discriminator, Field, Tag
1616
from typing_extensions import NotRequired, TypedDict
1717

1818
from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
@@ -1777,6 +1777,44 @@ class MyModel(BaseModel):
17771777
)
17781778

17791779

1780+
def test_native_output_strict_mode(allow_model_requests: None):
1781+
class CityLocation(BaseModel):
1782+
city: str
1783+
country: str
1784+
1785+
c = completion_message(
1786+
ChatCompletionMessage(content='{"city": "Mexico City", "country": "Mexico"}', role='assistant'),
1787+
)
1788+
mock_client = MockOpenAI.create_mock(c)
1789+
model = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client))
1790+
1791+
# Explicit strict=True
1792+
agent = Agent(model, output_type=NativeOutput(CityLocation, strict=True))
1793+
1794+
agent.run_sync('What is the capital of Mexico?')
1795+
assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is True
1796+
1797+
# Explicit strict=False
1798+
agent = Agent(model, output_type=NativeOutput(CityLocation, strict=False))
1799+
1800+
agent.run_sync('What is the capital of Mexico?')
1801+
assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is False
1802+
1803+
# Strict-compatible
1804+
agent = Agent(model, output_type=NativeOutput(CityLocation))
1805+
1806+
agent.run_sync('What is the capital of Mexico?')
1807+
assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is True
1808+
1809+
# Strict-incompatible
1810+
CityLocation.model_config = ConfigDict(extra='allow')
1811+
1812+
agent = Agent(model, output_type=NativeOutput(CityLocation))
1813+
1814+
agent.run_sync('What is the capital of Mexico?')
1815+
assert get_mock_chat_completion_kwargs(mock_client)[-1]['response_format']['json_schema']['strict'] is False
1816+
1817+
17801818
async def test_openai_instructions(allow_model_requests: None, openai_api_key: str):
17811819
m = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key))
17821820
agent = Agent(m, instructions='You are a helpful assistant.')

0 commit comments

Comments
 (0)