Skip to content

Commit d78b77e

Browse files
authored
Let more BaseModels use OpenAI strict JSON mode by defaulting to additionalProperties=False (#2419)
1 parent f36aaa6 commit d78b77e

File tree

6 files changed

+198
-20
lines changed

6 files changed

+198
-20
lines changed

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def _build_schema(
285285
td_schema = core_schema.typed_dict_schema(
286286
fields,
287287
config=core_config,
288-
total=var_kwargs_schema is None,
289288
extras_schema=gen_schema.generate_schema(var_kwargs_schema) if var_kwargs_schema else None,
290289
)
291290
return td_schema, None

pydantic_ai_slim/pydantic_ai/profiles/openai.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,13 @@ def transform(self, schema: JsonSchema) -> JsonSchema: # noqa C901
166166
schema['required'] = list(schema['properties'].keys())
167167

168168
elif self.strict is None:
169-
if (
170-
schema.get('additionalProperties') is not False
171-
or 'properties' not in schema
172-
or 'required' not in schema
173-
):
169+
if schema.get('additionalProperties', None) not in (None, False):
170+
self.is_strict_compatible = False
171+
else:
172+
# additional properties are disallowed by default
173+
schema['additionalProperties'] = False
174+
175+
if 'properties' not in schema or 'required' not in schema:
174176
self.is_strict_compatible = False
175177
else:
176178
required = schema['required']

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,19 @@ async def turn_on_strict_if_openai(
133133

134134
class GenerateToolJsonSchema(GenerateJsonSchema):
135135
def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaValue:
136-
s = super().typed_dict_schema(schema)
137-
total = schema.get('total')
138-
if 'additionalProperties' not in s and (total is True or total is None):
139-
s['additionalProperties'] = False
140-
return s
136+
json_schema = super().typed_dict_schema(schema)
137+
# Workaround for https://github.com/pydantic/pydantic/issues/12123
138+
if 'additionalProperties' not in json_schema: # pragma: no branch
139+
extra = schema.get('extra_behavior') or schema.get('config', {}).get('extra_fields_behavior')
140+
if extra == 'allow':
141+
extras_schema = schema.get('extras_schema', None)
142+
if extras_schema is not None:
143+
json_schema['additionalProperties'] = self.generate_inner(extras_schema) or True
144+
else:
145+
json_schema['additionalProperties'] = True # pragma: no cover
146+
elif extra == 'forbid':
147+
json_schema['additionalProperties'] = False
148+
return json_schema
141149

142150
def _named_required_fields_schema(self, named_required_fields: Sequence[tuple[str, bool, Any]]) -> JsonSchemaValue:
143151
# Remove largely-useless property titles

tests/models/test_openai.py

Lines changed: 176 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dirty_equals import IsListOrTuple
1414
from inline_snapshot import snapshot
1515
from pydantic import AnyUrl, BaseModel, Discriminator, Field, Tag
16-
from typing_extensions import TypedDict
16+
from typing_extensions import NotRequired, TypedDict
1717

1818
from pydantic_ai import Agent, ModelHTTPError, ModelRetry, UnexpectedModelBehavior
1919
from pydantic_ai.messages import (
@@ -1082,7 +1082,37 @@ class MyDefaultRecursiveDc:
10821082
field: MyDefaultRecursiveDc | None = None
10831083

10841084

1085-
class MyModel(BaseModel, extra='allow'):
1085+
class MyModel(BaseModel):
1086+
foo: str
1087+
1088+
1089+
class MyDc(BaseModel):
1090+
foo: str
1091+
1092+
1093+
class MyOptionalDc(BaseModel):
1094+
foo: str | None
1095+
bar: str
1096+
1097+
1098+
class MyExtrasDc(BaseModel, extra='allow'):
1099+
foo: str
1100+
1101+
1102+
class MyNormalTypedDict(TypedDict):
1103+
foo: str
1104+
1105+
1106+
class MyOptionalTypedDict(TypedDict):
1107+
foo: NotRequired[str]
1108+
bar: str
1109+
1110+
1111+
class MyPartialTypedDict(TypedDict, total=False):
1112+
foo: str
1113+
1114+
1115+
class MyExtrasModel(BaseModel, extra='allow'):
10861116
pass
10871117

10881118

@@ -1106,14 +1136,46 @@ def tool_with_recursion(x: MyRecursiveDc, y: MyDefaultRecursiveDc):
11061136
return f'{x} {y}' # pragma: no cover
11071137

11081138

1109-
def tool_with_additional_properties(x: MyModel) -> str:
1139+
def tool_with_model(x: MyModel) -> str:
1140+
return f'{x}' # pragma: no cover
1141+
1142+
1143+
def tool_with_dataclass(x: MyDc) -> str:
1144+
return f'{x}' # pragma: no cover
1145+
1146+
1147+
def tool_with_optional_dataclass(x: MyOptionalDc) -> str:
1148+
return f'{x}' # pragma: no cover
1149+
1150+
1151+
def tool_with_dataclass_with_extras(x: MyExtrasDc) -> str:
1152+
return f'{x}' # pragma: no cover
1153+
1154+
1155+
def tool_with_typed_dict(x: MyNormalTypedDict) -> str:
1156+
return f'{x}' # pragma: no cover
1157+
1158+
1159+
def tool_with_optional_typed_dict(x: MyOptionalTypedDict) -> str:
1160+
return f'{x}' # pragma: no cover
1161+
1162+
1163+
def tool_with_partial_typed_dict(x: MyPartialTypedDict) -> str:
1164+
return f'{x}' # pragma: no cover
1165+
1166+
1167+
def tool_with_model_with_extras(x: MyExtrasModel) -> str:
11101168
return f'{x}' # pragma: no cover
11111169

11121170

11131171
def tool_with_kwargs(x: int, **kwargs: Any) -> str:
11141172
return f'{x} {kwargs}' # pragma: no cover
11151173

11161174

1175+
def tool_with_typed_kwargs(x: int, **kwargs: int) -> str:
1176+
return f'{x} {kwargs}' # pragma: no cover
1177+
1178+
11171179
def tool_with_union(x: int | MyDefaultDc) -> str:
11181180
return f'{x}' # pragma: no cover
11191181

@@ -1216,6 +1278,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
12161278
}
12171279
},
12181280
'type': 'object',
1281+
'additionalProperties': False,
12191282
},
12201283
'MyEnum': {'enum': ['a', 'b'], 'type': 'string'},
12211284
'MyRecursiveDc': {
@@ -1225,6 +1288,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
12251288
},
12261289
'required': ['field', 'my_enum'],
12271290
'type': 'object',
1291+
'additionalProperties': False,
12281292
},
12291293
},
12301294
'additionalProperties': False,
@@ -1275,7 +1339,97 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
12751339
snapshot(True),
12761340
),
12771341
(
1278-
tool_with_additional_properties,
1342+
tool_with_model,
1343+
None,
1344+
snapshot(
1345+
{
1346+
'additionalProperties': False,
1347+
'properties': {'foo': {'type': 'string'}},
1348+
'required': ['foo'],
1349+
'type': 'object',
1350+
}
1351+
),
1352+
snapshot(True),
1353+
),
1354+
(
1355+
tool_with_dataclass,
1356+
None,
1357+
snapshot(
1358+
{
1359+
'additionalProperties': False,
1360+
'properties': {'foo': {'type': 'string'}},
1361+
'required': ['foo'],
1362+
'type': 'object',
1363+
}
1364+
),
1365+
snapshot(True),
1366+
),
1367+
(
1368+
tool_with_optional_dataclass,
1369+
None,
1370+
snapshot(
1371+
{
1372+
'additionalProperties': False,
1373+
'properties': {'foo': {'anyOf': [{'type': 'string'}, {'type': 'null'}]}, 'bar': {'type': 'string'}},
1374+
'required': ['foo', 'bar'],
1375+
'type': 'object',
1376+
}
1377+
),
1378+
snapshot(True),
1379+
),
1380+
(
1381+
tool_with_dataclass_with_extras,
1382+
None,
1383+
snapshot(
1384+
{
1385+
'additionalProperties': True,
1386+
'properties': {'foo': {'type': 'string'}},
1387+
'required': ['foo'],
1388+
'type': 'object',
1389+
}
1390+
),
1391+
snapshot(None),
1392+
),
1393+
(
1394+
tool_with_typed_dict,
1395+
None,
1396+
snapshot(
1397+
{
1398+
'additionalProperties': False,
1399+
'properties': {'foo': {'type': 'string'}},
1400+
'required': ['foo'],
1401+
'type': 'object',
1402+
}
1403+
),
1404+
snapshot(True),
1405+
),
1406+
(
1407+
tool_with_optional_typed_dict,
1408+
None,
1409+
snapshot(
1410+
{
1411+
'additionalProperties': False,
1412+
'properties': {'foo': {'type': 'string'}, 'bar': {'type': 'string'}},
1413+
'required': ['bar'],
1414+
'type': 'object',
1415+
}
1416+
),
1417+
snapshot(None),
1418+
),
1419+
(
1420+
tool_with_partial_typed_dict,
1421+
None,
1422+
snapshot(
1423+
{
1424+
'additionalProperties': False,
1425+
'properties': {'foo': {'type': 'string'}},
1426+
'type': 'object',
1427+
}
1428+
),
1429+
snapshot(None),
1430+
),
1431+
(
1432+
tool_with_model_with_extras,
12791433
None,
12801434
snapshot(
12811435
{
@@ -1287,7 +1441,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
12871441
snapshot(None),
12881442
),
12891443
(
1290-
tool_with_additional_properties,
1444+
tool_with_model_with_extras,
12911445
True,
12921446
snapshot(
12931447
{
@@ -1304,6 +1458,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
13041458
None,
13051459
snapshot(
13061460
{
1461+
'additionalProperties': True,
13071462
'properties': {'x': {'type': 'integer'}},
13081463
'required': ['x'],
13091464
'type': 'object',
@@ -1324,6 +1479,19 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
13241479
),
13251480
snapshot(True),
13261481
),
1482+
(
1483+
tool_with_typed_kwargs,
1484+
None,
1485+
snapshot(
1486+
{
1487+
'additionalProperties': {'type': 'integer'},
1488+
'properties': {'x': {'type': 'integer'}},
1489+
'required': ['x'],
1490+
'type': 'object',
1491+
}
1492+
),
1493+
snapshot(None),
1494+
),
13271495
(
13281496
tool_with_union,
13291497
None,
@@ -1333,6 +1501,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
13331501
'MyDefaultDc': {
13341502
'properties': {'x': {'default': 1, 'type': 'integer'}},
13351503
'type': 'object',
1504+
'additionalProperties': False,
13361505
}
13371506
},
13381507
'additionalProperties': False,
@@ -1373,6 +1542,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
13731542
'MyDefaultDc': {
13741543
'properties': {'x': {'default': 1, 'type': 'integer'}},
13751544
'type': 'object',
1545+
'additionalProperties': False,
13761546
}
13771547
},
13781548
'additionalProperties': False,
@@ -1413,6 +1583,7 @@ def tool_with_tuples(x: tuple[int], y: tuple[str] = ('abc',)) -> str:
14131583
'MyDefaultDc': {
14141584
'properties': {'x': {'default': 1, 'type': 'integer'}},
14151585
'type': 'object',
1586+
'additionalProperties': False,
14161587
}
14171588
},
14181589
'additionalProperties': False,

tests/test_agent.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ def test_response_tuple():
387387
name='final_result',
388388
description='The final response which ends this conversation',
389389
parameters_json_schema={
390-
'additionalProperties': False,
391390
'properties': {
392391
'response': {
393392
'maxItems': 2,
@@ -637,7 +636,6 @@ class Bar(BaseModel):
637636
'type': 'object',
638637
},
639638
},
640-
'additionalProperties': False,
641639
'properties': {'response': {'anyOf': [{'$ref': '#/$defs/Foo'}, {'$ref': '#/$defs/Bar'}]}},
642640
'required': ['response'],
643641
'type': 'object',

tests/test_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def test_docstring_unknown():
387387
{
388388
'name': 'unknown_docstring',
389389
'description': 'Unknown style docstring.',
390-
'parameters_json_schema': {'properties': {}, 'type': 'object'},
390+
'parameters_json_schema': {'additionalProperties': {'type': 'integer'}, 'properties': {}, 'type': 'object'},
391391
'outer_typed_dict_key': None,
392392
'strict': None,
393393
'kind': 'function',
@@ -1031,6 +1031,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] =
10311031
'name': 'my_tool_1',
10321032
'outer_typed_dict_key': None,
10331033
'parameters_json_schema': {
1034+
'additionalProperties': True,
10341035
'properties': {'x': {'default': None, 'type': 'string'}},
10351036
'type': 'object',
10361037
},
@@ -1071,7 +1072,6 @@ def get_score(data: Data) -> int: ... # pragma: no branch
10711072
'name': 'get_score',
10721073
'description': None,
10731074
'parameters_json_schema': {
1074-
'additionalProperties': False,
10751075
'properties': {
10761076
'a': {'description': 'The first parameter', 'type': 'integer'},
10771077
'b': {'description': 'The second parameter', 'type': 'integer'},

0 commit comments

Comments
 (0)