Skip to content

Commit 3e1f634

Browse files
authored
Fix parallel tool calling with tools returning ToolReturn with content (#2365)
1 parent c7a3591 commit 3e1f634

File tree

3 files changed

+123
-23
lines changed

3 files changed

+123
-23
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901
659659
for call in calls_to_run:
660660
yield _messages.FunctionToolCallEvent(call)
661661

662-
user_parts: list[_messages.UserPromptPart] = []
662+
user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list)
663663

664664
if calls_to_run:
665665
# Run all tool tasks in parallel
666-
parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {}
666+
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
667667
with ctx.deps.tracer.start_as_current_span(
668668
'running tools',
669669
attributes={
@@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901
681681
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
682682
for task in done:
683683
index = tasks.index(task)
684-
tool_result_part, extra_parts = task.result()
685-
yield _messages.FunctionToolResultEvent(tool_result_part)
684+
tool_part, tool_user_parts = task.result()
685+
yield _messages.FunctionToolResultEvent(tool_part)
686686

687-
parts_by_index[index] = [tool_result_part, *extra_parts]
687+
tool_parts_by_index[index] = tool_part
688+
user_parts_by_index[index] = tool_user_parts
688689

689690
# We append the results at the end, rather than as they are received, to retain a consistent ordering
690691
# This is mostly just to simplify testing
691-
for k in sorted(parts_by_index):
692-
output_parts.extend(parts_by_index[k])
692+
for k in sorted(tool_parts_by_index):
693+
output_parts.append(tool_parts_by_index[k])
693694

694695
# Finally, we handle deferred tool calls
695696
for call in tool_calls_by_kind['deferred']:
@@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901
704705
else:
705706
yield _messages.FunctionToolCallEvent(call)
706707

707-
output_parts.extend(user_parts)
708+
for k in sorted(user_parts_by_index):
709+
output_parts.extend(user_parts_by_index[k])
708710

709711
if final_result:
710712
output_final_result.append(final_result)
@@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901
713715
async def _call_function_tool(
714716
tool_manager: ToolManager[DepsT],
715717
tool_call: _messages.ToolCallPart,
716-
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]:
718+
) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]:
717719
try:
718720
tool_result = await tool_manager.handle_call(tool_call)
719721
except ToolRetryError as e:
720722
return (e.tool_retry, [])
721723

