Skip to content

Commit b22f95c

Browse files
mscherrmannMoritz ScherrmanndmontaguKludex
authored
Feature/add openai strict mode (#1304)
Co-authored-by: Moritz Scherrmann <[email protected]> Co-authored-by: David Montague <[email protected]> Co-authored-by: Marcelo Trylesinski <[email protected]>
1 parent 1a78d7e commit b22f95c

File tree

14 files changed

+817
-51
lines changed

14 files changed

+817
-51
lines changed

docs/tools.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ print(test_model.last_model_request_parameters.function_tools)
353353
'type': 'object',
354354
},
355355
outer_typed_dict_key=None,
356+
strict=None,
356357
)
357358
]
358359
"""
@@ -456,6 +457,7 @@ print(test_model.last_model_request_parameters.function_tools)
456457
'type': 'object',
457458
},
458459
outer_typed_dict_key=None,
460+
strict=None,
459461
)
460462
]
461463
"""

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ async def _make_request(
311311
return self._result
312312

313313
model_settings, model_request_parameters = await self._prepare_request(ctx)
314+
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)
314315
model_response, request_usage = await ctx.deps.model.request(
315316
ctx.state.message_history, model_settings, model_request_parameters
316317
)

pydantic_ai_slim/pydantic_ai/_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@ def check_object_json_schema(schema: JsonSchemaValue) -> ObjectJsonSchema:
5050
if schema.get('type') == 'object':
5151
return schema
5252
elif schema.get('$ref') is not None:
53-
return schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
53+
maybe_result = schema.get('$defs', {}).get(schema['$ref'][8:]) # This removes the initial "#/$defs/".
54+
55+
if "'$ref': '#/$defs/" in str(maybe_result):
56+
return schema # We can't remove the $defs because the schema contains other references
57+
return maybe_result
5458
else:
5559
raise UserError('Schema must be an object')
5660

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ def tool(
940940
docstring_format: DocstringFormat = 'auto',
941941
require_parameter_descriptions: bool = False,
942942
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
943+
strict: bool | None = None,
943944
) -> Callable[[ToolFuncContext[AgentDepsT, ToolParams]], ToolFuncContext[AgentDepsT, ToolParams]]: ...
944945

