Skip to content

Commit 073983c

Browse files
Make non-required parameters not required (#742)
Co-authored-by: Sydney Runkle <[email protected]>
1 parent e6af102 commit 073983c

File tree

2 files changed

+99
-4
lines changed

2 files changed

+99
-4
lines changed

pydantic_ai_slim/pydantic_ai/_pydantic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def function_schema( # noqa: C901
115115
field_name,
116116
field_info,
117117
decorators,
118+
required=p.default is Parameter.empty,
118119
)
119120
# noinspection PyTypeChecker
120121
td_schema.setdefault('metadata', {})['is_model_like'] = is_model_like(annotation)

tests/test_tools.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,15 @@
1010
from pydantic_core import PydanticSerializationError
1111

1212
from pydantic_ai import Agent, RunContext, Tool, UserError
13-
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart
13+
from pydantic_ai.messages import (
14+
ArgsDict,
15+
ModelMessage,
16+
ModelRequest,
17+
ModelResponse,
18+
TextPart,
19+
ToolCallPart,
20+
ToolReturnPart,
21+
)
1422
from pydantic_ai.models.function import AgentInfo, FunctionModel
1523
from pydantic_ai.models.test import TestModel
1624
from pydantic_ai.tools import ToolDefinition
@@ -73,9 +81,11 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover
7381

7482

7583
async def get_json_schema(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
76-
assert len(info.function_tools) == 1
77-
r = info.function_tools[0]
78-
return ModelResponse(parts=[TextPart(pydantic_core.to_json(r).decode())])
84+
if len(info.function_tools) == 1:
85+
r = info.function_tools[0]
86+
return ModelResponse(parts=[TextPart(pydantic_core.to_json(r).decode())])
87+
else:
88+
return ModelResponse(parts=[TextPart(pydantic_core.to_json(info.function_tools).decode())])
7989

8090

8191
@pytest.mark.parametrize('docstring_format', ['google', 'auto'])
@@ -591,3 +601,87 @@ def test_enforce_parameter_descriptions() -> None:
591601
'bar',
592602
]
593603
assert all(err_part in error_reason for err_part in error_parts)
604+
605+
606+
def test_json_schema_required_parameters(set_event_loop: None):
607+
agent = Agent(FunctionModel(get_json_schema))
608+
609+
@agent.tool
610+
def my_tool(ctx: RunContext[None], a: int, b: int = 1) -> int:
611+
raise NotImplementedError
612+
613+
@agent.tool_plain
614+
def my_tool_plain(*, a: int = 1, b: int) -> int:
615+
raise NotImplementedError
616+
617+
result = agent.run_sync('Hello')
618+
json_schema = json.loads(result.data)
619+
assert json_schema == snapshot(
620+
[
621+
{
622+
'description': '',
623+
'name': 'my_tool',
624+
'outer_typed_dict_key': None,
625+
'parameters_json_schema': {
626+
'additionalProperties': False,
627+
'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}},
628+
'required': ['a'],
629+
'type': 'object',
630+
},
631+
},
632+
{
633+
'description': '',
634+
'name': 'my_tool_plain',
635+
'outer_typed_dict_key': None,
636+
'parameters_json_schema': {
637+
'additionalProperties': False,
638+
'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}},
639+
'required': ['b'],
640+
'type': 'object',
641+
},
642+
},
643+
]
644+
)
645+
646+
647+
def test_call_tool_without_unrequired_parameters(set_event_loop: None):
648+
async def call_tools_first(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
649+
if len(messages) == 1:
650+
return ModelResponse(
651+
parts=[
652+
ToolCallPart(tool_name='my_tool', args=ArgsDict({'a': 13})),
653+
ToolCallPart(tool_name='my_tool', args=ArgsDict({'a': 13, 'b': 4})),
654+
ToolCallPart(tool_name='my_tool_plain', args=ArgsDict({'b': 17})),
655+
ToolCallPart(tool_name='my_tool_plain', args=ArgsDict({'a': 4, 'b': 17})),
656+
]
657+
)
658+
else:
659+
return ModelResponse(parts=[TextPart('finished')])
660+
661+
agent = Agent(FunctionModel(call_tools_first))
662+
663+
@agent.tool
664+
def my_tool(ctx: RunContext[None], a: int, b: int = 2) -> int:
665+
return a + b
666+
667+
@agent.tool_plain
668+
def my_tool_plain(*, a: int = 3, b: int) -> int:
669+
return a * b
670+
671+
result = agent.run_sync('Hello')
672+
all_messages = result.all_messages()
673+
first_response = all_messages[1]
674+
second_request = all_messages[2]
675+
assert isinstance(first_response, ModelResponse)
676+
assert isinstance(second_request, ModelRequest)
677+
tool_call_args = [p.args for p in first_response.parts if isinstance(p, ToolCallPart)]
678+
tool_returns = [p.content for p in second_request.parts if isinstance(p, ToolReturnPart)]
679+
assert tool_call_args == snapshot(
680+
[
681+
ArgsDict(args_dict={'a': 13}),
682+
ArgsDict(args_dict={'a': 13, 'b': 4}),
683+
ArgsDict(args_dict={'b': 17}),
684+
ArgsDict(args_dict={'a': 4, 'b': 17}),
685+
]
686+
)
687+
assert tool_returns == snapshot([15, 17, 51, 68])

0 commit comments

Comments
 (0)