Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/basic/tools.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import asyncio
from typing import Annotated

from pydantic import BaseModel
from pydantic import BaseModel, Field

from agents import Agent, Runner, function_tool


class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
city: str = Field(description="The city name")
temperature_range: str = Field(description="The temperature range in Celsius")
conditions: str = Field(description="The weather conditions")


@function_tool
def get_weather(city: str) -> Weather:
def get_weather(city: Annotated[str, "The city to get the weather for"]) -> Weather:
"""Get the current weather information for a specified city."""
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")


agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
Expand Down
48 changes: 45 additions & 3 deletions src/agents/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import re
from dataclasses import dataclass
from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints

from griffe import Docstring, DocstringSectionKind
from pydantic import BaseModel, Field, create_model
Expand Down Expand Up @@ -185,6 +185,31 @@ def generate_func_documentation(
)


def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]:
"""Returns the underlying annotation and any metadata from typing.Annotated."""

metadata: tuple[Any, ...] = ()
ann = annotation

while get_origin(ann) is Annotated:
args = get_args(ann)
if not args:
break
ann = args[0]
metadata = (*metadata, *args[1:])

return ann, metadata


def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
"""Extracts a human readable description from Annotated metadata if present."""

for item in metadata:
if isinstance(item, str):
return item
return None


def function_schema(
func: Callable[..., Any],
docstring_style: DocstringStyle | None = None,
Expand Down Expand Up @@ -219,17 +244,34 @@ def function_schema(
# 1. Grab docstring info
if use_docstring_info:
doc_info = generate_func_documentation(func, docstring_style)
param_descs = doc_info.param_descriptions or {}
param_descs = dict(doc_info.param_descriptions or {})
else:
doc_info = None
param_descs = {}

type_hints_with_extras = get_type_hints(func, include_extras=True)
type_hints: dict[str, Any] = {}
annotated_param_descs: dict[str, str] = {}

for name, annotation in type_hints_with_extras.items():
if name == "return":
continue

stripped_ann, metadata = _strip_annotated(annotation)
type_hints[name] = stripped_ann

description = _extract_description_from_metadata(metadata)
if description is not None:
annotated_param_descs[name] = description

for name, description in annotated_param_descs.items():
param_descs.setdefault(name, description)

# Ensure name_override takes precedence even if docstring info is disabled.
func_name = name_override or (doc_info.name if doc_info else func.__name__)

# 2. Inspect function signature and get type hints
sig = inspect.signature(func)
type_hints = get_type_hints(func)
params = list(sig.parameters.items())
takes_context = False
filtered_params = []
Expand Down
40 changes: 39 additions & 1 deletion tests/test_function_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Literal
from typing import Annotated, Any, Literal

import pytest
from pydantic import BaseModel, Field, ValidationError
Expand Down Expand Up @@ -521,6 +521,44 @@ def func_with_optional_field(
fs.params_pydantic_model(**{"required_param": "test", "optional_param": -1.0})


def test_function_uses_annotated_descriptions_without_docstring() -> None:
"""Test that Annotated metadata populates parameter descriptions when docstrings are ignored."""

def add(
a: Annotated[int, "First number to add"],
b: Annotated[int, "Second number to add"],
) -> int:
return a + b

fs = function_schema(add, use_docstring_info=False)

properties = fs.params_json_schema.get("properties", {})
assert properties["a"].get("description") == "First number to add"
assert properties["b"].get("description") == "Second number to add"


def test_function_prefers_docstring_descriptions_over_annotated_metadata() -> None:
"""Test that docstring parameter descriptions take precedence over Annotated metadata."""

def add(
a: Annotated[int, "Annotated description for a"],
b: Annotated[int, "Annotated description for b"],
) -> int:
"""Adds two integers.

Args:
a: Docstring provided description.
"""

return a + b

fs = function_schema(add)

properties = fs.params_json_schema.get("properties", {})
assert properties["a"].get("description") == "Docstring provided description."
assert properties["b"].get("description") == "Annotated description for b"


def test_function_with_field_description_merge():
"""Test that Field descriptions are merged with docstring descriptions."""

Expand Down