945946
def tool(
@@ -953,6 +954,7 @@ def tool(
953954
docstring_format: DocstringFormat = 'auto',
954955
require_parameter_descriptions: bool = False,
955956
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
957+
strict: bool | None = None,
956958
) -> Any:
957959
"""Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument.
958960
@@ -995,6 +997,8 @@ async def spam(ctx: RunContext[str], y: float) -> float:
995997
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
996998
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
997999
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1000+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1001+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
9981002
"""
9991003
if func is None:
10001004

@@ -1011,14 +1015,23 @@ def tool_decorator(
10111015
docstring_format,
10121016
require_parameter_descriptions,
10131017
schema_generator,
1018+
strict,
10141019
)
10151020
return func_
10161021

10171022
return tool_decorator
10181023
else:
10191024
# noinspection PyTypeChecker
10201025
self._register_function(
1021-
func, True, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1026+
func,
1027+
True,
1028+
name,
1029+
retries,
1030+
prepare,
1031+
docstring_format,
1032+
require_parameter_descriptions,
1033+
schema_generator,
1034+
strict,
10221035
)
10231036
return func
10241037

@@ -1036,6 +1049,7 @@ def tool_plain(
10361049
docstring_format: DocstringFormat = 'auto',
10371050
require_parameter_descriptions: bool = False,
10381051
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1052+
strict: bool | None = None,
10391053
) -> Callable[[ToolFuncPlain[ToolParams]], ToolFuncPlain[ToolParams]]: ...
10401054

10411055
def tool_plain(
@@ -1049,6 +1063,7 @@ def tool_plain(
10491063
docstring_format: DocstringFormat = 'auto',
10501064
require_parameter_descriptions: bool = False,
10511065
schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema,
1066+
strict: bool | None = None,
10521067
) -> Any:
10531068
"""Decorator to register a tool function which DOES NOT take `RunContext` as an argument.
10541069
@@ -1091,6 +1106,8 @@ async def spam(ctx: RunContext[str]) -> float:
10911106
Defaults to `'auto'`, such that the format is inferred from the structure of the docstring.
10921107
require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False.
10931108
schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`.
1109+
strict: Whether to enforce JSON schema compliance (only affects OpenAI).
1110+
See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info.
10941111
"""
10951112
if func is None:
10961113

@@ -1105,13 +1122,22 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams
11051122
docstring_format,
11061123
require_parameter_descriptions,
11071124
schema_generator,
1125+
strict,
11081126
)
11091127
return func_
11101128

11111129
return tool_decorator
11121130
else:
11131131
self._register_function(
1114-
func, False, name, retries, prepare, docstring_format, require_parameter_descriptions, schema_generator
1132+
func,
1133+
False,
1134+
name,
1135+
retries,
1136+
prepare,
1137+
docstring_format,
1138+
require_parameter_descriptions,
1139+
schema_generator,
1140+
strict,
11151141
)
11161142
return func
11171143

@@ -1125,6 +1151,7 @@ def _register_function(
11251151
docstring_format: DocstringFormat,
11261152
require_parameter_descriptions: bool,
11271153
schema_generator: type[GenerateJsonSchema],
1154+
strict: bool | None,
11281155
) -> None:
11291156
"""Private utility to register a function as a tool."""
11301157
retries_ = retries if retries is not None else self._default_retries
@@ -1137,6 +1164,7 @@ def _register_function(
11371164
docstring_format=docstring_format,
11381165
require_parameter_descriptions=require_parameter_descriptions,
11391166
schema_generator=schema_generator,
1167+
strict=strict,
11401168
)
11411169
self._register_tool(tool)
11421170

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,15 @@ async def request_stream(
274274
# noinspection PyUnreachableCode
275275
yield # pragma: no cover
276276

277+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
278+
"""Customize the request parameters for the model.
279+
280+
This method can be overridden by subclasses to modify the request parameters before sending them to the model.
281+
In particular, this method can be used to make modifications to the generated tool JSON schemas if necessary
282+
for vendor/model-specific reasons.
283+
"""
284+
return model_request_parameters
285+
277286
@property
278287
@abstractmethod
279288
def model_name(self) -> str:

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterator, Sequence
66
from contextlib import asynccontextmanager
77
from copy import deepcopy
8-
from dataclasses import dataclass, field
8+
from dataclasses import dataclass, field, replace
99
from datetime import datetime
1010
from typing import Annotated, Any, Literal, Protocol, Union, cast
1111
from uuid import uuid4
@@ -152,6 +152,16 @@ async def request_stream(
152152
) as http_response:
153153
yield await self._process_streamed_response(http_response)
154154

155+
def customize_request_parameters(self, model_request_parameters: ModelRequestParameters) -> ModelRequestParameters:
156+
def _customize_tool_def(t: ToolDefinition):
157+
return replace(t, parameters_json_schema=_GeminiJsonSchema(t.parameters_json_schema).simplify())
158+
159+
return ModelRequestParameters(
160+
function_tools=[_customize_tool_def(tool) for tool in model_request_parameters.function_tools],
161+
allow_text_result=model_request_parameters.allow_text_result,
162+
result_tools=[_customize_tool_def(tool) for tool in model_request_parameters.result_tools],
163+
)
164+
155165
@property
156166
def model_name(self) -> GeminiModelName:
157167
"""The model name."""
@@ -640,7 +650,7 @@ class _GeminiFunction(TypedDict):
640650

641651

642652
def _function_from_abstract_tool(tool: ToolDefinition) -> _GeminiFunction:
643-
json_schema = _GeminiJsonSchema(tool.parameters_json_schema).simplify()
653+
json_schema = tool.parameters_json_schema
644654
f = _GeminiFunction(name=tool.name, description=tool.description)
645655
if json_schema.get('properties'):
646656
f['parameters'] = json_schema

0 commit comments

Comments
 (0)