Skip to content

Commit 15c5ef2

Browse files
authored
Fix bug related to handling multiple result tools (#926)
1 parent 4d0f8ff commit 15c5ef2

File tree

8 files changed

+81
-25
lines changed

8 files changed

+81
-25
lines changed

docs/agents.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ async def main():
141141
kind='response',
142142
)
143143
),
144-
End(data=FinalResult(data='Paris', tool_name=None)),
144+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
145145
]
146146
"""
147147
print(agent_run.result.data)
@@ -202,7 +202,7 @@ async def main():
202202
kind='response',
203203
)
204204
),
205-
End(data=FinalResult(data='Paris', tool_name=None)),
205+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
206206
]
207207
"""
208208
```

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,7 @@ async def _handle_tool_calls(
454454
final_result: result.FinalResult[NodeRunEndT] | None = None
455455
parts: list[_messages.ModelRequestPart] = []
456456
if result_schema is not None:
457-
if match := result_schema.find_tool(tool_calls):
458-
call, result_tool = match
457+
for call, result_tool in result_schema.find_tool(tool_calls):
459458
try:
460459
result_data = result_tool.validate(call)
461460
result_data = await _validate_result(result_data, ctx, call)
@@ -465,12 +464,17 @@ async def _handle_tool_calls(
465464
ctx.state.increment_retries(ctx.deps.max_result_retries)
466465
parts.append(e.tool_retry)
467466
else:
468-
final_result = result.FinalResult(result_data, call.tool_name)
467+
final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id)
468+
break
469469

470470
# Then build the other request parts based on end strategy
471471
tool_responses: list[_messages.ModelRequestPart] = self._tool_responses
472472
async for event in process_function_tools(
473-
tool_calls, final_result and final_result.tool_name, ctx, tool_responses
473+
tool_calls,
474+
final_result and final_result.tool_name,
475+
final_result and final_result.tool_call_id,
476+
ctx,
477+
tool_responses,
474478
):
475479
yield event
476480

@@ -518,7 +522,7 @@ async def _handle_text_response(
518522
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
519523
else:
520524
# The following cast is safe because we know `str` is an allowed result type
521-
return self._handle_final_result(ctx, result.FinalResult(result_data, tool_name=None), [])
525+
return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), [])
522526
else:
523527
ctx.state.increment_retries(ctx.deps.max_result_retries)
524528
return ModelRequestNode[DepsT, NodeRunEndT](
@@ -547,6 +551,7 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
547551
async def process_function_tools(
548552
tool_calls: list[_messages.ToolCallPart],
549553
result_tool_name: str | None,
554+
result_tool_call_id: str | None,
550555
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
551556
output_parts: list[_messages.ModelRequestPart],
552557
) -> AsyncIterator[_messages.HandleResponseEvent]:
@@ -566,7 +571,11 @@ async def process_function_tools(
566571
calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = []
567572
call_index_to_event_id: dict[int, str] = {}
568573
for call in tool_calls:
569-
if call.tool_name == result_tool_name and not found_used_result_tool:
574+
if (
575+
call.tool_name == result_tool_name
576+
and call.tool_call_id == result_tool_call_id
577+
and not found_used_result_tool
578+
):
570579
found_used_result_tool = True
571580
output_parts.append(
572581
_messages.ToolReturnPart(
@@ -593,9 +602,14 @@ async def process_function_tools(
593602
# if tool_name is in _result_schema, it means we found a result tool but an error occurred in
594603
# validation, we don't add another part here
595604
if result_tool_name is not None:
605+
if found_used_result_tool:
606+
content = 'Result tool not used - a final result was already processed.'
607+
else:
608+
# TODO: Include information about the validation failure, and/or merge this with the ModelRetry part
609+
content = 'Result tool not used - result failed validation.'
596610
part = _messages.ToolReturnPart(
597611
tool_name=call.tool_name,
598-
content='Result tool not used - a final result was already processed.',
612+
content=content,
599613
tool_call_id=call.tool_call_id,
600614
)
601615
output_parts.append(part)

pydantic_ai_slim/pydantic_ai/_result.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import sys
55
import types
6-
from collections.abc import Awaitable, Iterable
6+
from collections.abc import Awaitable, Iterable, Iterator
77
from dataclasses import dataclass, field
88
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin
99

@@ -127,12 +127,12 @@ def find_named_tool(
127127
def find_tool(
128128
self,
129129
parts: Iterable[_messages.ModelResponsePart],
130-
) -> tuple[_messages.ToolCallPart, ResultTool[ResultDataT]] | None:
130+
) -> Iterator[tuple[_messages.ToolCallPart, ResultTool[ResultDataT]]]:
131131
"""Find a tool that matches one of the calls."""
132132
for part in parts:
133133
if isinstance(part, _messages.ToolCallPart):
134134
if result := self.tools.get(part.tool_name):
135-
return part, result
135+
yield part, result
136136

137137
def tool_names(self) -> list[str]:
138138
"""Return the names of the tools."""

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ async def main():
370370
kind='response',
371371
)
372372
),
373-
End(data=FinalResult(data='Paris', tool_name=None)),
373+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
374374
]
375375
'''
376376
print(agent_run.result.data)
@@ -661,11 +661,10 @@ async def stream_to_final(
661661
new_part = maybe_part_event.part
662662
if isinstance(new_part, _messages.TextPart):
663663
if _agent_graph.allow_text_result(result_schema):
664-
return FinalResult(s, None)
665-
elif isinstance(new_part, _messages.ToolCallPart):
666-
if result_schema is not None and (match := result_schema.find_tool([new_part])):
667-
call, _ = match
668-
return FinalResult(s, call.tool_name)
664+
return FinalResult(s, None, None)
665+
elif isinstance(new_part, _messages.ToolCallPart) and result_schema:
666+
for call, _ in result_schema.find_tool([new_part]):
667+
return FinalResult(s, call.tool_name, call.tool_call_id)
669668
return None
670669

671670
final_result_details = await stream_to_final(streamed_response)
@@ -692,6 +691,7 @@ async def on_complete() -> None:
692691
async for _event in _agent_graph.process_function_tools(
693692
tool_calls,
694693
final_result_details.tool_name,
694+
final_result_details.tool_call_id,
695695
graph_ctx,
696696
parts,
697697
):
@@ -1258,7 +1258,7 @@ async def main():
12581258
kind='response',
12591259
)
12601260
),
1261-
End(data=FinalResult(data='Paris', tool_name=None)),
1261+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
12621262
]
12631263
'''
12641264
print(agent_run.result.data)
@@ -1382,7 +1382,7 @@ async def main():
13821382
kind='response',
13831383
)
13841384
),
1385-
End(data=FinalResult(data='Paris', tool_name=None)),
1385+
End(data=FinalResult(data='Paris', tool_name=None, tool_call_id=None)),
13861386
]
13871387
'''
13881388
print('Final result:', agent_run.result.data)

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ class FinalResultEvent:
539539

540540
tool_name: str | None
541541
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
542+
tool_call_id: str | None
543+
"""The tool call ID, if any, that this result is associated with."""
542544
event_kind: Literal['final_result'] = 'final_result'
543545
"""Event type identifier, used as a discriminator."""
544546

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,14 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.
145145
if isinstance(e, _messages.PartStartEvent):
146146
new_part = e.part
147147
if isinstance(new_part, _messages.ToolCallPart):
148-
if result_schema is not None and (match := result_schema.find_tool([new_part])):
149-
call, _ = match
150-
return _messages.FinalResultEvent(tool_name=call.tool_name)
148+
if result_schema:
149+
for call, _ in result_schema.find_tool([new_part]):
150+
return _messages.FinalResultEvent(
151+
tool_name=call.tool_name, tool_call_id=call.tool_call_id
152+
)
151153
elif allow_text_result:
152154
assert_type(e, _messages.PartStartEvent)
153-
return _messages.FinalResultEvent(tool_name=None)
155+
return _messages.FinalResultEvent(tool_name=None, tool_call_id=None)
154156

155157
usage_checking_stream = _get_usage_checking_stream_response(
156158
self._raw_stream_response, self._usage_limits, self.usage
@@ -472,6 +474,8 @@ class FinalResult(Generic[ResultDataT]):
472474
"""The final result data."""
473475
tool_name: str | None
474476
"""Name of the final result tool; `None` if the result came from unstructured text content."""
477+
tool_call_id: str | None
478+
"""ID of the tool call that produced the final result; `None` if the result came from unstructured text content."""
475479

476480

477481
def _get_usage_checking_stream_response(

tests/test_agent.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,42 @@ def regular_tool(x: int) -> int:
11831183
tool_returns = [m for m in result.all_messages() if isinstance(m, ToolReturnPart)]
11841184
assert tool_returns == snapshot([])
11851185

1186+
def test_multiple_final_result_are_validated_correctly(self):
1187+
"""Tests that if multiple final results are returned, but one fails validation, the other is used."""
1188+
1189+
def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1190+
assert info.result_tools is not None
1191+
return ModelResponse(
1192+
parts=[
1193+
ToolCallPart('final_result', {'bad_value': 'first'}, tool_call_id='first'),
1194+
ToolCallPart('final_result', {'value': 'second'}, tool_call_id='second'),
1195+
]
1196+
)
1197+
1198+
agent = Agent(FunctionModel(return_model), result_type=self.ResultType, end_strategy='early')
1199+
result = agent.run_sync('test multiple final results')
1200+
1201+
# Verify the result came from the second final tool
1202+
assert result.data.value == 'second'
1203+
1204+
# Verify we got appropriate tool returns
1205+
assert result.new_messages()[-1].parts == snapshot(
1206+
[
1207+
ToolReturnPart(
1208+
tool_name='final_result',
1209+
tool_call_id='first',
1210+
content='Result tool not used - result failed validation.',
1211+
timestamp=IsNow(tz=timezone.utc),
1212+
),
1213+
ToolReturnPart(
1214+
tool_name='final_result',
1215+
content='Final result processed.',
1216+
timestamp=IsNow(tz=timezone.utc),
1217+
tool_call_id='second',
1218+
),
1219+
]
1220+
)
1221+
11861222

11871223
async def test_model_settings_override() -> None:
11881224
def return_settings(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:

tests/test_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ def result_validator_simple(data: str) -> str:
768768
async for chunk in stream.stream_output(debounce_by=None):
769769
messages.append(chunk)
770770
stream_usage = deepcopy(stream.usage())
771-
assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None))
771+
assert run.next_node == End(data=FinalResult(data='The bat sat on the mat.', tool_name=None, tool_call_id=None))
772772
assert (
773773
run.usage()
774774
== stream_usage

0 commit comments

Comments
 (0)