Skip to content

Commit f28c6fc

Browse files
authored
Don't include ToolResultPart for external tool call when streaming (#3112)
1 parent cf0fa2a commit f28c6fc

File tree

4 files changed

+201
-24
lines changed

4 files changed

+201
-24
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -795,16 +795,14 @@ async def process_tool_calls( # noqa: C901
795795
# Then, we handle function tool calls
796796
calls_to_run: list[_messages.ToolCallPart] = []
797797
if final_result and ctx.deps.end_strategy == 'early':
798-
output_parts.extend(
799-
[
798+
for call in tool_calls_by_kind['function']:
799+
output_parts.append(
800800
_messages.ToolReturnPart(
801801
tool_name=call.tool_name,
802802
content='Tool not executed - a final result was already processed.',
803803
tool_call_id=call.tool_call_id,
804804
)
805-
for call in tool_calls_by_kind['function']
806-
]
807-
)
805+
)
808806
else:
809807
calls_to_run.extend(tool_calls_by_kind['function'])
810808

@@ -850,14 +848,17 @@ async def process_tool_calls( # noqa: C901
850848
if tool_call_results is None:
851849
calls = [*tool_calls_by_kind['external'], *tool_calls_by_kind['unapproved']]
852850
if final_result:
853-
for call in calls:
854-
output_parts.append(
855-
_messages.ToolReturnPart(
856-
tool_name=call.tool_name,
857-
content='Tool not executed - a final result was already processed.',
858-
tool_call_id=call.tool_call_id,
851+
# If the run was already determined to end on deferred tool calls,
852+
# we shouldn't insert return parts as the deferred tools will still get a real result.
853+
if not isinstance(final_result.output, _output.DeferredToolRequests):
854+
for call in calls:
855+
output_parts.append(
856+
_messages.ToolReturnPart(
857+
tool_name=call.tool_name,
858+
content='Tool not executed - a final result was already processed.',
859+
tool_call_id=call.tool_call_id,
860+
)
859861
)
860-
)
861862
elif calls:
862863
deferred_calls['external'].extend(tool_calls_by_kind['external'])
863864
deferred_calls['unapproved'].extend(tool_calls_by_kind['unapproved'])

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ async def stream_to_final(
489489

490490
if final_result_event is not None:
491491
final_result = FinalResult(
492-
stream, final_result_event.tool_name, final_result_event.tool_call_id
492+
None, final_result_event.tool_name, final_result_event.tool_call_id
493493
)
494494
if yielded:
495495
raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover
@@ -503,16 +503,15 @@ async def on_complete() -> None:
503503
The model response will have been added to messages by now
504504
by `StreamedRunResult._marked_completed`.
505505
"""
506-
last_message = messages[-1]
507-
assert isinstance(last_message, _messages.ModelResponse)
508-
tool_calls = [
509-
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
510-
]
506+
nonlocal final_result
507+
final_result = FinalResult(
508+
await stream.get_output(), final_result.tool_name, final_result.tool_call_id
509+
)
511510

512511
parts: list[_messages.ModelRequestPart] = []
513512
async for _event in _agent_graph.process_tool_calls(
514513
tool_manager=graph_ctx.deps.tool_manager,
515-
tool_calls=tool_calls,
514+
tool_calls=stream.response.tool_calls,
516515
tool_call_results=None,
517516
final_result=final_result,
518517
ctx=graph_ctx,

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,20 +67,20 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat
6767
except ValidationError:
6868
pass
6969
if self._raw_stream_response.final_result_event is not None: # pragma: no branch
70-
yield await self.validate_response_output(self._raw_stream_response.get())
70+
yield await self.validate_response_output(self.response)
7171

7272
async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
7373
"""Asynchronously stream the (unvalidated) model responses for the agent."""
7474
# if the message currently has any parts with content, yield before streaming
75-
msg = self._raw_stream_response.get()
75+
msg = self.response
7676
for part in msg.parts:
7777
if part.has_content():
7878
yield msg
7979
break
8080

8181
async with _utils.group_by_temporal(self, debounce_by) as group_iter:
8282
async for _items in group_iter:
83-
yield self._raw_stream_response.get() # current state of the response
83+
yield self.response # current state of the response
8484

8585
async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]:
8686
"""Stream the text result as an async iterable.
@@ -136,7 +136,7 @@ async def get_output(self) -> OutputDataT:
136136
async for _ in self:
137137
pass
138138

139-
return await self.validate_response_output(self._raw_stream_response.get())
139+
return await self.validate_response_output(self.response)
140140

141141
async def validate_response_output(
142142
self, message: _messages.ModelResponse, *, allow_partial: bool = False
@@ -201,7 +201,7 @@ async def _stream_text_deltas_ungrouped() -> AsyncIterator[tuple[str, int]]:
201201
# yields tuples of (text_content, part_index)
202202
# we don't currently make use of the part_index, but in principle this may be useful
203203
# so we retain it here for now to make possible future refactors simpler
204-
msg = self._raw_stream_response.get()
204+
msg = self.response
205205
for i, part in enumerate(msg.parts):
206206
if isinstance(part, _messages.TextPart) and part.content:
207207
yield part.content, i

tests/test_streaming.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AgentRunResult,
1919
AgentRunResultEvent,
2020
AgentStreamEvent,
21+
ExternalToolset,
2122
FinalResultEvent,
2223
FunctionToolCallEvent,
2324
FunctionToolResultEvent,
@@ -819,6 +820,182 @@ def another_tool(y: int) -> int: # pragma: no cover
819820
)
820821

821822

823+
async def test_early_strategy_with_external_tool_call():
824+
tool_called: list[str] = []
825+
826+
async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]:
827+
assert info.output_tools is not None
828+
yield {1: DeltaToolCall('external_tool')}
829+
yield {2: DeltaToolCall('final_result', '{"value": "final"}')}
830+
yield {3: DeltaToolCall('regular_tool', '{"x": 1}')}
831+
832+
agent = Agent(
833+
FunctionModel(stream_function=sf),
834+
output_type=[OutputType, DeferredToolRequests],
835+
toolsets=[
836+
ExternalToolset(
837+
tool_defs=[
838+
ToolDefinition(
839+
name='external_tool',
840+
kind='external',
841+
)
842+
]
843+
)
844+
],
845+
end_strategy='early',
846+
)
847+
848+
@agent.tool_plain
849+
def regular_tool(x: int) -> int: # pragma: no cover
850+
"""A regular tool that should not be called."""
851+
tool_called.append('regular_tool')
852+
return x
853+
854+
async with agent.run_stream('test early strategy with external tool call') as result:
855+
response = await result.get_output()
856+
assert response == snapshot(
857+
DeferredToolRequests(
858+
calls=[
859+
ToolCallPart(
860+
tool_name='external_tool',
861+
tool_call_id=IsStr(),
862+
)
863+
]
864+
)
865+
)
866+
messages = result.all_messages()
867+
868+
# Verify no tools were called
869+
assert tool_called == []
870+
871+
# Verify we got appropriate tool returns
872+
assert messages == snapshot(
873+
[
874+
ModelRequest(
875+
parts=[
876+
UserPromptPart(
877+
content='test early strategy with external tool call',
878+
timestamp=IsNow(tz=datetime.timezone.utc),
879+
part_kind='user-prompt',
880+
)
881+
],
882+
kind='request',
883+
),
884+
ModelResponse(
885+
parts=[
886+
ToolCallPart(tool_name='external_tool', tool_call_id=IsStr()),
887+
ToolCallPart(
888+
tool_name='final_result',
889+
args='{"value": "final"}',
890+
tool_call_id=IsStr(),
891+
),
892+
ToolCallPart(
893+
tool_name='regular_tool',
894+
args='{"x": 1}',
895+
tool_call_id=IsStr(),
896+
),
897+
],
898+
usage=RequestUsage(input_tokens=50, output_tokens=7),
899+
model_name='function::sf',
900+
timestamp=IsNow(tz=datetime.timezone.utc),
901+
kind='response',
902+
),
903+
ModelRequest(
904+
parts=[
905+
ToolReturnPart(
906+
tool_name='final_result',
907+
content='Output tool not used - a final result was already processed.',
908+
tool_call_id=IsStr(),
909+
timestamp=IsNow(tz=datetime.timezone.utc),
910+
),
911+
ToolReturnPart(
912+
tool_name='regular_tool',
913+
content='Tool not executed - a final result was already processed.',
914+
tool_call_id=IsStr(),
915+
timestamp=IsNow(tz=datetime.timezone.utc),
916+
),
917+
],
918+
kind='request',
919+
),
920+
]
921+
)
922+
923+
924+
async def test_early_strategy_with_deferred_tool_call():
925+
tool_called: list[str] = []
926+
927+
async def sf(_: list[ModelMessage], info: AgentInfo) -> AsyncIterator[str | DeltaToolCalls]:
928+
assert info.output_tools is not None
929+
yield {1: DeltaToolCall('deferred_tool')}
930+
yield {2: DeltaToolCall('regular_tool', '{"x": 1}')}
931+
932+
agent = Agent(
933+
FunctionModel(stream_function=sf),
934+
output_type=[str, DeferredToolRequests],
935+
end_strategy='early',
936+
)
937+
938+
@agent.tool_plain
939+
def deferred_tool() -> int:
940+
raise CallDeferred
941+
942+
@agent.tool_plain
943+
def regular_tool(x: int) -> int:
944+
tool_called.append('regular_tool')
945+
return x
946+
947+
async with agent.run_stream('test early strategy with external tool call') as result:
948+
response = await result.get_output()
949+
assert response == snapshot(
950+
DeferredToolRequests(calls=[ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr())])
951+
)
952+
messages = result.all_messages()
953+
954+
# Verify no tools were called
955+
assert tool_called == ['regular_tool']
956+
957+
# Verify we got appropriate tool returns
958+
assert messages == snapshot(
959+
[
960+
ModelRequest(
961+
parts=[
962+
UserPromptPart(
963+
content='test early strategy with external tool call',
964+
timestamp=IsNow(tz=datetime.timezone.utc),
965+
part_kind='user-prompt',
966+
)
967+
],
968+
kind='request',
969+
),
970+
ModelResponse(
971+
parts=[
972+
ToolCallPart(tool_name='deferred_tool', tool_call_id=IsStr()),
973+
ToolCallPart(
974+
tool_name='regular_tool',
975+
args='{"x": 1}',
976+
tool_call_id=IsStr(),
977+
),
978+
],
979+
usage=RequestUsage(input_tokens=50, output_tokens=3),
980+
model_name='function::sf',
981+
timestamp=IsNow(tz=datetime.timezone.utc),
982+
kind='response',
983+
),
984+
ModelRequest(
985+
parts=[
986+
ToolReturnPart(
987+
tool_name='regular_tool',
988+
content=1,
989+
tool_call_id=IsStr(),
990+
timestamp=IsNow(tz=datetime.timezone.utc),
991+
)
992+
],
993+
kind='request',
994+
),
995+
]
996+
)
997+
998+
822999
async def test_early_strategy_does_not_apply_to_tool_calls_without_final_tool():
8231000
"""Test that 'early' strategy does not apply to tool calls without final tool."""
8241001
tool_called: list[str] = []

0 commit comments

Comments
 (0)