Skip to content

Commit 1a72ca5

Browse files
authored
allow tools to return any (#136)
1 parent 3f0234f commit 1a72ca5

File tree

5 files changed

+73
-39
lines changed

5 files changed

+73
-39
lines changed

docs/agents.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ agent.run_sync('hello', model=FunctionModel(print_schema))
489489

490490
_(This example is complete, it can be run "as is")_
491491

492-
The return type of tool can be any valid JSON object ([`JsonData`][pydantic_ai.tools.JsonData]) as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
492+
The return type of tool can be anything which Pydantic can serialize to JSON as some models (e.g. Gemini) support semi-structured return values, some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON.
493493

494494
If a tool has a single parameter that can be represented as an object in JSON schema (e.g. dataclass, TypedDict, pydantic model), the schema for the tool is simplified to be just that object. (TODO example)
495495

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from __future__ import annotations as _annotations
22

33
import json
4-
from collections.abc import Mapping, Sequence
54
from dataclasses import dataclass, field
65
from datetime import datetime
7-
from typing import TYPE_CHECKING, Annotated, Any, Literal, Union
6+
from typing import Annotated, Any, Literal, Union
87

98
import pydantic
109
import pydantic_core
1110
from pydantic import TypeAdapter
12-
from typing_extensions import TypeAlias, TypeAliasType
1311

1412
from . import _pydantic
1513
from ._utils import now_utc as _now_utc
@@ -44,13 +42,7 @@ class UserPrompt:
4442
"""Message type identifier, this type is available on all message as a discriminator."""
4543

4644

47-
JsonData: TypeAlias = 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]'
48-
if not TYPE_CHECKING:
49-
# work around for https://github.com/pydantic/pydantic/issues/10873
50-
# this is need for pydantic to work both `json_ta` and `MessagesTypeAdapter` at the bottom of this file
51-
JsonData = TypeAliasType('JsonData', 'Union[str, int, float, None, Sequence[JsonData], Mapping[str, JsonData]]')
52-
53-
json_ta: TypeAdapter[JsonData] = TypeAdapter(JsonData)
45+
tool_return_ta: TypeAdapter[Any] = TypeAdapter(Any)
5446

5547

5648
@dataclass
@@ -59,7 +51,7 @@ class ToolReturn:
5951

6052
tool_name: str
6153
"""The name of the "tool" was called."""
62-
content: JsonData
54+
content: Any
6355
"""The return value."""
6456
tool_id: str | None = None
6557
"""Optional tool identifier, this is used by some models including OpenAI."""
@@ -72,15 +64,14 @@ def model_response_str(self) -> str:
7264
if isinstance(self.content, str):
7365
return self.content
7466
else:
75-
content = json_ta.validate_python(self.content)
76-
return json_ta.dump_json(content).decode()
67+
return tool_return_ta.dump_json(self.content).decode()
7768

78-
def model_response_object(self) -> dict[str, JsonData]:
69+
def model_response_object(self) -> dict[str, Any]:
7970
# gemini supports JSON dict return values, but no other JSON types, hence we wrap anything else in a dict
8071
if isinstance(self.content, dict):
81-
return json_ta.validate_python(self.content) # pyright: ignore[reportReturnType]
72+
return tool_return_ta.dump_python(self.content, mode='json') # pyright: ignore[reportUnknownMemberType]
8273
else:
83-
return {'return_value': json_ta.validate_python(self.content)}
74+
return {'return_value': tool_return_ta.dump_python(self.content, mode='json')}
8475

8576

8677
@dataclass

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations as _annotations
22

33
import inspect
4-
from collections.abc import Awaitable, Mapping, Sequence
4+
from collections.abc import Awaitable
55
from dataclasses import dataclass, field
66
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast
77

88
from pydantic import ValidationError
99
from pydantic_core import SchemaValidator
10-
from typing_extensions import Concatenate, ParamSpec, TypeAlias, final
10+
from typing_extensions import Concatenate, ParamSpec, final
1111

1212
from . import _pydantic, _utils, messages
1313
from .exceptions import ModelRetry, UnexpectedModelBehavior
@@ -23,12 +23,10 @@
2323
'RunContext',
2424
'ResultValidatorFunc',
2525
'SystemPromptFunc',
26-
'ToolReturnValue',
2726
'ToolFuncContext',
2827
'ToolFuncPlain',
2928
'ToolFuncEither',
3029
'ToolParams',
31-
'JsonData',
3230
'Tool',
3331
)
3432

