Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def as_tool(
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
max_turns: int | None = None,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand All @@ -402,6 +403,8 @@ def as_tool(
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the tool is enabled. Disabled tools are hidden
from the LLM at runtime.
max_turns: The maximum number of turns the agent can take when running as a tool.
If not provided, the default value will be used.
"""

@function_tool(
Expand All @@ -413,9 +416,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
from .run import Runner

output = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
starting_agent=self, input=input, context=context.context, max_turns=max_turns
)
if custom_output_extractor:
return await custom_output_extractor(output)
Expand Down
4 changes: 3 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ async def run(
input: str | list[TResponseInputItem],
*,
context: TContext | None = None,
max_turns: int = DEFAULT_MAX_TURNS,
max_turns: int | None = DEFAULT_MAX_TURNS,
hooks: RunHooks[TContext] | None = None,
run_config: RunConfig | None = None,
previous_response_id: str | None = None,
Expand Down Expand Up @@ -277,6 +277,8 @@ async def run(
agent. Agents may perform handoffs, so we don't know the specific type of the output.
"""
runner = DEFAULT_AGENT_RUNNER
if max_turns is None:
max_turns = DEFAULT_MAX_TURNS
return await runner.run(
starting_agent,
input,
Expand Down
84 changes: 84 additions & 0 deletions tests/test_agent_as_tool_max_turns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import json

import pytest

from agents import Agent, MaxTurnsExceeded, Runner
from agents.items import TResponseOutputItem
from agents.run import DEFAULT_MAX_TURNS

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message


@pytest.mark.asyncio
async def test_runner_run_max_turns_none_defaults_to_constant():
model = FakeModel()
agent = Agent(
name="test_runner_max_turns_none",
model=model,
tools=[get_function_tool("tool", "ok")],
)

# Prepare 11 turns (DEFAULT_MAX_TURNS is 10) to ensure exceeding default.
func_output = json.dumps({"a": "b"})
turns: list[list[TResponseOutputItem] | Exception] = []
for i in range(1, DEFAULT_MAX_TURNS + 1):
turns.append([get_text_message(str(i)), get_function_tool_call("tool", func_output)])
model.add_multiple_turn_outputs(turns)

# Passing None should make Runner default to DEFAULT_MAX_TURNS (10), so 11th turn exceeds.
with pytest.raises(MaxTurnsExceeded):
await Runner.run(agent, input="go", max_turns=None)


@pytest.mark.asyncio
async def test_agent_as_tool_forwards_max_turns():
# Inner agent will exceed when limited to 1 turn.
inner_model = FakeModel()
inner_agent = Agent(
name="inner",
model=inner_model,
tools=[get_function_tool("some_function", "ok")],
)

# Make inner agent require more than 1 turn.
func_output = json.dumps({"x": 1})
inner_turns: list[list[TResponseOutputItem] | Exception] = [
[get_text_message("t1"), get_function_tool_call("some_function", func_output)],
[get_text_message("t2"), get_function_tool_call("some_function", func_output)],
]
inner_model.add_multiple_turn_outputs(inner_turns)

# Wrap inner agent as a tool with max_turns=1.
wrapped_tool = inner_agent.as_tool(
tool_name="inner_tool",
tool_description="runs inner agent",
max_turns=1,
)

# Orchestrator will call the wrapped tool twice, causing inner to exceed its max_turns.
outer_model = FakeModel()
orchestrator = Agent(
name="orchestrator",
model=outer_model,
tools=[wrapped_tool],
)

# Outer model asks to call the tool once;
# exceeding happens inside the tool call when inner runs.
outer_turns: list[list[TResponseOutputItem] | Exception] = [
[get_function_tool_call("inner_tool")],
]
outer_model.add_multiple_turn_outputs(outer_turns)

# Since tool default error handling returns a string on error,
# the run should not raise here.
result = await Runner.run(orchestrator, input="start")

# The tool call error should be surfaced as a message back to the model;
# ensure we have some output.
# We don't assert exact message text to avoid brittleness;
# just ensure the run completed with items.
assert len(result.new_items) >= 1