1414from rigging .message import inject_system_content
1515from ulid import ULID # can't access via rg
1616
17- from dreadnode .agent .error import MaxStepsError
17+ from dreadnode .agent .error import MaxStepsError , MaxToolCallsError
1818from dreadnode .agent .events import (
1919 AgentEnd ,
2020 AgentError ,
@@ -89,7 +89,9 @@ class Agent(Model):
8989 )
9090 """The agent's core instructions."""
9191 max_steps : int = Config (default = 10 )
92- """The maximum number of steps (generation + tool calls)."""
92+ """The maximum number of steps (generations)."""
93+ max_tool_calls : int = Config (default = - 1 )
94+ """The maximum number of tool calls. Defaults to infinite."""
9395 caching : rg .caching .CacheMode | None = Config (default = None , repr = False )
9496 """How to handle cache_control entries on inference messages."""
9597
@@ -488,10 +490,16 @@ async def _dispatch(event: AgentEvent) -> t.AsyncIterator[AgentEvent]: # noqa:
488490 raise winning_reaction
489491
490492 # Tool calling
493+ tool_calls = 0
491494
492495 async def _process_tool_call (
493496 tool_call : "rg.tools.ToolCall" ,
494497 ) -> t .AsyncGenerator [AgentEvent , None ]:
498+ nonlocal tool_calls
499+
500+ if self .max_tool_calls != - 1 and tool_calls >= self .max_tool_calls :
501+ raise Finish ("Reached maximum allowed tool calls." )
502+
495503 async for event in _dispatch (
496504 ToolStart (
497505 session_id = session_id ,
@@ -513,6 +521,7 @@ async def _process_tool_call(
513521 tool = next ((t for t in self .all_tools if t .name == tool_call .name ), None )
514522
515523 if tool is not None :
524+ tool_calls += 1
516525 try :
517526 message , stop = await tool .handle_tool_call (tool_call )
518527 except Reaction :
@@ -690,6 +699,9 @@ async def _process_tool_call(
690699 if step >= self .max_steps :
691700 error = MaxStepsError (max_steps = self .max_steps )
692701 stop_reason = "max_steps_reached"
702+ elif self .max_tool_calls != - 1 and tool_calls >= self .max_tool_calls :
703+ error = MaxToolCallsError (max_tool_calls = self .max_tool_calls )
704+ stop_reason = "max_tool_calls_reached"
693705 elif error is not None :
694706 stop_reason = "error"
695707 elif events and isinstance (events [- 1 ], AgentStalled ):
0 commit comments