@@ -75,17 +73,12 @@ class RunContext(Generic[AgentDeps]):
7573
Usage `ResultValidator[AgentDeps, ResultData]`.
7674
"""
7775

78-
JsonData: TypeAlias = 'None | str | int | float | Sequence[JsonData] | Mapping[str, JsonData]'
79-
"""Type representing any JSON data."""
80-
81-
ToolReturnValue = Union[JsonData, Awaitable[JsonData]]
82-
"""Return value of a tool function."""
83-
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], ToolReturnValue]
76+
ToolFuncContext = Callable[Concatenate[RunContext[AgentDeps], ToolParams], Any]
8477
"""A tool function that takes `RunContext` as the first argument.
8578
8679
Usage `ToolContextFunc[AgentDeps, ToolParams]`.
8780
"""
88-
ToolFuncPlain = Callable[ToolParams, ToolReturnValue]
81+
ToolFuncPlain = Callable[ToolParams, Any]
8982
"""A tool function that does not take `RunContext` as the first argument.
9083
9184
Usage `ToolPlainFunc[ToolParams]`.
@@ -146,8 +139,8 @@ async def my_tool(ctx: RunContext[int], x: int, y: int) -> str:
146139
function: The Python function to call as the tool.
147140
takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] first argument.
148141
max_retries: Maximum number of retries allowed for this tool, set to the agent default if `None`.
149-
name: Name of the tool, inferred from the function if left blank.
150-
description: Description of the tool, inferred from the function if left blank.
142+
name: Name of the tool, inferred from the function if `None`.
143+
description: Description of the tool, inferred from the function if `None`.
151144
"""
152145
f = _pydantic.function_schema(function, takes_ctx)
153146
self.function = function

tests/test_tools.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from inline_snapshot import snapshot
66
from pydantic import BaseModel, Field
7+
from pydantic_core import PydanticSerializationError
78

89
from pydantic_ai import Agent, RunContext, Tool, UserError
910
from pydantic_ai.messages import Message, ModelAnyResponse, ModelTextResponse
@@ -209,12 +210,14 @@ def test_docstring_google_no_body(set_event_loop: None):
209210
)
210211

211212

213+
class Foo(BaseModel):
214+
x: int
215+
y: str
216+
217+
212218
def test_takes_just_model(set_event_loop: None):
213219
agent = Agent()
214220

215-
class Foo(BaseModel):
216-
x: int
217-
y: str
218221

219222
@agent.tool_plain
220223
def takes_just_model(model: Foo) -> str:
@@ -343,3 +346,50 @@ def plain_tool(x: int) -> int:
343346
def test_init_plain_tool_invalid():
344347
with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'):
345348
Tool(ctx_tool, False)
349+
350+
351+
def test_return_pydantic_model(set_event_loop: None):
352+
agent = Agent('test')
353+
354+
@agent.tool_plain
355+
def return_pydantic_model(x: int) -> Foo:
356+
return Foo(x=x, y='a')
357+
358+
result = agent.run_sync('')
359+
assert result.data == snapshot('{"return_pydantic_model":{"x":0,"y":"a"}}')
360+
361+
362+
def test_return_bytes(set_event_loop: None):
363+
agent = Agent('test')
364+
365+
@agent.tool_plain
366+
def return_pydantic_model() -> bytes:
367+
return '🐈 Hello'.encode()
368+
369+
result = agent.run_sync('')
370+
assert result.data == snapshot('{"return_pydantic_model":"🐈 Hello"}')
371+
372+
373+
def test_return_bytes_invalid(set_event_loop: None):
374+
agent = Agent('test')
375+
376+
@agent.tool_plain
377+
def return_pydantic_model() -> bytes:
378+
return b'\00 \x81'
379+
380+
with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'):
381+
agent.run_sync('')
382+
383+
384+
def test_return_unknown(set_event_loop: None):
385+
agent = Agent('test')
386+
387+
class Foobar:
388+
pass
389+
390+
@agent.tool_plain
391+
def return_pydantic_model() -> Foobar:
392+
return Foobar()
393+
394+
with pytest.raises(PydanticSerializationError, match='Unable to serialize unknown type:'):
395+
agent.run_sync('')

tests/typed_agent.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,15 @@ def ok_tool_plain(x: str) -> dict[str, str]:
6161

6262

6363
@typed_agent.tool_plain
64-
def ok_json_list(x: str) -> list[Union[str, int]]:
64+
async def ok_json_list(x: str) -> list[Union[str, int]]:
6565
return [x, 1]
6666

6767

68+
@typed_agent.tool
69+
async def ok_ctx(ctx: RunContext[MyDeps], x: str) -> list[int | str]:
70+
return [ctx.deps.foo, ctx.deps.bar, x]
71+
72+
6873
@typed_agent.tool
6974
async def bad_tool1(ctx: RunContext[MyDeps], x: str) -> str:
7075
total = ctx.deps.foo + ctx.deps.spam # type: ignore[attr-defined]
@@ -76,11 +81,6 @@ async def bad_tool2(ctx: RunContext[int], x: str) -> str:
7681
return f'{x} {ctx.deps}'
7782

7883

79-
@typed_agent.tool_plain # type: ignore[arg-type]
80-
async def bad_tool_return(x: int) -> list[MyDeps]:
81-
return [MyDeps(1, x)]
82-
83-
8484
with expect_error(ValueError):
8585

8686
@typed_agent.tool # type: ignore[arg-type]

0 commit comments

Comments
 (0)