|
6 | 6 | from pydantic import BaseModel
|
7 | 7 |
|
8 | 8 | from ragbits.agents import Agent, AgentRunContext
|
9 |
| -from ragbits.agents._main import AgentOptions, AgentResult, AgentResultStreaming, ToolCallResult |
| 9 | +from ragbits.agents._main import AgentOptions, AgentResult, AgentResultStreaming, ToolCallResult, ToolChoice |
10 | 10 | from ragbits.agents.exceptions import (
|
11 | 11 | AgentInvalidPromptInputError,
|
12 | 12 | AgentMaxTurnsExceededError,
|
@@ -143,22 +143,92 @@ def llm_with_tool_call_context() -> MockLLM:
|
143 | 143 | return MockLLM(default_options=options)
|
144 | 144 |
|
145 | 145 |
|
| 146 | +def get_time() -> str: |
| 147 | + """ |
| 148 | + Returns the current time. |
| 149 | +
|
| 150 | + Returns: |
| 151 | + The current time as a string. |
| 152 | + """ |
| 153 | + return "12:00 PM" |
| 154 | + |
| 155 | + |
| 156 | +@pytest.fixture |
| 157 | +def llm_no_tool_call_when_none() -> MockLLM: |
| 158 | + """LLM that doesn't call tools when tool_choice is 'none'.""" |
| 159 | + options = MockLLMOptions(response="I cannot call tools right now.") |
| 160 | + return MockLLM(default_options=options) |
| 161 | + |
| 162 | + |
| 163 | +@pytest.fixture |
| 164 | +def llm_auto_tool_call() -> MockLLM: |
| 165 | + """LLM that automatically decides to call a tool.""" |
| 166 | + options = MockLLMOptions( |
| 167 | + response="Let me check the weather for you.", |
| 168 | + tool_calls=[ |
| 169 | + { |
| 170 | + "name": "get_weather", |
| 171 | + "arguments": '{"location": "New York"}', |
| 172 | + "id": "auto_test", |
| 173 | + "type": "function", |
| 174 | + } |
| 175 | + ], |
| 176 | + ) |
| 177 | + return MockLLM(default_options=options) |
| 178 | + |
| 179 | + |
| 180 | +@pytest.fixture |
| 181 | +def llm_required_tool_call() -> MockLLM: |
| 182 | + """LLM that is forced to call a tool when tool_choice is 'required'.""" |
| 183 | + options = MockLLMOptions( |
| 184 | + response="", |
| 185 | + tool_calls=[ |
| 186 | + { |
| 187 | + "name": "get_weather", |
| 188 | + "arguments": '{"location": "Boston"}', |
| 189 | + "id": "required_test", |
| 190 | + "type": "function", |
| 191 | + } |
| 192 | + ], |
| 193 | + ) |
| 194 | + return MockLLM(default_options=options) |
| 195 | + |
| 196 | + |
| 197 | +@pytest.fixture |
| 198 | +def llm_specific_tool_call() -> MockLLM: |
| 199 | + """LLM that calls a specific tool when tool_choice is a specific function.""" |
| 200 | + options = MockLLMOptions( |
| 201 | + response="", |
| 202 | + tool_calls=[ |
| 203 | + { |
| 204 | + "name": "get_time", |
| 205 | + "arguments": "{}", |
| 206 | + "id": "specific_test", |
| 207 | + "type": "function", |
| 208 | + } |
| 209 | + ], |
| 210 | + ) |
| 211 | + return MockLLM(default_options=options) |
| 212 | + |
| 213 | + |
146 | 214 | async def _run(
|
147 | 215 | agent: Agent,
|
148 | 216 | input: str | BaseModel | None = None,
|
149 | 217 | options: AgentOptions | None = None,
|
150 | 218 | context: AgentRunContext | None = None,
|
| 219 | + tool_choice: ToolChoice | None = None, |
151 | 220 | ) -> AgentResult:
|
152 |
| - return await agent.run(input, options=options, context=context) |
| 221 | + return await agent.run(input, options=options, context=context, tool_choice=tool_choice) |
153 | 222 |
|
154 | 223 |
|
155 | 224 | async def _run_streaming(
|
156 | 225 | agent: Agent,
|
157 | 226 | input: str | BaseModel | None = None,
|
158 | 227 | options: AgentOptions | None = None,
|
159 | 228 | context: AgentRunContext | None = None,
|
| 229 | + tool_choice: ToolChoice | None = None, |
160 | 230 | ) -> AgentResultStreaming:
|
161 |
| - result = agent.run_streaming(input, options=options, context=context) |
| 231 | + result = agent.run_streaming(input, options=options, context=context, tool_choice=tool_choice) |
162 | 232 | async for _chunk in result:
|
163 | 233 | pass
|
164 | 234 | return result
|
@@ -588,3 +658,114 @@ async def test_max_turns_not_exeeded_with_many_tool_calls(llm_multiple_tool_call
|
588 | 658 |
|
589 | 659 | assert result.content == "Final response after multiple tool calls"
|
590 | 660 | assert len(result.tool_calls) == 3
|
| 661 | + |
| 662 | + |
| 663 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 664 | +async def test_agent_run_with_tool_choice_none(llm_no_tool_call_when_none: MockLLM, method: Callable): |
| 665 | + """Test agent run with tool_choice set to 'none'.""" |
| 666 | + agent = Agent( |
| 667 | + llm=llm_no_tool_call_when_none, |
| 668 | + prompt=CustomPrompt, |
| 669 | + tools=[get_weather], |
| 670 | + ) |
| 671 | + result = await method(agent, tool_choice="none") |
| 672 | + |
| 673 | + assert result.content == "I cannot call tools right now." |
| 674 | + assert result.tool_calls is None |
| 675 | + |
| 676 | + |
| 677 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 678 | +async def test_agent_run_with_auto_tool_call(llm_auto_tool_call: MockLLM, method: Callable): |
| 679 | + """Test agent run with automatic tool call.""" |
| 680 | + agent = Agent( |
| 681 | + llm=llm_auto_tool_call, |
| 682 | + prompt=CustomPrompt, |
| 683 | + tools=[get_weather], |
| 684 | + ) |
| 685 | + result = await method(agent) |
| 686 | + |
| 687 | + assert result.content == "Let me check the weather for you." |
| 688 | + assert len(result.tool_calls) == 1 |
| 689 | + assert result.tool_calls[0].id == "auto_test" |
| 690 | + |
| 691 | + |
| 692 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 693 | +async def test_agent_run_with_required_tool_call(llm_required_tool_call: MockLLM, method: Callable): |
| 694 | + """Test agent run with required tool call.""" |
| 695 | + agent = Agent( |
| 696 | + llm=llm_required_tool_call, |
| 697 | + prompt=CustomPrompt, |
| 698 | + tools=[get_weather], |
| 699 | + ) |
| 700 | + result = await method(agent, tool_choice="required") |
| 701 | + |
| 702 | + assert result.content == "" |
| 703 | + assert len(result.tool_calls) == 1 |
| 704 | + assert result.tool_calls[0].id == "required_test" |
| 705 | + |
| 706 | + |
| 707 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 708 | +async def test_agent_run_with_specific_tool_call(llm_specific_tool_call: MockLLM, method: Callable): |
| 709 | + """Test agent run with specific tool call.""" |
| 710 | + agent = Agent( |
| 711 | + llm=llm_specific_tool_call, |
| 712 | + prompt=CustomPrompt, |
| 713 | + tools=[get_weather, get_time], |
| 714 | + ) |
| 715 | + result = await method(agent, tool_choice=get_time) |
| 716 | + |
| 717 | + assert result.content == "" |
| 718 | + assert len(result.tool_calls) == 1 |
| 719 | + assert result.tool_calls[0].id == "specific_test" |
| 720 | + assert result.tool_calls[0].name == "get_time" |
| 721 | + assert result.tool_calls[0].result == "12:00 PM" |
| 722 | + |
| 723 | + |
| 724 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 725 | +async def test_agent_run_with_tool_choice_auto_explicit(llm_auto_tool_call: MockLLM, method: Callable): |
| 726 | + """Test agent run with tool_choice explicitly set to 'auto'.""" |
| 727 | + agent = Agent( |
| 728 | + llm=llm_auto_tool_call, |
| 729 | + prompt=CustomPrompt, |
| 730 | + tools=[get_weather], |
| 731 | + ) |
| 732 | + result = await method(agent, tool_choice="auto") |
| 733 | + |
| 734 | + assert result.content == "Let me check the weather for you." |
| 735 | + assert len(result.tool_calls) == 1 |
| 736 | + assert result.tool_calls[0].name == "get_weather" |
| 737 | + assert result.tool_calls[0].arguments == {"location": "New York"} |
| 738 | + |
| 739 | + |
| 740 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 741 | +async def test_tool_choice_with_multiple_tools_available(llm_auto_tool_call: MockLLM, method: Callable): |
| 742 | + """Test tool_choice behavior when multiple tools are available.""" |
| 743 | + agent = Agent( |
| 744 | + llm=llm_auto_tool_call, |
| 745 | + prompt=CustomPrompt, |
| 746 | + tools=[get_weather, get_time], # Multiple tools available |
| 747 | + ) |
| 748 | + |
| 749 | + result = await method(agent, tool_choice="auto") |
| 750 | + |
| 751 | + assert result.content == "Let me check the weather for you." |
| 752 | + assert len(result.tool_calls) == 1 |
| 753 | + # The LLM chose to call get_weather based on its configuration |
| 754 | + assert result.tool_calls[0].name == "get_weather" |
| 755 | + |
| 756 | + |
| 757 | +@pytest.mark.parametrize("method", [_run, _run_streaming]) |
| 758 | +async def test_tool_choice_history_preservation(llm_with_tool_call: MockLLM, method: Callable): |
| 759 | + """Test that tool_choice works correctly with history preservation.""" |
| 760 | + agent: Agent = Agent( |
| 761 | + llm=llm_with_tool_call, |
| 762 | + prompt="You are a helpful assistant", |
| 763 | + tools=[get_weather], |
| 764 | + keep_history=True, |
| 765 | + ) |
| 766 | + |
| 767 | + await method(agent, input="Check weather", tool_choice="auto") |
| 768 | + assert len(agent.history) >= 3 # At least system, user, assistant messages |
| 769 | + # Should include tool call in history |
| 770 | + tool_call_messages = [msg for msg in agent.history if msg.get("role") == "tool"] |
| 771 | + assert len(tool_call_messages) >= 1 |
0 commit comments