diff --git a/src/agents/agent.py b/src/agents/agent.py index b64a6ea1d..f45090e07 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -384,6 +384,7 @@ def as_tool( custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + max_turns: int | None = None, ) -> Tool: """Transform this agent into a tool, callable by other agents. @@ -402,6 +403,8 @@ def as_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + max_turns: The maximum number of turns the agent can take when running as a tool. + If not provided, the default value will be used. """ @function_tool( @@ -413,9 +416,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> str: from .run import Runner output = await Runner.run( - starting_agent=self, - input=input, - context=context.context, + starting_agent=self, input=input, context=context.context, max_turns=max_turns ) if custom_output_extractor: return await custom_output_extractor(output) diff --git a/src/agents/run.py b/src/agents/run.py index 5056758fb..fc679fabf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -236,7 +236,7 @@ async def run( input: str | list[TResponseInputItem], *, context: TContext | None = None, - max_turns: int = DEFAULT_MAX_TURNS, + max_turns: int | None = DEFAULT_MAX_TURNS, hooks: RunHooks[TContext] | None = None, run_config: RunConfig | None = None, previous_response_id: str | None = None, @@ -277,6 +277,8 @@ async def run( agent. Agents may perform handoffs, so we don't know the specific type of the output. """ runner = DEFAULT_AGENT_RUNNER + if max_turns is None: + max_turns = DEFAULT_MAX_TURNS return await runner.run( starting_agent, input, diff --git a/tests/test_agent_as_tool_max_turns.py b/tests/test_agent_as_tool_max_turns.py new file mode 100644 index 000000000..63768fbd9 --- /dev/null +++ b/tests/test_agent_as_tool_max_turns.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import json + +import pytest + +from agents import Agent, MaxTurnsExceeded, Runner +from agents.items import TResponseOutputItem +from agents.run import DEFAULT_MAX_TURNS + +from .fake_model import FakeModel +from .test_responses import get_function_tool, get_function_tool_call, get_text_message + + +@pytest.mark.asyncio +async def test_runner_run_max_turns_none_defaults_to_constant(): + model = FakeModel() + agent = Agent( + name="test_runner_max_turns_none", + model=model, + tools=[get_function_tool("tool", "ok")], + ) + + # Prepare 11 turns (DEFAULT_MAX_TURNS is 10) to ensure exceeding default. + func_output = json.dumps({"a": "b"}) + turns: list[list[TResponseOutputItem] | Exception] = [] + for i in range(1, DEFAULT_MAX_TURNS + 1): + turns.append([get_text_message(str(i)), get_function_tool_call("tool", func_output)]) + model.add_multiple_turn_outputs(turns) + + # Passing None should make Runner default to DEFAULT_MAX_TURNS (10), so 11th turn exceeds. + with pytest.raises(MaxTurnsExceeded): + await Runner.run(agent, input="go", max_turns=None) + + +@pytest.mark.asyncio +async def test_agent_as_tool_forwards_max_turns(): + # Inner agent will exceed when limited to 1 turn. + inner_model = FakeModel() + inner_agent = Agent( + name="inner", + model=inner_model, + tools=[get_function_tool("some_function", "ok")], + ) + + # Make inner agent require more than 1 turn. + func_output = json.dumps({"x": 1}) + inner_turns: list[list[TResponseOutputItem] | Exception] = [ + [get_text_message("t1"), get_function_tool_call("some_function", func_output)], + [get_text_message("t2"), get_function_tool_call("some_function", func_output)], + ] + inner_model.add_multiple_turn_outputs(inner_turns) + + # Wrap inner agent as a tool with max_turns=1. + wrapped_tool = inner_agent.as_tool( + tool_name="inner_tool", + tool_description="runs inner agent", + max_turns=1, + ) + + # Orchestrator will call the wrapped tool twice, causing inner to exceed its max_turns. + outer_model = FakeModel() + orchestrator = Agent( + name="orchestrator", + model=outer_model, + tools=[wrapped_tool], + ) + + # Outer model asks to call the tool once; + # exceeding happens inside the tool call when inner runs. + outer_turns: list[list[TResponseOutputItem] | Exception] = [ + [get_function_tool_call("inner_tool")], + ] + outer_model.add_multiple_turn_outputs(outer_turns) + + # Since tool default error handling returns a string on error, + # the run should not raise here. + result = await Runner.run(orchestrator, input="start") + + # The tool call error should be surfaced as a message back to the model; + # ensure we have some output. + # We don't assert exact message text to avoid brittleness; + # just ensure the run completed with items. + assert len(result.new_items) >= 1