|
10 | 10 | from pydantic_core import PydanticSerializationError
|
11 | 11 |
|
12 | 12 | from pydantic_ai import Agent, RunContext, Tool, UserError
|
13 |
| -from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart |
| 13 | +from pydantic_ai.messages import ( |
| 14 | + ArgsDict, |
| 15 | + ModelMessage, |
| 16 | + ModelRequest, |
| 17 | + ModelResponse, |
| 18 | + TextPart, |
| 19 | + ToolCallPart, |
| 20 | + ToolReturnPart, |
| 21 | +) |
14 | 22 | from pydantic_ai.models.function import AgentInfo, FunctionModel
|
15 | 23 | from pydantic_ai.models.test import TestModel
|
16 | 24 | from pydantic_ai.tools import ToolDefinition
|
@@ -73,9 +81,11 @@ async def google_style_docstring(foo: int, bar: str) -> str: # pragma: no cover
|
73 | 81 |
|
74 | 82 |
|
75 | 83 | async def get_json_schema(_messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
|
76 |
| - assert len(info.function_tools) == 1 |
77 |
| - r = info.function_tools[0] |
78 |
| - return ModelResponse(parts=[TextPart(pydantic_core.to_json(r).decode())]) |
| 84 | + if len(info.function_tools) == 1: |
| 85 | + r = info.function_tools[0] |
| 86 | + return ModelResponse(parts=[TextPart(pydantic_core.to_json(r).decode())]) |
| 87 | + else: |
| 88 | + return ModelResponse(parts=[TextPart(pydantic_core.to_json(info.function_tools).decode())]) |
79 | 89 |
|
80 | 90 |
|
81 | 91 | @pytest.mark.parametrize('docstring_format', ['google', 'auto'])
|
@@ -591,3 +601,87 @@ def test_enforce_parameter_descriptions() -> None:
|
591 | 601 | 'bar',
|
592 | 602 | ]
|
593 | 603 | assert all(err_part in error_reason for err_part in error_parts)
|
| 604 | + |
| 605 | + |
| 606 | +def test_json_schema_required_parameters(set_event_loop: None): |
| 607 | + agent = Agent(FunctionModel(get_json_schema)) |
| 608 | + |
| 609 | + @agent.tool |
| 610 | + def my_tool(ctx: RunContext[None], a: int, b: int = 1) -> int: |
| 611 | + raise NotImplementedError |
| 612 | + |
| 613 | + @agent.tool_plain |
| 614 | + def my_tool_plain(*, a: int = 1, b: int) -> int: |
| 615 | + raise NotImplementedError |
| 616 | + |
| 617 | + result = agent.run_sync('Hello') |
| 618 | + json_schema = json.loads(result.data) |
| 619 | + assert json_schema == snapshot( |
| 620 | + [ |
| 621 | + { |
| 622 | + 'description': '', |
| 623 | + 'name': 'my_tool', |
| 624 | + 'outer_typed_dict_key': None, |
| 625 | + 'parameters_json_schema': { |
| 626 | + 'additionalProperties': False, |
| 627 | + 'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}, |
| 628 | + 'required': ['a'], |
| 629 | + 'type': 'object', |
| 630 | + }, |
| 631 | + }, |
| 632 | + { |
| 633 | + 'description': '', |
| 634 | + 'name': 'my_tool_plain', |
| 635 | + 'outer_typed_dict_key': None, |
| 636 | + 'parameters_json_schema': { |
| 637 | + 'additionalProperties': False, |
| 638 | + 'properties': {'a': {'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'integer'}}, |
| 639 | + 'required': ['b'], |
| 640 | + 'type': 'object', |
| 641 | + }, |
| 642 | + }, |
| 643 | + ] |
| 644 | + ) |
| 645 | + |
| 646 | + |
| 647 | +def test_call_tool_without_unrequired_parameters(set_event_loop: None): |
| 648 | + async def call_tools_first(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
| 649 | + if len(messages) == 1: |
| 650 | + return ModelResponse( |
| 651 | + parts=[ |
| 652 | + ToolCallPart(tool_name='my_tool', args=ArgsDict({'a': 13})), |
| 653 | + ToolCallPart(tool_name='my_tool', args=ArgsDict({'a': 13, 'b': 4})), |
| 654 | + ToolCallPart(tool_name='my_tool_plain', args=ArgsDict({'b': 17})), |
| 655 | + ToolCallPart(tool_name='my_tool_plain', args=ArgsDict({'a': 4, 'b': 17})), |
| 656 | + ] |
| 657 | + ) |
| 658 | + else: |
| 659 | + return ModelResponse(parts=[TextPart('finished')]) |
| 660 | + |
| 661 | + agent = Agent(FunctionModel(call_tools_first)) |
| 662 | + |
| 663 | + @agent.tool |
| 664 | + def my_tool(ctx: RunContext[None], a: int, b: int = 2) -> int: |
| 665 | + return a + b |
| 666 | + |
| 667 | + @agent.tool_plain |
| 668 | + def my_tool_plain(*, a: int = 3, b: int) -> int: |
| 669 | + return a * b |
| 670 | + |
| 671 | + result = agent.run_sync('Hello') |
| 672 | + all_messages = result.all_messages() |
| 673 | + first_response = all_messages[1] |
| 674 | + second_request = all_messages[2] |
| 675 | + assert isinstance(first_response, ModelResponse) |
| 676 | + assert isinstance(second_request, ModelRequest) |
| 677 | + tool_call_args = [p.args for p in first_response.parts if isinstance(p, ToolCallPart)] |
| 678 | + tool_returns = [p.content for p in second_request.parts if isinstance(p, ToolReturnPart)] |
| 679 | + assert tool_call_args == snapshot( |
| 680 | + [ |
| 681 | + ArgsDict(args_dict={'a': 13}), |
| 682 | + ArgsDict(args_dict={'a': 13, 'b': 4}), |
| 683 | + ArgsDict(args_dict={'b': 17}), |
| 684 | + ArgsDict(args_dict={'a': 4, 'b': 17}), |
| 685 | + ] |
| 686 | + ) |
| 687 | + assert tool_returns == snapshot([15, 17, 51, 68]) |
0 commit comments