Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def as_tool(
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
max_turns: int | None = None,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.
Expand All @@ -402,6 +403,8 @@ def as_tool(
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the tool is enabled. Disabled tools are hidden
from the LLM at runtime.
max_turns: The maximum number of turns the agent can take when running as a tool.
If not provided, the default value will be used.
"""

@function_tool(
Expand All @@ -413,9 +416,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
from .run import Runner

output = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
starting_agent=self, input=input, context=context.context, max_turns=max_turns
)
if custom_output_extractor:
return await custom_output_extractor(output)
Expand Down
4 changes: 3 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def run(
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
max_turns: int | None = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
Expand Down Expand Up @@ -271,6 +271,8 @@ async def run(
agent. Agents may perform handoffs, so we don't know the specific type of the output.
"""
runner = DEFAULT_AGENT_RUNNER
if max_turns is None:
max_turns = DEFAULT_MAX_TURNS
return await runner.run(
starting_agent,
input,
Expand Down
85 changes: 85 additions & 0 deletions tests/test_agent_as_tool_max_turns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import json

import pytest

from agents import Agent, MaxTurnsExceeded, Runner
from agents.run import DEFAULT_MAX_TURNS

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message


@pytest.mark.asyncio
async def test_runner_run_max_turns_none_defaults_to_constant():
model = FakeModel()
agent = Agent(
name="test_runner_max_turns_none",
model=model,
tools=[get_function_tool("tool", "ok")],
)

# Prepare 11 turns (DEFAULT_MAX_TURNS is 10) to ensure exceeding default.
func_output = json.dumps({"a": "b"})
turns: list[list[object]] = []
for i in range(1, DEFAULT_MAX_TURNS + 1):
turns.append([get_text_message(str(i)), get_function_tool_call("tool", func_output)])
model.add_multiple_turn_outputs(turns)

# Passing None should make Runner default to DEFAULT_MAX_TURNS (10), so 11th turn exceeds.
with pytest.raises(MaxTurnsExceeded):
await Runner.run(agent, input="go", max_turns=None)


@pytest.mark.asyncio
async def test_agent_as_tool_forwards_max_turns():
# Inner agent will exceed when limited to 1 turn.
inner_model = FakeModel()
inner_agent = Agent(
name="inner",
model=inner_model,
tools=[get_function_tool("some_function", "ok")],
)

# Make inner agent require more than 1 turn.
func_output = json.dumps({"x": 1})
inner_model.add_multiple_turn_outputs(
[
[get_text_message("t1"), get_function_tool_call("some_function", func_output)],
[get_text_message("t2"), get_function_tool_call("some_function", func_output)],
]
)

# Wrap inner agent as a tool with max_turns=1.
wrapped_tool = inner_agent.as_tool(
tool_name="inner_tool",
tool_description="runs inner agent",
max_turns=1,
)

# Orchestrator will call the wrapped tool twice, causing inner to exceed its max_turns.
outer_model = FakeModel()
orchestrator = Agent(
name="orchestrator",
model=outer_model,
tools=[wrapped_tool],
)

# Outer model asks to call the tool once;
# exceeding happens inside the tool call when inner runs.
outer_model.add_multiple_turn_outputs(
[
[get_function_tool_call("inner_tool")],
]
)

# Since tool default error handling returns a string on error,
# the run should not raise here.
result = await Runner.run(orchestrator, input="start")

# The tool call error should be surfaced as a message back to the model;
# ensure we have some output.
# We don't assert exact message text to avoid brittleness;
# just ensure the run completed with items.
assert len(result.new_items) >= 1
Loading