722-
part = _messages.ToolReturnPart(
724+
tool_part = _messages.ToolReturnPart(
723725
tool_name=tool_call.tool_name,
724726
content=tool_result,
725727
tool_call_id=tool_call.tool_call_id,
726728
)
727-
extra_parts: list[_messages.ModelRequestPart] = []
729+
user_parts: list[_messages.UserPromptPart] = []
728730

729731
if isinstance(tool_result, _messages.ToolReturn):
730732
if (
@@ -740,12 +742,12 @@ async def _call_function_tool(
740742
f'Please use `content` instead.'
741743
)
742744

743-
part.content = tool_result.return_value # type: ignore
744-
part.metadata = tool_result.metadata
745+
tool_part.content = tool_result.return_value # type: ignore
746+
tool_part.metadata = tool_result.metadata
745747
if tool_result.content:
746-
extra_parts.append(
748+
user_parts.append(
747749
_messages.UserPromptPart(
748-
content=list(tool_result.content),
750+
content=tool_result.content,
749751
part_kind='user-prompt',
750752
)
751753
)
@@ -763,7 +765,7 @@ def process_content(content: Any) -> Any:
763765
else:
764766
identifier = multi_modal_content_identifier(content.url)
765767

766-
extra_parts.append(
768+
user_parts.append(
767769
_messages.UserPromptPart(
768770
content=[f'This is file {identifier}:', content],
769771
part_kind='user-prompt',
@@ -775,11 +777,11 @@ def process_content(content: Any) -> Any:
775777

776778
if isinstance(tool_result, list):
777779
contents = cast(list[Any], tool_result)
778-
part.content = [process_content(content) for content in contents]
780+
tool_part.content = [process_content(content) for content in contents]
779781
else:
780-
part.content = process_content(tool_result)
782+
tool_part.content = process_content(tool_result)
781783

782-
return (part, extra_parts)
784+
return (tool_part, user_parts)
783785

784786

785787
@dataclasses.dataclass

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ class ToolReturn:
412412
return_value: Any
413413
"""The return value to be used in the tool response."""
414414

415-
content: Sequence[UserContent] | None = None
416-
"""The content sequence to be sent to the model as a UserPromptPart."""
415+
content: str | Sequence[UserContent] | None = None
416+
"""The content to be sent to the model as a UserPromptPart."""
417417

418418
metadata: Any = None
419419
"""Additional data that can be accessed programmatically by the application but is not sent to the LLM."""

tests/test_tools.py

Lines changed: 100 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,26 @@
1313

1414
from pydantic_ai import Agent, RunContext, Tool, UserError
1515
from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior
16-
from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart
16+
from pydantic_ai.messages import (
17+
ModelMessage,
18+
ModelRequest,
19+
ModelResponse,
20+
TextPart,
21+
ToolCallPart,
22+
ToolReturn,
23+
ToolReturnPart,
24+
UserPromptPart,
25+
)
1726
from pydantic_ai.models.function import AgentInfo, FunctionModel
1827
from pydantic_ai.models.test import TestModel
1928
from pydantic_ai.output import DeferredToolCalls, ToolOutput
2029
from pydantic_ai.tools import ToolDefinition
2130
from pydantic_ai.toolsets.deferred import DeferredToolset
2231
from pydantic_ai.toolsets.function import FunctionToolset
2332
from pydantic_ai.toolsets.prefixed import PrefixedToolset
33+
from pydantic_ai.usage import Usage
2434

25-
from .conftest import IsStr
35+
from .conftest import IsDatetime, IsStr
2636

2737

2838
def test_tool_no_ctx():
@@ -1321,3 +1331,91 @@ def test_output_type_deferred_tool_calls_by_itself():
13211331
def test_output_type_empty():
13221332
with pytest.raises(UserError, match='At least one output type must be provided.'):
13231333
Agent(TestModel(), output_type=[])
1334+
1335+
1336+
def test_parallel_tool_return():
1337+
def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1338+
if len(messages) == 1:
1339+
return ModelResponse(
1340+
parts=[ToolCallPart('get_price', {'fruit': 'apple'}), ToolCallPart('get_price', {'fruit': 'banana'})]
1341+
)
1342+
else:
1343+
return ModelResponse(
1344+
parts=[
1345+
TextPart('Done!'),
1346+
]
1347+
)
1348+
1349+
agent = Agent(FunctionModel(llm))
1350+
1351+
@agent.tool_plain
1352+
def get_price(fruit: str) -> ToolReturn:
1353+
return ToolReturn(
1354+
return_value=10.0,
1355+
content=f'The price of {fruit} is 10.0',
1356+
metadata={'foo': 'bar'},
1357+
)
1358+
1359+
result = agent.run_sync('What do an apple and a banana cost?')
1360+
1361+
assert result.all_messages() == snapshot(
1362+
[
1363+
ModelRequest(
1364+
parts=[
1365+
UserPromptPart(
1366+
content='What do an apple and a banana cost?',
1367+
timestamp=IsDatetime(),
1368+
)
1369+
]
1370+
),
1371+
ModelResponse(
1372+
parts=[
1373+
ToolCallPart(
1374+
tool_name='get_price',
1375+
args={'fruit': 'apple'},
1376+
tool_call_id=IsStr(),
1377+
),
1378+
ToolCallPart(
1379+
tool_name='get_price',
1380+
args={'fruit': 'banana'},
1381+
tool_call_id=IsStr(),
1382+
),
1383+
],
1384+
usage=Usage(requests=1, request_tokens=58, response_tokens=10, total_tokens=68),
1385+
model_name='function:llm:',
1386+
timestamp=IsDatetime(),
1387+
),
1388+
ModelRequest(
1389+
parts=[
1390+
ToolReturnPart(
1391+
tool_name='get_price',
1392+
content=10.0,
1393+
tool_call_id=IsStr(),
1394+
metadata={'foo': 'bar'},
1395+
timestamp=IsDatetime(),
1396+
),
1397+
ToolReturnPart(
1398+
tool_name='get_price',
1399+
content=10.0,
1400+
tool_call_id=IsStr(),
1401+
metadata={'foo': 'bar'},
1402+
timestamp=IsDatetime(),
1403+
),
1404+
UserPromptPart(
1405+
content='The price of apple is 10.0',
1406+
timestamp=IsDatetime(),
1407+
),
1408+
UserPromptPart(
1409+
content='The price of banana is 10.0',
1410+
timestamp=IsDatetime(),
1411+
),
1412+
]
1413+
),
1414+
ModelResponse(
1415+
parts=[TextPart(content='Done!')],
1416+
usage=Usage(requests=1, request_tokens=76, response_tokens=11, total_tokens=87),
1417+
model_name='function:llm:',
1418+
timestamp=IsDatetime(),
1419+
),
1420+
]
1421+
)

0 commit comments

Comments
 (0)