Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/basic/tools.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import asyncio
from typing import Annotated

from pydantic import BaseModel
from pydantic import BaseModel, Field

from agents import Agent, Runner, function_tool


class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
city: str = Field(description="The city name")
temperature_range: str = Field(description="The temperature range in Celsius")
conditions: str = Field(description="The weather conditions")


@function_tool
def get_weather(city: str) -> Weather:
def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weather:
"""Get the current weather information for a specified city."""
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")


agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
Expand Down
20 changes: 18 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
from .util._types import MaybeAwaitable

if TYPE_CHECKING:
from .lifecycle import AgentHooks
from .lifecycle import AgentHooks, RunHooks
from .mcp import MCPServer
from .memory.session import Session
from .result import RunResult
from .run import RunConfig


@dataclass
Expand Down Expand Up @@ -384,6 +386,12 @@ def as_tool(
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
is_enabled: bool
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
run_config: RunConfig | None = None,
max_turns: int | None = None,
hooks: RunHooks[TContext] | None = None,
previous_response_id: str | None = None,
conversation_id: str | None = None,
session: Session | None = None,
) -> Tool:
"""Transform this agent into a tool, callable by other agents.

Expand All @@ -410,12 +418,20 @@ def as_tool(
is_enabled=is_enabled,
)
async def run_agent(context: RunContextWrapper, input: str) -> str:
from .run import Runner
from .run import DEFAULT_MAX_TURNS, Runner

resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS

output = await Runner.run(
starting_agent=self,
input=input,
context=context.context,
run_config=run_config,
max_turns=resolved_max_turns,
hooks=hooks,
previous_response_id=previous_response_id,
conversation_id=conversation_id,
session=session,
)
if custom_output_extractor:
return await custom_output_extractor(output)
Expand Down
48 changes: 45 additions & 3 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import re
from dataclasses import dataclass
from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints

from griffe import Docstring, DocstringSectionKind
from pydantic import BaseModel, Field, create_model
Expand Down Expand Up @@ -185,6 +185,31 @@ def generate_func_documentation(
)


def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]:
"""Returns the underlying annotation and any metadata from typing.Annotated."""

metadata: tuple[Any, ...] = ()
ann = annotation

while get_origin(ann) is Annotated:
args = get_args(ann)
if not args:
break
ann = args[0]
metadata = (*metadata, *args[1:])

return ann, metadata


def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
"""Extracts a human readable description from Annotated metadata if present."""

for item in metadata:
if isinstance(item, str):
return item
return None


def function_schema(
func: Callable[..., Any],
docstring_style: DocstringStyle | None = None,
Expand Down Expand Up @@ -219,17 +244,34 @@ def function_schema(
# 1. Grab docstring info
if use_docstring_info:
doc_info = generate_func_documentation(func, docstring_style)
param_descs = doc_info.param_descriptions or {}
param_descs = dict(doc_info.param_descriptions or {})
else:
doc_info = None
param_descs = {}

type_hints_with_extras = get_type_hints(func, include_extras=True)
type_hints: dict[str, Any] = {}
annotated_param_descs: dict[str, str] = {}

for name, annotation in type_hints_with_extras.items():
if name == "return":
continue

stripped_ann, metadata = _strip_annotated(annotation)
type_hints[name] = stripped_ann

description = _extract_description_from_metadata(metadata)
if description is not None:
annotated_param_descs[name] = description

for name, description in annotated_param_descs.items():
param_descs.setdefault(name, description)

# Ensure name_override takes precedence even if docstring info is disabled.
func_name = name_override or (doc_info.name if doc_info else func.__name__)

# 2. Inspect function signature and get type hints
sig = inspect.signature(func)
type_hints = get_type_hints(func)
params = list(sig.parameters.items())
takes_context = False
filtered_params = []
Expand Down
175 changes: 174 additions & 1 deletion tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,24 @@
from __future__ import annotations

from typing import Any

import pytest
from openai.types.responses import ResponseOutputMessage, ResponseOutputText
from pydantic import BaseModel

from agents import Agent, AgentBase, FunctionTool, RunContextWrapper
from agents import (
Agent,
AgentBase,
FunctionTool,
MessageOutputItem,
RunConfig,
RunContextWrapper,
RunHooks,
Runner,
Session,
TResponseInputItem,
)
from agents.tool_context import ToolContext


class BoolCtx(BaseModel):
Expand Down Expand Up @@ -205,3 +222,159 @@ async def custom_extractor(result):
tools = await orchestrator.get_all_tools(context)
assert len(tools) == 1
assert tools[0].name == "custom_tool_name"


@pytest.mark.asyncio
async def test_agent_as_tool_returns_concatenated_text(monkeypatch: pytest.MonkeyPatch) -> None:
"""Agent tool should use default text aggregation when no custom extractor is provided."""

agent = Agent(name="storyteller")

message = ResponseOutputMessage(
id="msg_1",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Hello world",
type="output_text",
logprobs=None,
)
],
)

result = type(
"DummyResult",
(),
{"new_items": [MessageOutputItem(agent=agent, raw_item=message)]},
)()

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "hello"
return result

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = agent.as_tool(
tool_name="story_tool",
tool_description="Tell a short story",
is_enabled=True,
)

assert isinstance(tool, FunctionTool)
tool_context = ToolContext(context=None, tool_name="story_tool", tool_call_id="call_1")
output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')

assert output == "Hello world"


@pytest.mark.asyncio
async def test_agent_as_tool_custom_output_extractor(monkeypatch: pytest.MonkeyPatch) -> None:
"""Custom output extractors should receive the RunResult from Runner.run."""

agent = Agent(name="summarizer")

message = ResponseOutputMessage(
id="msg_2",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Original text",
type="output_text",
logprobs=None,
)
],
)

class DummySession(Session):
session_id = "sess_123"

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
return []

async def add_items(self, items: list[TResponseInputItem]) -> None:
return None

async def pop_item(self) -> TResponseInputItem | None:
return None

async def clear_session(self) -> None:
return None

dummy_session = DummySession()

class DummyResult:
def __init__(self, items: list[MessageOutputItem]) -> None:
self.new_items = items

run_result = DummyResult([MessageOutputItem(agent=agent, raw_item=message)])

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
assert starting_agent is agent
assert input == "summarize this"
assert context is None
assert max_turns == 7
assert hooks is hooks_obj
assert run_config is run_config_obj
assert previous_response_id == "resp_1"
assert conversation_id == "conv_1"
assert session is dummy_session
return run_result

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

async def extractor(result) -> str:
assert result is run_result
return "custom output"

hooks_obj = RunHooks[Any]()
run_config_obj = RunConfig(model="gpt-4.1-mini")

tool = agent.as_tool(
tool_name="summary_tool",
tool_description="Summarize input",
custom_output_extractor=extractor,
is_enabled=True,
run_config=run_config_obj,
max_turns=7,
hooks=hooks_obj,
previous_response_id="resp_1",
conversation_id="conv_1",
session=dummy_session,
)

assert isinstance(tool, FunctionTool)
tool_context = ToolContext(context=None, tool_name="summary_tool", tool_call_id="call_2")
output = await tool.on_invoke_tool(tool_context, '{"input": "summarize this"}')

assert output == "custom output"
40 changes: 39 additions & 1 deletion tests/test_function_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Literal
from typing import Annotated, Any, Literal

import pytest
from pydantic import BaseModel, Field, ValidationError
Expand Down Expand Up @@ -521,6 +521,44 @@ def func_with_optional_field(
fs.params_pydantic_model(**{"required_param": "test", "optional_param": -1.0})


def test_function_uses_annotated_descriptions_without_docstring() -> None:
"""Test that Annotated metadata populates parameter descriptions when docstrings are ignored."""

def add(
a: Annotated[int, "First number to add"],
b: Annotated[int, "Second number to add"],
) -> int:
return a + b

fs = function_schema(add, use_docstring_info=False)

properties = fs.params_json_schema.get("properties", {})
assert properties["a"].get("description") == "First number to add"
assert properties["b"].get("description") == "Second number to add"


def test_function_prefers_docstring_descriptions_over_annotated_metadata() -> None:
"""Test that docstring parameter descriptions take precedence over Annotated metadata."""

def add(
a: Annotated[int, "Annotated description for a"],
b: Annotated[int, "Annotated description for b"],
) -> int:
"""Adds two integers.

Args:
a: Docstring provided description.
"""

return a + b

fs = function_schema(add)

properties = fs.params_json_schema.get("properties", {})
assert properties["a"].get("description") == "Docstring provided description."
assert properties["b"].get("description") == "Annotated description for b"


def test_function_with_field_description_merge():
"""Test that Field descriptions are merged with docstring descriptions."""

Expand Down
Loading