Skip to content

Commit a04d25f

Browse files
authored
Add content (e.g. files) returned by tool to FunctionToolResultEvent (#3082)
1 parent 2e4032e commit a04d25f

File tree

5 files changed

+73
-21
lines changed

5 files changed

+73
-21
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -913,13 +913,19 @@ async def _call_tools(
913913

914914
async def handle_call_or_result(
915915
coro_or_task: Awaitable[
916-
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]
916+
tuple[
917+
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
918+
]
917919
]
918-
| Task[tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]],
920+
| Task[
921+
tuple[
922+
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
923+
]
924+
],
919925
index: int,
920926
) -> _messages.HandleResponseEvent | None:
921927
try:
922-
tool_part, tool_user_part = (
928+
tool_part, tool_user_content = (
923929
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
924930
)
925931
except exceptions.CallDeferred:
@@ -928,10 +934,10 @@ async def handle_call_or_result(
928934
deferred_calls_by_index[index] = 'unapproved'
929935
else:
930936
tool_parts_by_index[index] = tool_part
931-
if tool_user_part:
932-
user_parts_by_index[index] = tool_user_part
937+
if tool_user_content:
938+
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)
933939

934-
return _messages.FunctionToolResultEvent(tool_part)
940+
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)
935941

936942
if tool_manager.should_call_sequentially(tool_calls):
937943
for index, call in enumerate(tool_calls):
@@ -971,7 +977,7 @@ async def _call_tool(
971977
tool_manager: ToolManager[DepsT],
972978
tool_call: _messages.ToolCallPart,
973979
tool_call_result: DeferredToolResult | None,
974-
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]:
980+
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None]:
975981
try:
976982
if tool_call_result is None:
977983
tool_result = await tool_manager.handle_call(tool_call)
@@ -1048,14 +1054,7 @@ async def _call_tool(
10481054
metadata=tool_return.metadata,
10491055
)
10501056

1051-
user_part: _messages.UserPromptPart | None = None
1052-
if tool_return.content:
1053-
user_part = _messages.UserPromptPart(
1054-
content=tool_return.content,
1055-
part_kind='user-prompt',
1056-
)
1057-
1058-
return return_part, user_part
1057+
return return_part, tool_return.content or None
10591058

10601059

10611060
@dataclasses.dataclass

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,9 @@ class FunctionToolResultEvent:
16541654

16551655
_: KW_ONLY
16561656

1657+
content: str | Sequence[UserContent] | None = None
1658+
"""The content that will be sent to the model as a UserPromptPart following the result."""
1659+
16571660
event_kind: Literal['function_tool_result'] = 'function_tool_result'
16581661
"""Event type identifier, used as a discriminator."""
16591662

tests/test_dbos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
359359
BasicSpan(content='ctx.run_step=1'),
360360
BasicSpan(
361361
content=IsStr(
362-
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
362+
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
363363
)
364364
),
365365
],
@@ -374,7 +374,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
374374
BasicSpan(content='ctx.run_step=1'),
375375
BasicSpan(
376376
content=IsStr(
377-
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
377+
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
378378
)
379379
),
380380
],
@@ -435,7 +435,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
435435
BasicSpan(content='ctx.run_step=2'),
436436
BasicSpan(
437437
content=IsStr(
438-
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
438+
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
439439
)
440440
),
441441
],

tests/test_streaming.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FinalResultEvent,
2020
FunctionToolCallEvent,
2121
FunctionToolResultEvent,
22+
ImageUrl,
2223
ModelMessage,
2324
ModelRequest,
2425
ModelResponse,
@@ -1575,3 +1576,52 @@ async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[Agen
15751576
FinalResultEvent(tool_name=None, tool_call_id=None),
15761577
]
15771578
)
1579+
1580+
1581+
async def test_stream_tool_returning_user_content():
1582+
m = TestModel()
1583+
1584+
agent = Agent(m)
1585+
assert agent.name is None
1586+
1587+
@agent.tool_plain
1588+
async def get_image() -> ImageUrl:
1589+
return ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg')
1590+
1591+
events: list[AgentStreamEvent] = []
1592+
1593+
async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]):
1594+
async for event in stream:
1595+
events.append(event)
1596+
1597+
await agent.run('Hello', event_stream_handler=event_stream_handler)
1598+
1599+
assert events == snapshot(
1600+
[
1601+
PartStartEvent(
1602+
index=0,
1603+
part=ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()),
1604+
),
1605+
FunctionToolCallEvent(part=ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr())),
1606+
FunctionToolResultEvent(
1607+
result=ToolReturnPart(
1608+
tool_name='get_image',
1609+
content='See file bd38f5',
1610+
tool_call_id=IsStr(),
1611+
timestamp=IsNow(tz=timezone.utc),
1612+
),
1613+
content=[
1614+
'This is file bd38f5:',
1615+
ImageUrl(
1616+
url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg',
1617+
identifier='bd38f5',
1618+
),
1619+
],
1620+
),
1621+
PartStartEvent(index=0, part=TextPart(content='')),
1622+
FinalResultEvent(tool_name=None, tool_call_id=None),
1623+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='{"get_image":"See ')),
1624+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='file ')),
1625+
PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='bd38f5"}')),
1626+
]
1627+
)

tests/test_temporal.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ async def test_complex_agent_run_in_workflow(
424424
BasicSpan(content='ctx.run_step=1'),
425425
BasicSpan(
426426
content=IsStr(
427-
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
427+
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
428428
)
429429
),
430430
],
@@ -453,7 +453,7 @@ async def test_complex_agent_run_in_workflow(
453453
BasicSpan(content='ctx.run_step=1'),
454454
BasicSpan(
455455
content=IsStr(
456-
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
456+
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
457457
)
458458
),
459459
],
@@ -544,7 +544,7 @@ async def test_complex_agent_run_in_workflow(
544544
BasicSpan(content='ctx.run_step=2'),
545545
BasicSpan(
546546
content=IsStr(
547-
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"event_kind":"function_tool_result"}'
547+
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
548548
)
549549
),
550550
],

0 commit comments

Comments
 (0)