Skip to content

Commit 46a58ea

Browse files
mkultraWasHereMichael Kouremetis
andauthored
fix: tool call counting (#282)
* tool call tracking * tool call tracking * better placement for tool counting, will now fail on the exact tool call that is past max * added test case * linting * linting --------- Co-authored-by: Michael Kouremetis <[email protected]>
1 parent 86ca07b commit 46a58ea

File tree

4 files changed

+51
-4
lines changed

4 files changed

+51
-4
lines changed

dreadnode/agent/agent.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from rigging.message import inject_system_content
1515
from ulid import ULID # can't access via rg
1616

17-
from dreadnode.agent.error import MaxStepsError
17+
from dreadnode.agent.error import MaxStepsError, MaxToolCallsError
1818
from dreadnode.agent.events import (
1919
AgentEnd,
2020
AgentError,
@@ -89,7 +89,9 @@ class Agent(Model):
8989
)
9090
"""The agent's core instructions."""
9191
max_steps: int = Config(default=10)
92-
"""The maximum number of steps (generation + tool calls)."""
92+
"""The maximum number of steps (generations)."""
93+
max_tool_calls: int = Config(default=-1)
94+
"""The maximum number of tool calls. Defaults to infinite."""
9395
caching: rg.caching.CacheMode | None = Config(default=None, repr=False)
9496
"""How to handle cache_control entries on inference messages."""
9597

@@ -488,10 +490,16 @@ async def _dispatch(event: AgentEvent) -> t.AsyncIterator[AgentEvent]: # noqa:
488490
raise winning_reaction
489491

490492
# Tool calling
493+
tool_calls = 0
491494

492495
async def _process_tool_call(
493496
tool_call: "rg.tools.ToolCall",
494497
) -> t.AsyncGenerator[AgentEvent, None]:
498+
nonlocal tool_calls
499+
500+
if self.max_tool_calls != -1 and tool_calls >= self.max_tool_calls:
501+
raise Finish("Reached maximum allowed tool calls.")
502+
495503
async for event in _dispatch(
496504
ToolStart(
497505
session_id=session_id,
@@ -513,6 +521,7 @@ async def _process_tool_call(
513521
tool = next((t for t in self.all_tools if t.name == tool_call.name), None)
514522

515523
if tool is not None:
524+
tool_calls += 1
516525
try:
517526
message, stop = await tool.handle_tool_call(tool_call)
518527
except Reaction:
@@ -690,6 +699,9 @@ async def _process_tool_call(
690699
if step >= self.max_steps:
691700
error = MaxStepsError(max_steps=self.max_steps)
692701
stop_reason = "max_steps_reached"
702+
elif self.max_tool_calls != -1 and tool_calls >= self.max_tool_calls:
703+
error = MaxToolCallsError(max_tool_calls=self.max_tool_calls)
704+
stop_reason = "max_tool_calls_reached"
693705
elif error is not None:
694706
stop_reason = "error"
695707
elif events and isinstance(events[-1], AgentStalled):

dreadnode/agent/error.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,11 @@ class MaxStepsError(Exception):
44
def __init__(self, max_steps: int):
55
super().__init__(f"Maximum steps reached ({max_steps}).")
66
self.max_steps = max_steps
7+
8+
9+
class MaxToolCallsError(Exception):
10+
"""Raise from a hook to stop the agent's run due to reaching the maximum number of tool calls."""
11+
12+
def __init__(self, max_tool_calls: int):
13+
super().__init__(f"Maximum tool calls reached ({max_tool_calls}).")
14+
self.max_tool_calls = max_tool_calls

dreadnode/agent/result.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
if t.TYPE_CHECKING:
99
from dreadnode.agent.agent import Agent
1010

11-
AgentStopReason = t.Literal["finished", "max_steps_reached", "error", "stalled"]
11+
AgentStopReason = t.Literal[
12+
"finished", "max_steps_reached", "max_tool_calls_reached", "error", "stalled"
13+
]
1214

1315

1416
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))

tests/test_agent.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from rigging.generator.base import GeneratedMessage
99

1010
from dreadnode.agent.agent import Agent, TaskAgent
11-
from dreadnode.agent.error import MaxStepsError
11+
from dreadnode.agent.error import MaxStepsError, MaxToolCallsError
1212
from dreadnode.agent.events import AgentEnd, AgentEvent, AgentStalled, Reacted, ToolStart
1313
from dreadnode.agent.hooks.base import retry_with_feedback
1414
from dreadnode.agent.reactions import RetryWithFeedback
@@ -298,6 +298,31 @@ async def test_run_stops_on_max_steps(mock_generator: MockGenerator, simple_tool
298298
assert result.steps == 1
299299

300300

301+
@pytest.mark.asyncio
302+
async def test_run_stops_on_max_tool_calls(
303+
mock_generator: MockGenerator, simple_tool: AnyTool
304+
) -> None:
305+
"""Ensure the agent run terminates with a MaxToolCallsError when exceeding max_tool_calls."""
306+
# The agent will just keep calling the tool.
307+
mock_generator._responses = [
308+
MockGenerator.tool_response("get_weather", {"city": "A"}),
309+
MockGenerator.tool_response("get_weather", {"city": "B"}),
310+
MockGenerator.tool_response("get_weather", {"city": "C"}),
311+
]
312+
313+
agent = Agent(
314+
name="MaxToolCallsAgent",
315+
model=mock_generator,
316+
tools=[simple_tool],
317+
max_tool_calls=2,
318+
)
319+
result = await agent.run("...")
320+
321+
assert result.failed
322+
assert result.stop_reason == "max_tool_calls_reached"
323+
assert isinstance(result.error, MaxToolCallsError)
324+
325+
301326
@pytest.mark.asyncio
302327
async def test_run_stops_on_stop_condition(
303328
mock_generator: MockGenerator, simple_tool: AnyTool

0 commit comments

Comments
 (0)