@@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901
659
659
for call in calls_to_run :
660
660
yield _messages .FunctionToolCallEvent (call )
661
661
662
- user_parts : list [_messages .UserPromptPart ] = []
662
+ user_parts_by_index : dict [ int , list [_messages .UserPromptPart ]] = defaultdict ( list )
663
663
664
664
if calls_to_run :
665
665
# Run all tool tasks in parallel
666
- parts_by_index : dict [int , list [ _messages .ModelRequestPart ] ] = {}
666
+ tool_parts_by_index : dict [int , _messages .ModelRequestPart ] = {}
667
667
with ctx .deps .tracer .start_as_current_span (
668
668
'running tools' ,
669
669
attributes = {
@@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901
681
681
done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
682
682
for task in done :
683
683
index = tasks .index (task )
684
- tool_result_part , extra_parts = task .result ()
685
- yield _messages .FunctionToolResultEvent (tool_result_part )
684
+ tool_part , tool_user_parts = task .result ()
685
+ yield _messages .FunctionToolResultEvent (tool_part )
686
686
687
- parts_by_index [index ] = [tool_result_part , * extra_parts ]
687
+ tool_parts_by_index [index ] = tool_part
688
+ user_parts_by_index [index ] = tool_user_parts
688
689
689
690
# We append the results at the end, rather than as they are received, to retain a consistent ordering
690
691
# This is mostly just to simplify testing
691
- for k in sorted (parts_by_index ):
692
- output_parts .extend ( parts_by_index [k ])
692
+ for k in sorted (tool_parts_by_index ):
693
+ output_parts .append ( tool_parts_by_index [k ])
693
694
694
695
# Finally, we handle deferred tool calls
695
696
for call in tool_calls_by_kind ['deferred' ]:
@@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901
704
705
else :
705
706
yield _messages .FunctionToolCallEvent (call )
706
707
707
- output_parts .extend (user_parts )
708
+ for k in sorted (user_parts_by_index ):
709
+ output_parts .extend (user_parts_by_index [k ])
708
710
709
711
if final_result :
710
712
output_final_result .append (final_result )
@@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901
713
715
async def _call_function_tool (
714
716
tool_manager : ToolManager [DepsT ],
715
717
tool_call : _messages .ToolCallPart ,
716
- ) -> tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , list [_messages .ModelRequestPart ]]:
718
+ ) -> tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , list [_messages .UserPromptPart ]]:
717
719
try :
718
720
tool_result = await tool_manager .handle_call (tool_call )
719
721
except ToolRetryError as e :
720
722
return (e .tool_retry , [])
721
723
722
- part = _messages .ToolReturnPart (
724
+ tool_part = _messages .ToolReturnPart (
723
725
tool_name = tool_call .tool_name ,
724
726
content = tool_result ,
725
727
tool_call_id = tool_call .tool_call_id ,
726
728
)
727
- extra_parts : list [_messages .ModelRequestPart ] = []
729
+ user_parts : list [_messages .UserPromptPart ] = []
728
730
729
731
if isinstance (tool_result , _messages .ToolReturn ):
730
732
if (
@@ -740,12 +742,12 @@ async def _call_function_tool(
740
742
f'Please use `content` instead.'
741
743
)
742
744
743
- part .content = tool_result .return_value # type: ignore
744
- part .metadata = tool_result .metadata
745
+ tool_part .content = tool_result .return_value # type: ignore
746
+ tool_part .metadata = tool_result .metadata
745
747
if tool_result .content :
746
- extra_parts .append (
748
+ user_parts .append (
747
749
_messages .UserPromptPart (
748
- content = list ( tool_result .content ) ,
750
+ content = tool_result .content ,
749
751
part_kind = 'user-prompt' ,
750
752
)
751
753
)
@@ -763,7 +765,7 @@ def process_content(content: Any) -> Any:
763
765
else :
764
766
identifier = multi_modal_content_identifier (content .url )
765
767
766
- extra_parts .append (
768
+ user_parts .append (
767
769
_messages .UserPromptPart (
768
770
content = [f'This is file { identifier } :' , content ],
769
771
part_kind = 'user-prompt' ,
@@ -775,11 +777,11 @@ def process_content(content: Any) -> Any:
775
777
776
778
if isinstance (tool_result , list ):
777
779
contents = cast (list [Any ], tool_result )
778
- part .content = [process_content (content ) for content in contents ]
780
+ tool_part .content = [process_content (content ) for content in contents ]
779
781
else :
780
- part .content = process_content (tool_result )
782
+ tool_part .content = process_content (tool_result )
781
783
782
- return (part , extra_parts )
784
+ return (tool_part , user_parts )
783
785
784
786
785
787
@dataclasses .dataclass
0 commit comments