Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
from .lifecycle import AgentHooks
from .lifecycle import AgentHooks, RunHooks
from .mcp import MCPServer
from .memory.session import Session
from .result import RunResult
from .run import RunConfig


@dataclass
Expand Down Expand Up @@ -384,6 +386,12 @@ def as_tool(
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
run_config: RunConfig | None = None,
max_turns: int | None = None,
hooks: RunHooks[TContext] | None = None,
previous_response_id: str | None = None,
conversation_id: str | None = None,
session: Session | None = None,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand All @@ -410,12 +418,20 @@ def as_tool(
is_enabled=is_enabled,
)
async def run_agent(context: RunContextWrapper, input: str) -> str:
from .run import Runner
from .run import DEFAULT_MAX_TURNS, Runner

resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS

output = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
run_config=run_config,
max_turns=resolved_max_turns,
hooks=hooks,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
session=session,
)
if custom_output_extractor:
return await custom_output_extractor(output)
Expand Down
175 changes: 174 additions & 1 deletion tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
from __future__ import annotations

from typing import Any

import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from pydantic import BaseModel

from agents import Agent, AgentBase, FunctionTool, RunContextWrapper
from agents import (
Agent,
AgentBase,
FunctionTool,
MessageOutputItem,
RunConfig,
RunContextWrapper,
RunHooks,
Runner,
Session,
TResponseInputItem,
)
from agents.tool_context import ToolContext


class BoolCtx(BaseModel):
Expand Down Expand Up @@ -205,3 +222,159 @@ async def custom_extractor(result):
tools = await orchestrator.get_all_tools(context)
assert len(tools) == 1
assert tools[0].name == "custom_tool_name"


@pytest.mark.asyncio
async def test_agent_as_tool_returns_concatenated_text(monkeypatch: pytest.MonkeyPatch) -> None:
"""Agent tool should use default text aggregation when no custom extractor is provided."""

agent = Agent(name="storyteller")

message = ResponseOutputMessage(
id="msg_1",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Hello world",
type="output_text",
logprobs=None,
)
],
)

result = type(
"DummyResult",
(),
{"new_items": [MessageOutputItem(agent=agent, raw_item=message)]},
)()

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
return result

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = agent.as_tool(
tool_name="story_tool",
tool_description="Tell a short story",
is_enabled=True,
)

assert isinstance(tool, FunctionTool)
tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1")
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')

assert output == "Hello world"


@pytest.mark.asyncio
async def test_agent_as_tool_custom_output_extractor(monkeypatch: pytest.MonkeyPatch) -> None:
"""Custom output extractors should receive the RunResult from Runner.run."""

agent = Agent(name="summarizer")

message = ResponseOutputMessage(
id="msg_2",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Original text",
type="output_text",
logprobs=None,
)
],
)

class DummySession(Session):
session_id = "sess_123"

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
return []

async def add_items(self, items: list[TResponseInputItem]) -> None:
return None

async def pop_item(self) -> TResponseInputItem | None:
return None

async def clear_session(self) -> None:
return None

dummy_session = DummySession()

class DummyResult:
def __init__(self, items: list[MessageOutputItem]) -> None:
self.new_items = items

run_result = DummyResult([MessageOutputItem(agent=agent, raw_item=message)])

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "summarize this"
assert context is None
assert max_turns == 7
assert hooks is hooks_obj
assert run_config is run_config_obj
assert previous_response_id == "resp_1"
assert conversation_id == "conv_1"
assert session is dummy_session
return run_result

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

async def extractor(result) -> str:
assert result is run_result
return "custom output"

hooks_obj = RunHooks[Any]()
run_config_obj = RunConfig(model="gpt-4.1-mini")

tool = agent.as_tool(
tool_name="summary_tool",
tool_description="Summarize input",
custom_output_extractor=extractor,
is_enabled=True,
run_config=run_config_obj,
max_turns=7,
hooks=hooks_obj,
previous_response_id="resp_1",
conversation_id="conv_1",
session=dummy_session,
)

assert isinstance(tool, FunctionTool)
tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2")
output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}')

assert output == "custom output"