Skip to content

Commit 14154b7

Browse files
feat(function_schema): Add support for pydantic Field annotations in tool arguments (for tools decorated with @function_schema) (#1124)
### Summary This PR allows you to use the pydantic `Field` decorator to constrain tool arguments for tools decorated with `@function_tool` (more specifically, for tools using `function_schema` to parse tool arguments). Such constrains include, e.g. limiting integers to a certain range, but in principle, this should work for [any constrains supported by the API](https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-properties). Specifically, it enables the following syntax: ```python @function_tool def my_tool(age: int = Field(..., gt=0)) -> str: ... ``` Previously, one had to create a nested pydantic `BaseModel` to achieve this functionality. Issue #1123 explains this feature request and the previous workaround. **Example:** ```python import json from pydantic import Field from agents import function_tool @function_tool def my_tool(age: int = Field(..., gt=0)) -> str: return f"The age is {age}" print(json.dumps(my_tool.params_json_schema, indent=2)) ``` **Output:** (compare to #1123) ``` { "properties": { "age": { "exclusiveMinimum": 0, "title": "Age", "type": "integer" } }, "required": [ "age" ], "title": "my_tool_args", "type": "object", "additionalProperties": false } ``` ### Test plan I added unit tests in `tests/test_function_schema.py`. ### Issue number Closes #1123. ### Checks - [x] I've added new tests (if relevant) - [ ] I've added/updated the relevant documentation - [x] I've run `make lint` and `make format` - [x] I've made sure tests pass **Note:** I am happy to add documentation for this; please point me to where I should do so:)
1 parent 1f2dc81 commit 14154b7

File tree

2 files changed

+227
-1
lines changed

2 files changed

+227
-1
lines changed

src/agents/function_schema.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from griffe import Docstring, DocstringSectionKind
1111
from pydantic import BaseModel, Field, create_model
12+
from pydantic.fields import FieldInfo
1213

1314
from .exceptions import UserError
1415
from .run_context import RunContextWrapper
@@ -319,6 +320,14 @@ def function_schema(
319320
ann,
320321
Field(..., description=field_description),
321322
)
323+
elif isinstance(default, FieldInfo):
324+
# Parameter with a default value that is a Field(...)
325+
fields[name] = (
326+
ann,
327+
FieldInfo.merge_field_infos(
328+
default, description=field_description or default.description
329+
),
330+
)
322331
else:
323332
# Parameter with a default value
324333
fields[name] = (

tests/test_function_schema.py

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Literal
44

55
import pytest
6-
from pydantic import BaseModel, ValidationError
6+
from pydantic import BaseModel, Field, ValidationError
77
from typing_extensions import TypedDict
88

99
from agents import RunContextWrapper
@@ -451,3 +451,220 @@ def foo(x: int) -> int:
451451

452452
assert fs.name == "custom"
453453
assert fs.params_json_schema.get("title") == "custom_args"
454+
455+
456+
def test_function_with_field_required_constraints():
457+
"""Test function with required Field parameter that has constraints."""
458+
459+
def func_with_field_constraints(my_number: int = Field(..., gt=10, le=100)) -> int:
460+
return my_number * 2
461+
462+
fs = function_schema(func_with_field_constraints, use_docstring_info=False)
463+
464+
# Check that the schema includes the constraints
465+
properties = fs.params_json_schema.get("properties", {})
466+
my_number_schema = properties.get("my_number", {})
467+
assert my_number_schema.get("type") == "integer"
468+
assert my_number_schema.get("exclusiveMinimum") == 10 # gt=10
469+
assert my_number_schema.get("maximum") == 100 # le=100
470+
471+
# Valid input should work
472+
valid_input = {"my_number": 50}
473+
parsed = fs.params_pydantic_model(**valid_input)
474+
args, kwargs_dict = fs.to_call_args(parsed)
475+
result = func_with_field_constraints(*args, **kwargs_dict)
476+
assert result == 100
477+
478+
# Invalid input: too small (should violate gt=10)
479+
with pytest.raises(ValidationError):
480+
fs.params_pydantic_model(**{"my_number": 5})
481+
482+
# Invalid input: too large (should violate le=100)
483+
with pytest.raises(ValidationError):
484+
fs.params_pydantic_model(**{"my_number": 150})
485+
486+
487+
def test_function_with_field_optional_with_default():
488+
"""Test function with optional Field parameter that has default and constraints."""
489+
490+
def func_with_optional_field(
491+
required_param: str,
492+
optional_param: float = Field(default=5.0, ge=0.0),
493+
) -> str:
494+
return f"{required_param}: {optional_param}"
495+
496+
fs = function_schema(func_with_optional_field, use_docstring_info=False)
497+
498+
# Check that the schema includes the constraints and description
499+
properties = fs.params_json_schema.get("properties", {})
500+
optional_schema = properties.get("optional_param", {})
501+
assert optional_schema.get("type") == "number"
502+
assert optional_schema.get("minimum") == 0.0 # ge=0.0
503+
assert optional_schema.get("default") == 5.0
504+
505+
# Valid input with default
506+
valid_input = {"required_param": "test"}
507+
parsed = fs.params_pydantic_model(**valid_input)
508+
args, kwargs_dict = fs.to_call_args(parsed)
509+
result = func_with_optional_field(*args, **kwargs_dict)
510+
assert result == "test: 5.0"
511+
512+
# Valid input with explicit value
513+
valid_input2 = {"required_param": "test", "optional_param": 10.5}
514+
parsed2 = fs.params_pydantic_model(**valid_input2)
515+
args2, kwargs_dict2 = fs.to_call_args(parsed2)
516+
result2 = func_with_optional_field(*args2, **kwargs_dict2)
517+
assert result2 == "test: 10.5"
518+
519+
# Invalid input: negative value (should violate ge=0.0)
520+
with pytest.raises(ValidationError):
521+
fs.params_pydantic_model(**{"required_param": "test", "optional_param": -1.0})
522+
523+
524+
def test_function_with_field_description_merge():
525+
"""Test that Field descriptions are merged with docstring descriptions."""
526+
527+
def func_with_field_and_docstring(
528+
param_with_field_desc: int = Field(..., description="Field description"),
529+
param_with_both: str = Field(default="hello", description="Field description"),
530+
) -> str:
531+
"""
532+
Function with both field and docstring descriptions.
533+
534+
Args:
535+
param_with_field_desc: Docstring description
536+
param_with_both: Docstring description
537+
"""
538+
return f"{param_with_field_desc}: {param_with_both}"
539+
540+
fs = function_schema(func_with_field_and_docstring, use_docstring_info=True)
541+
542+
# Check that docstring description takes precedence when both exist
543+
properties = fs.params_json_schema.get("properties", {})
544+
param1_schema = properties.get("param_with_field_desc", {})
545+
param2_schema = properties.get("param_with_both", {})
546+
547+
# The docstring description should be used when both are present
548+
assert param1_schema.get("description") == "Docstring description"
549+
assert param2_schema.get("description") == "Docstring description"
550+
551+
552+
def func_with_field_desc_only(
553+
param_with_field_desc: int = Field(..., description="Field description only"),
554+
param_without_desc: str = Field(default="hello"),
555+
) -> str:
556+
return f"{param_with_field_desc}: {param_without_desc}"
557+
558+
559+
def test_function_with_field_description_only():
560+
"""Test that Field descriptions are used when no docstring info."""
561+
562+
fs = function_schema(func_with_field_desc_only)
563+
564+
# Check that field description is used when no docstring
565+
properties = fs.params_json_schema.get("properties", {})
566+
param1_schema = properties.get("param_with_field_desc", {})
567+
param2_schema = properties.get("param_without_desc", {})
568+
569+
assert param1_schema.get("description") == "Field description only"
570+
assert param2_schema.get("description") is None
571+
572+
573+
def test_function_with_field_string_constraints():
574+
"""Test function with Field parameter that has string-specific constraints."""
575+
576+
def func_with_string_field(
577+
name: str = Field(..., min_length=3, max_length=20, pattern=r"^[A-Za-z]+$"),
578+
) -> str:
579+
return f"Hello, {name}!"
580+
581+
fs = function_schema(func_with_string_field, use_docstring_info=False)
582+
583+
# Check that the schema includes string constraints
584+
properties = fs.params_json_schema.get("properties", {})
585+
name_schema = properties.get("name", {})
586+
assert name_schema.get("type") == "string"
587+
assert name_schema.get("minLength") == 3
588+
assert name_schema.get("maxLength") == 20
589+
assert name_schema.get("pattern") == r"^[A-Za-z]+$"
590+
591+
# Valid input
592+
valid_input = {"name": "Alice"}
593+
parsed = fs.params_pydantic_model(**valid_input)
594+
args, kwargs_dict = fs.to_call_args(parsed)
595+
result = func_with_string_field(*args, **kwargs_dict)
596+
assert result == "Hello, Alice!"
597+
598+
# Invalid input: too short
599+
with pytest.raises(ValidationError):
600+
fs.params_pydantic_model(**{"name": "Al"})
601+
602+
# Invalid input: too long
603+
with pytest.raises(ValidationError):
604+
fs.params_pydantic_model(**{"name": "A" * 25})
605+
606+
# Invalid input: doesn't match pattern (contains numbers)
607+
with pytest.raises(ValidationError):
608+
fs.params_pydantic_model(**{"name": "Alice123"})
609+
610+
611+
def test_function_with_field_multiple_constraints():
612+
"""Test function with multiple Field parameters having different constraint types."""
613+
614+
def func_with_multiple_field_constraints(
615+
score: int = Field(..., ge=0, le=100, description="Score from 0 to 100"),
616+
name: str = Field(default="Unknown", min_length=1, max_length=50),
617+
factor: float = Field(default=1.0, gt=0.0, description="Positive multiplier"),
618+
) -> str:
619+
final_score = score * factor
620+
return f"{name} scored {final_score}"
621+
622+
fs = function_schema(func_with_multiple_field_constraints, use_docstring_info=False)
623+
624+
# Check schema structure
625+
properties = fs.params_json_schema.get("properties", {})
626+
627+
# Check score field
628+
score_schema = properties.get("score", {})
629+
assert score_schema.get("type") == "integer"
630+
assert score_schema.get("minimum") == 0
631+
assert score_schema.get("maximum") == 100
632+
assert score_schema.get("description") == "Score from 0 to 100"
633+
634+
# Check name field
635+
name_schema = properties.get("name", {})
636+
assert name_schema.get("type") == "string"
637+
assert name_schema.get("minLength") == 1
638+
assert name_schema.get("maxLength") == 50
639+
assert name_schema.get("default") == "Unknown"
640+
641+
# Check factor field
642+
factor_schema = properties.get("factor", {})
643+
assert factor_schema.get("type") == "number"
644+
assert factor_schema.get("exclusiveMinimum") == 0.0
645+
assert factor_schema.get("default") == 1.0
646+
assert factor_schema.get("description") == "Positive multiplier"
647+
648+
# Valid input with defaults
649+
valid_input = {"score": 85}
650+
parsed = fs.params_pydantic_model(**valid_input)
651+
args, kwargs_dict = fs.to_call_args(parsed)
652+
result = func_with_multiple_field_constraints(*args, **kwargs_dict)
653+
assert result == "Unknown scored 85.0"
654+
655+
# Valid input with all parameters
656+
valid_input2 = {"score": 90, "name": "Alice", "factor": 1.5}
657+
parsed2 = fs.params_pydantic_model(**valid_input2)
658+
args2, kwargs_dict2 = fs.to_call_args(parsed2)
659+
result2 = func_with_multiple_field_constraints(*args2, **kwargs_dict2)
660+
assert result2 == "Alice scored 135.0"
661+
662+
# Test various validation errors
663+
with pytest.raises(ValidationError): # score too high
664+
fs.params_pydantic_model(**{"score": 150})
665+
666+
with pytest.raises(ValidationError): # empty name
667+
fs.params_pydantic_model(**{"score": 50, "name": ""})
668+
669+
with pytest.raises(ValidationError): # zero factor
670+
fs.params_pydantic_model(**{"score": 50, "factor": 0.0})

0 commit comments

Comments
 (0)