|
18 | 18 | AgentRunResult, |
19 | 19 | AgentRunResultEvent, |
20 | 20 | AgentStreamEvent, |
| 21 | + ExternalToolset, |
21 | 22 | FinalResultEvent, |
22 | 23 | FunctionToolCallEvent, |
23 | 24 | FunctionToolResultEvent, |
@@ -819,6 +820,182 @@ def another_tool(y: int) -> int: # pragma: no cover |
819 | 820 | ) |
820 | 821 |
|
821 | 822 |
|
| 823 | +async def test_early_strategy_with_external_tool_call(): |
| 824 | + tool_called: list[str] = [] |
| 825 | + |
| 826 | + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: |
| 827 | + assert info.output_tools is not None |
| 828 | + yield {1: DeltaToolCall('external_tool')} |
| 829 | + yield {2: DeltaToolCall('final_result', '{"value": "final"}')} |
| 830 | + yield {3: DeltaToolCall('regular_tool', '{"x": 1}')} |
| 831 | + |
| 832 | + agent = Agent( |
| 833 | + FunctionModel(stream_function=sf), |
| 834 | + output_type=[OutputType, DeferredToolRequests], |
| 835 | + toolsets=[ |
| 836 | + ExternalToolset( |
| 837 | + tool_defs=[ |
| 838 | + ToolDefinition( |
| 839 | + name='external_tool', |
| 840 | + kind='external', |
| 841 | + ) |
| 842 | + ] |
| 843 | + ) |
| 844 | + ], |
| 845 | + end_strategy='early', |
| 846 | + ) |
| 847 | + |
| 848 | + @agent.tool_plain |
| 849 | + def regular_tool(x: int) -> int: # pragma: no cover |
| 850 | + """A regular tool that should not be called.""" |
| 851 | + tool_called.append('regular_tool') |
| 852 | + return x |
| 853 | + |
| 854 | + async with agent.run_stream('test early strategy with external tool call') as result: |
| 855 | + response = await result.get_output() |
| 856 | + assert response == snapshot( |
| 857 | + DeferredToolRequests( |
| 858 | + calls=[ |
| 859 | + ToolCallPart( |
| 860 | + tool_name='external_tool', |
| 861 | + tool_call_id=IsStr(), |
| 862 | + ) |
| 863 | + ] |
| 864 | + ) |
| 865 | + ) |
| 866 | + messages = result.all_messages() |
| 867 | + |
| 868 | + # Verify no tools were called |
| 869 | + assert tool_called == [] |
| 870 | + |
| 871 | + # Verify we got appropriate tool returns |
| 872 | + assert messages == snapshot( |
| 873 | + [ |
| 874 | + ModelRequest( |
| 875 | + parts=[ |
| 876 | + UserPromptPart( |
| 877 | + content='test early strategy with external tool call', |
| 878 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 879 | + part_kind='user-prompt', |
| 880 | + ) |
| 881 | + ], |
| 882 | + kind='request', |
| 883 | + ), |
| 884 | + ModelResponse( |
| 885 | + parts=[ |
| 886 | + ToolCallPart(tool_name='external_tool', tool_call_id=IsStr()), |
| 887 | + ToolCallPart( |
| 888 | + tool_name='final_result', |
| 889 | + args='{"value": "final"}', |
| 890 | + tool_call_id=IsStr(), |
| 891 | + ), |
| 892 | + ToolCallPart( |
| 893 | + tool_name='regular_tool', |
| 894 | + args='{"x": 1}', |
| 895 | + tool_call_id=IsStr(), |
| 896 | + ), |
| 897 | + ], |
| 898 | + usage=RequestUsage(input_tokens=50, output_tokens=7), |
| 899 | + model_name='function::sf', |
| 900 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 901 | + kind='response', |
| 902 | + ), |
| 903 | + ModelRequest( |
| 904 | + parts=[ |
| 905 | + ToolReturnPart( |
| 906 | + tool_name='final_result', |
| 907 | + content='Output tool not used - a final result was already processed.', |
| 908 | + tool_call_id=IsStr(), |
| 909 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 910 | + ), |
| 911 | + ToolReturnPart( |
| 912 | + tool_name='regular_tool', |
| 913 | + content='Tool not executed - a final result was already processed.', |
| 914 | + tool_call_id=IsStr(), |
| 915 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 916 | + ), |
| 917 | + ], |
| 918 | + kind='request', |
| 919 | + ), |
| 920 | + ] |
| 921 | + ) |
| 922 | + |
| 923 | + |
| 924 | +async def test_early_strategy_with_deferred_tool_call(): |
| 925 | + tool_called: list[str] = [] |
| 926 | + |
| 927 | + async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]: |
| 928 | + assert info.output_tools is not None |
| 929 | + yield {1: DeltaToolCall('deferred_tool')} |
| 930 | + yield {2: DeltaToolCall('regular_tool', '{"x": 1}')} |
| 931 | + |
| 932 | + agent = Agent( |
| 933 | + FunctionModel(stream_function=sf), |
| 934 | + output_type=[str, DeferredToolRequests], |
| 935 | + end_strategy='early', |
| 936 | + ) |
| 937 | + |
| 938 | + @agent.tool_plain |
| 939 | + def deferred_tool() -> int: |
| 940 | + raise CallDeferred |
| 941 | + |
| 942 | + @agent.tool_plain |
| 943 | + def regular_tool(x: int) -> int: |
| 944 | + tool_called.append('regular_tool') |
| 945 | + return x |
| 946 | + |
| 947 | + async with agent.run_stream('test early strategy with external tool call') as result: |
| 948 | + response = await result.get_output() |
| 949 | + assert response == snapshot( |
| 950 | + DeferredToolRequests(calls=[ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr())]) |
| 951 | + ) |
| 952 | + messages = result.all_messages() |
| 953 | + |
| 954 | + # Verify no tools were called |
| 955 | + assert tool_called == ['regular_tool'] |
| 956 | + |
| 957 | + # Verify we got appropriate tool returns |
| 958 | + assert messages == snapshot( |
| 959 | + [ |
| 960 | + ModelRequest( |
| 961 | + parts=[ |
| 962 | + UserPromptPart( |
| 963 | + content='test early strategy with external tool call', |
| 964 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 965 | + part_kind='user-prompt', |
| 966 | + ) |
| 967 | + ], |
| 968 | + kind='request', |
| 969 | + ), |
| 970 | + ModelResponse( |
| 971 | + parts=[ |
| 972 | + ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr()), |
| 973 | + ToolCallPart( |
| 974 | + tool_name='regular_tool', |
| 975 | + args='{"x": 1}', |
| 976 | + tool_call_id=IsStr(), |
| 977 | + ), |
| 978 | + ], |
| 979 | + usage=RequestUsage(input_tokens=50, output_tokens=3), |
| 980 | + model_name='function::sf', |
| 981 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 982 | + kind='response', |
| 983 | + ), |
| 984 | + ModelRequest( |
| 985 | + parts=[ |
| 986 | + ToolReturnPart( |
| 987 | + tool_name='regular_tool', |
| 988 | + content=1, |
| 989 | + tool_call_id=IsStr(), |
| 990 | + timestamp=IsNow(tz=datetime.timezone.utc), |
| 991 | + ) |
| 992 | + ], |
| 993 | + kind='request', |
| 994 | + ), |
| 995 | + ] |
| 996 | + ) |
| 997 | + |
| 998 | + |
822 | 999 | async def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool(): |
823 | 1000 | """Test that 'early' strategy does not apply to tool calls without final tool.""" |
824 | 1001 | tool_called: list[str] = [] |
|
0 commit comments