diff --git a/src/agents/agent.py b/src/agents/agent.py index b64a6ea1d..a061926b1 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -30,9 +30,11 @@ from .util._types import MaybeAwaitable if TYPE_CHECKING: - from .lifecycle import AgentHooks + from .lifecycle import AgentHooks, RunHooks from .mcp import MCPServer + from .memory.session import Session from .result import RunResult + from .run import RunConfig @dataclass @@ -384,6 +386,12 @@ def as_tool( custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True, + run_config: RunConfig | None = None, + max_turns: int | None = None, + hooks: RunHooks[TContext] | None = None, + previous_response_id: str | None = None, + conversation_id: str | None = None, + session: Session | None = None, ) -> Tool: """Transform this agent into a tool, callable by other agents. @@ -410,12 +418,20 @@ def as_tool( is_enabled=is_enabled, ) async def run_agent(context: RunContextWrapper, input: str) -> str: - from .run import Runner + from .run import DEFAULT_MAX_TURNS, Runner + + resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS output = await Runner.run( starting_agent=self, input=input, context=context.context, + run_config=run_config, + max_turns=resolved_max_turns, + hooks=hooks, + previous_response_id=previous_response_id, + conversation_id=conversation_id, + session=session, ) if custom_output_extractor: return await custom_output_extractor(output) diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 3307c7a1a..813f72c28 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -1,7 +1,24 @@ +from __future__ import annotations + +from typing import Any + import pytest +from openai.types.responses import ResponseOutputMessage, ResponseOutputText from pydantic import BaseModel -from agents import Agent, AgentBase, FunctionTool, RunContextWrapper +from agents import ( + Agent, + AgentBase, + FunctionTool, + MessageOutputItem, + RunConfig, + RunContextWrapper, + RunHooks, + Runner, + Session, + TResponseInputItem, +) +from agents.tool_context import ToolContext class BoolCtx(BaseModel): @@ -205,3 +222,159 @@ async def custom_extractor(result): tools = await orchestrator.get_all_tools(context) assert len(tools) == 1 assert tools[0].name == "custom_tool_name" + + +@pytest.mark.asyncio +async def test_agent_as_tool_returns_concatenated_text(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent tool should use default text aggregation when no custom extractor is provided.""" + + agent = Agent(name="storyteller") + + message = ResponseOutputMessage( + id="msg_1", + role="assistant", + status="completed", + type="message", + content=[ + ResponseOutputText( + annotations=[], + text="Hello world", + type="output_text", + logprobs=None, + ) + ], + ) + + result = type( + "DummyResult", + (), + {"new_items": [MessageOutputItem(agent=agent, raw_item=message)]}, + )() + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + conversation_id, + session, + ): + assert starting_agent is agent + assert input == "hello" + return result + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + + tool = agent.as_tool( + tool_name="story_tool", + tool_description="Tell a short story", + is_enabled=True, + ) + + assert isinstance(tool, FunctionTool) + tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1") + output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}') + + assert output == "Hello world" + + +@pytest.mark.asyncio +async def test_agent_as_tool_custom_output_extractor(monkeypatch: pytest.MonkeyPatch) -> None: + """Custom output extractors should receive the RunResult from Runner.run.""" + + agent = Agent(name="summarizer") + + message = ResponseOutputMessage( + id="msg_2", + role="assistant", + status="completed", + type="message", + content=[ + ResponseOutputText( + annotations=[], + text="Original text", + type="output_text", + logprobs=None, + ) + ], + ) + + class DummySession(Session): + session_id = "sess_123" + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + return [] + + async def add_items(self, items: list[TResponseInputItem]) -> None: + return None + + async def pop_item(self) -> TResponseInputItem | None: + return None + + async def clear_session(self) -> None: + return None + + dummy_session = DummySession() + + class DummyResult: + def __init__(self, items: list[MessageOutputItem]) -> None: + self.new_items = items + + run_result = DummyResult([MessageOutputItem(agent=agent, raw_item=message)]) + + async def fake_run( + cls, + starting_agent, + input, + *, + context, + max_turns, + hooks, + run_config, + previous_response_id, + conversation_id, + session, + ): + assert starting_agent is agent + assert input == "summarize this" + assert context is None + assert max_turns == 7 + assert hooks is hooks_obj + assert run_config is run_config_obj + assert previous_response_id == "resp_1" + assert conversation_id == "conv_1" + assert session is dummy_session + return run_result + + monkeypatch.setattr(Runner, "run", classmethod(fake_run)) + + async def extractor(result) -> str: + assert result is run_result + return "custom output" + + hooks_obj = RunHooks[Any]() + run_config_obj = RunConfig(model="gpt-4.1-mini") + + tool = agent.as_tool( + tool_name="summary_tool", + tool_description="Summarize input", + custom_output_extractor=extractor, + is_enabled=True, + run_config=run_config_obj, + max_turns=7, + hooks=hooks_obj, + previous_response_id="resp_1", + conversation_id="conv_1", + session=dummy_session, + ) + + assert isinstance(tool, FunctionTool) + tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2") + output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}') + + assert output == "custom output"