Skip to content

Commit 43e507d

Browse files
GlockPLKonrad Czarnota
andauthored
feat: force tool calling support (#751)
Co-authored-by: Konrad Czarnota <[email protected]> Co-authored-by: GlockPL <[email protected]> Co-authored-by: GlockPL <[email protected]>
1 parent 7a20886 commit 43e507d

File tree

12 files changed

+380
-21
lines changed

12 files changed

+380
-21
lines changed

docs/how-to/agents/define_and_use_agents.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ The result is an [AgentResult][ragbits.agents.AgentResult], which includes the m
4949

5050
You can find the complete code example in the Ragbits repository [here](https://github.com/deepsense-ai/ragbits/blob/main/examples/agents/tool_use.py).
5151

52+
## Tool choice
53+
To control what tool is used at first call you could use `tool_choice` parameter. There are the following options:
54+
- "auto": let model decide if tool call is needed
55+
- "none": do not call tool
56+
- "required: enforce tool usage (model decides which one)
57+
- Callable: one of provided tools
58+
59+
5260
## Conversation history
5361
[`Agent`][ragbits.agents.Agent]s can retain conversation context across multiple interactions by enabling the `keep_history` flag when initializing the agent. This is useful when you want the agent to understand follow-up questions without needing the user to repeat earlier details.
5462

examples/agents/tool_use.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def main() -> None:
8282
tools=[get_weather],
8383
default_options=AgentOptions(max_total_tokens=500, max_turns=5),
8484
)
85-
response = await agent.run(WeatherPromptInput(location="Paris"))
85+
response = await agent.run(WeatherPromptInput(location="Paris"), tool_choice=get_weather)
8686
print(response)
8787

8888

packages/ragbits-agents/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# CHANGELOG
22

33
## Unreleased
4+
- Add tool_choice parameter to agent interface (#738)
45

56
## 1.2.1 (2025-08-04)
67

@@ -13,7 +14,6 @@
1314
### Changed
1415

1516
- ragbits-core updated to version v1.2.0
16-
1717
- Add native openai tools support (#621)
1818
- add Context to Agents (#715)
1919

packages/ragbits-agents/src/ragbits/agents/_main.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from ragbits.agents.mcp.server import MCPServer
2424
from ragbits.agents.mcp.utils import get_tools
25-
from ragbits.agents.tool import Tool, ToolCallResult
25+
from ragbits.agents.tool import Tool, ToolCallResult, ToolChoice
2626
from ragbits.core.audit.traces import trace
2727
from ragbits.core.llms.base import LLM, LLMClientOptionsT, LLMResponseWithMetadata, ToolCall, Usage
2828
from ragbits.core.options import Options
@@ -192,6 +192,7 @@ async def run(
192192
input: str | None = None,
193193
options: AgentOptions[LLMClientOptionsT] | None = None,
194194
context: AgentRunContext | None = None,
195+
tool_choice: ToolChoice | None = None,
195196
) -> AgentResult[PromptOutputT]: ...
196197

197198
@overload
@@ -200,13 +201,15 @@ async def run(
200201
input: PromptInputT,
201202
options: AgentOptions[LLMClientOptionsT] | None = None,
202203
context: AgentRunContext | None = None,
204+
tool_choice: ToolChoice | None = None,
203205
) -> AgentResult[PromptOutputT]: ...
204206

205207
async def run(
206208
self,
207209
input: str | PromptInputT | None = None,
208210
options: AgentOptions[LLMClientOptionsT] | None = None,
209211
context: AgentRunContext | None = None,
212+
tool_choice: ToolChoice | None = None,
210213
) -> AgentResult[PromptOutputT]:
211214
"""
212215
Run the agent. The method is experimental, inputs and outputs may change in the future.
@@ -218,6 +221,11 @@ async def run(
218221
- None: No input. Only valid when a string prompt was provided during initialization.
219222
options: The options for the agent run.
220223
context: The context for the agent run.
224+
tool_choice: Parameter that allows to control what tool is used at first call. Can be one of:
225+
- "auto": let model decide if tool call is needed
226+
- "none": do not call tool
227+
- "required: enforce tool usage (model decides which one)
228+
- Callable: one of provided tools
221229
222230
Returns:
223231
The result of the agent run.
@@ -251,6 +259,7 @@ async def run(
251259
await self.llm.generate_with_metadata(
252260
prompt=prompt_with_history,
253261
tools=[tool.to_function_schema() for tool in tools_mapping.values()],
262+
tool_choice=tool_choice if tool_choice and turn_count == 0 else None,
254263
options=self._get_llm_options(llm_options, merged_options, context.usage),
255264
),
256265
)
@@ -294,6 +303,7 @@ def run_streaming(
294303
input: str | None = None,
295304
options: AgentOptions[LLMClientOptionsT] | None = None,
296305
context: AgentRunContext | None = None,
306+
tool_choice: ToolChoice | None = None,
297307
) -> AgentResultStreaming: ...
298308

299309
@overload
@@ -302,13 +312,15 @@ def run_streaming(
302312
input: PromptInputT,
303313
options: AgentOptions[LLMClientOptionsT] | None = None,
304314
context: AgentRunContext | None = None,
315+
tool_choice: ToolChoice | None = None,
305316
) -> AgentResultStreaming: ...
306317

307318
def run_streaming(
308319
self,
309320
input: str | PromptInputT | None = None,
310321
options: AgentOptions[LLMClientOptionsT] | None = None,
311322
context: AgentRunContext | None = None,
323+
tool_choice: ToolChoice | None = None,
312324
) -> AgentResultStreaming:
313325
"""
314326
This method returns an `AgentResultStreaming` object that can be asynchronously
@@ -318,6 +330,11 @@ def run_streaming(
318330
input: The input for the agent run.
319331
options: The options for the agent run.
320332
context: The context for the agent run.
333+
tool_choice: Parameter that allows to control what tool is used at first call. Can be one of:
334+
- "auto": let model decide if tool call is needed
335+
- "none": do not call tool
336+
- "required: enforce tool usage (model decides which one)
337+
- Callable: one of provided tools
321338
322339
Returns:
323340
A `StreamingResult` object for iteration and collection.
@@ -329,14 +346,15 @@ def run_streaming(
329346
AgentInvalidPromptInputError: If the prompt/input combination is invalid.
330347
AgentMaxTurnsExceededError: If the maximum number of turns is exceeded.
331348
"""
332-
generator = self._stream_internal(input, options, context)
349+
generator = self._stream_internal(input, options, context, tool_choice)
333350
return AgentResultStreaming(generator)
334351

335352
async def _stream_internal(
336353
self,
337354
input: str | PromptInputT | None = None,
338355
options: AgentOptions[LLMClientOptionsT] | None = None,
339356
context: AgentRunContext | None = None,
357+
tool_choice: ToolChoice | None = None,
340358
) -> AsyncGenerator[str | ToolCall | ToolCallResult | SimpleNamespace | BasePrompt | Usage]:
341359
if context is None:
342360
context = AgentRunContext()
@@ -357,6 +375,7 @@ async def _stream_internal(
357375
streaming_result = self.llm.generate_streaming(
358376
prompt=prompt_with_history,
359377
tools=[tool.to_function_schema() for tool in tools_mapping.values()],
378+
tool_choice=tool_choice if tool_choice and turn_count == 0 else None,
360379
options=self._get_llm_options(llm_options, merged_options, context.usage),
361380
)
362381
async for chunk in streaming_result:

packages/ragbits-agents/src/ragbits/agents/tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass
3-
from typing import Any
3+
from typing import Any, Literal
44

55
from typing_extensions import Self
66

@@ -76,3 +76,6 @@ def to_function_schema(self) -> dict[str, Any]:
7676
"parameters": self.parameters,
7777
},
7878
}
79+
80+
81+
ToolChoice = Literal["auto", "none", "required"] | Callable

packages/ragbits-agents/tests/unit/test_agent.py

Lines changed: 184 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pydantic import BaseModel
77

88
from ragbits.agents import Agent, AgentRunContext
9-
from ragbits.agents._main import AgentOptions, AgentResult, AgentResultStreaming, ToolCallResult
9+
from ragbits.agents._main import AgentOptions, AgentResult, AgentResultStreaming, ToolCallResult, ToolChoice
1010
from ragbits.agents.exceptions import (
1111
AgentInvalidPromptInputError,
1212
AgentMaxTurnsExceededError,
@@ -143,22 +143,92 @@ def llm_with_tool_call_context() -> MockLLM:
143143
return MockLLM(default_options=options)
144144

145145

146+
def get_time() -> str:
147+
"""
148+
Returns the current time.
149+
150+
Returns:
151+
The current time as a string.
152+
"""
153+
return "12:00 PM"
154+
155+
156+
@pytest.fixture
157+
def llm_no_tool_call_when_none() -> MockLLM:
158+
"""LLM that doesn't call tools when tool_choice is 'none'."""
159+
options = MockLLMOptions(response="I cannot call tools right now.")
160+
return MockLLM(default_options=options)
161+
162+
163+
@pytest.fixture
164+
def llm_auto_tool_call() -> MockLLM:
165+
"""LLM that automatically decides to call a tool."""
166+
options = MockLLMOptions(
167+
response="Let me check the weather for you.",
168+
tool_calls=[
169+
{
170+
"name": "get_weather",
171+
"arguments": '{"location": "New York"}',
172+
"id": "auto_test",
173+
"type": "function",
174+
}
175+
],
176+
)
177+
return MockLLM(default_options=options)
178+
179+
180+
@pytest.fixture
181+
def llm_required_tool_call() -> MockLLM:
182+
"""LLM that is forced to call a tool when tool_choice is 'required'."""
183+
options = MockLLMOptions(
184+
response="",
185+
tool_calls=[
186+
{
187+
"name": "get_weather",
188+
"arguments": '{"location": "Boston"}',
189+
"id": "required_test",
190+
"type": "function",
191+
}
192+
],
193+
)
194+
return MockLLM(default_options=options)
195+
196+
197+
@pytest.fixture
198+
def llm_specific_tool_call() -> MockLLM:
199+
"""LLM that calls a specific tool when tool_choice is a specific function."""
200+
options = MockLLMOptions(
201+
response="",
202+
tool_calls=[
203+
{
204+
"name": "get_time",
205+
"arguments": "{}",
206+
"id": "specific_test",
207+
"type": "function",
208+
}
209+
],
210+
)
211+
return MockLLM(default_options=options)
212+
213+
146214
async def _run(
147215
agent: Agent,
148216
input: str | BaseModel | None = None,
149217
options: AgentOptions | None = None,
150218
context: AgentRunContext | None = None,
219+
tool_choice: ToolChoice | None = None,
151220
) -> AgentResult:
152-
return await agent.run(input, options=options, context=context)
221+
return await agent.run(input, options=options, context=context, tool_choice=tool_choice)
153222

154223

155224
async def _run_streaming(
156225
agent: Agent,
157226
input: str | BaseModel | None = None,
158227
options: AgentOptions | None = None,
159228
context: AgentRunContext | None = None,
229+
tool_choice: ToolChoice | None = None,
160230
) -> AgentResultStreaming:
161-
result = agent.run_streaming(input, options=options, context=context)
231+
result = agent.run_streaming(input, options=options, context=context, tool_choice=tool_choice)
162232
async for _chunk in result:
163233
pass
164234
return result
@@ -588,3 +658,114 @@ async def test_max_turns_not_exeeded_with_many_tool_calls(llm_multiple_tool_call
588658

589659
assert result.content == "Final response after multiple tool calls"
590660
assert len(result.tool_calls) == 3
661+
662+
663+
@pytest.mark.parametrize("method", [_run, _run_streaming])
664+
async def test_agent_run_with_tool_choice_none(llm_no_tool_call_when_none: MockLLM, method: Callable):
665+
"""Test agent run with tool_choice set to 'none'."""
666+
agent = Agent(
667+
llm=llm_no_tool_call_when_none,
668+
prompt=CustomPrompt,
669+
tools=[get_weather],
670+
)
671+
result = await method(agent, tool_choice="none")
672+
673+
assert result.content == "I cannot call tools right now."
674+
assert result.tool_calls is None
675+
676+
677+
@pytest.mark.parametrize("method", [_run, _run_streaming])
678+
async def test_agent_run_with_auto_tool_call(llm_auto_tool_call: MockLLM, method: Callable):
679+
"""Test agent run with automatic tool call."""
680+
agent = Agent(
681+
llm=llm_auto_tool_call,
682+
prompt=CustomPrompt,
683+
tools=[get_weather],
684+
)
685+
result = await method(agent)
686+
687+
assert result.content == "Let me check the weather for you."
688+
assert len(result.tool_calls) == 1
689+
assert result.tool_calls[0].id == "auto_test"
690+
691+
692+
@pytest.mark.parametrize("method", [_run, _run_streaming])
693+
async def test_agent_run_with_required_tool_call(llm_required_tool_call: MockLLM, method: Callable):
694+
"""Test agent run with required tool call."""
695+
agent = Agent(
696+
llm=llm_required_tool_call,
697+
prompt=CustomPrompt,
698+
tools=[get_weather],
699+
)
700+
result = await method(agent, tool_choice="required")
701+
702+
assert result.content == ""
703+
assert len(result.tool_calls) == 1
704+
assert result.tool_calls[0].id == "required_test"
705+
706+
707+
@pytest.mark.parametrize("method", [_run, _run_streaming])
708+
async def test_agent_run_with_specific_tool_call(llm_specific_tool_call: MockLLM, method: Callable):
709+
"""Test agent run with specific tool call."""
710+
agent = Agent(
711+
llm=llm_specific_tool_call,
712+
prompt=CustomPrompt,
713+
tools=[get_weather, get_time],
714+
)
715+
result = await method(agent, tool_choice=get_time)
716+
717+
assert result.content == ""
718+
assert len(result.tool_calls) == 1
719+
assert result.tool_calls[0].id == "specific_test"
720+
assert result.tool_calls[0].name == "get_time"
721+
assert result.tool_calls[0].result == "12:00 PM"
722+
723+
724+
@pytest.mark.parametrize("method", [_run, _run_streaming])
725+
async def test_agent_run_with_tool_choice_auto_explicit(llm_auto_tool_call: MockLLM, method: Callable):
726+
"""Test agent run with tool_choice explicitly set to 'auto'."""
727+
agent = Agent(
728+
llm=llm_auto_tool_call,
729+
prompt=CustomPrompt,
730+
tools=[get_weather],
731+
)
732+
result = await method(agent, tool_choice="auto")
733+
734+
assert result.content == "Let me check the weather for you."
735+
assert len(result.tool_calls) == 1
736+
assert result.tool_calls[0].name == "get_weather"
737+
assert result.tool_calls[0].arguments == {"location": "New York"}
738+
739+
740+
@pytest.mark.parametrize("method", [_run, _run_streaming])
741+
async def test_tool_choice_with_multiple_tools_available(llm_auto_tool_call: MockLLM, method: Callable):
742+
"""Test tool_choice behavior when multiple tools are available."""
743+
agent = Agent(
744+
llm=llm_auto_tool_call,
745+
prompt=CustomPrompt,
746+
tools=[get_weather, get_time], # Multiple tools available
747+
)
748+
749+
result = await method(agent, tool_choice="auto")
750+
751+
assert result.content == "Let me check the weather for you."
752+
assert len(result.tool_calls) == 1
753+
# The LLM chose to call get_weather based on its configuration
754+
assert result.tool_calls[0].name == "get_weather"
755+
756+
757+
@pytest.mark.parametrize("method", [_run, _run_streaming])
758+
async def test_tool_choice_history_preservation(llm_with_tool_call: MockLLM, method: Callable):
759+
"""Test that tool_choice works correctly with history preservation."""
760+
agent: Agent = Agent(
761+
llm=llm_with_tool_call,
762+
prompt="You are a helpful assistant",
763+
tools=[get_weather],
764+
keep_history=True,
765+
)
766+
767+
await method(agent, input="Check weather", tool_choice="auto")
768+
assert len(agent.history) >= 3 # At least system, user, assistant messages
769+
# Should include tool call in history
770+
tool_call_messages = [msg for msg in agent.history if msg.get("role") == "tool"]
771+
assert len(tool_call_messages) >= 1

packages/ragbits-core/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Add tool_choice parameter to LLM interface (#738)
6+
57
## 1.2.1 (2025-08-04)
68

79
## 1.2.0 (2025-08-01)

0 commit comments

Comments
 (0)