@@ -1117,7 +1117,6 @@ async def process_tool_calls( # noqa: C901
11171117 tool_calls = calls_to_run ,
11181118 tool_call_results = calls_to_run_results ,
11191119 validated_calls = validated_calls ,
1120- tracer = ctx .deps .tracer ,
11211120 output_parts = output_parts ,
11221121 output_deferred_calls = deferred_calls ,
11231122 output_deferred_metadata = deferred_metadata ,
@@ -1185,7 +1184,6 @@ async def _call_tools( # noqa: C901
11851184 tool_calls : list [_messages .ToolCallPart ],
11861185 tool_call_results : dict [str , DeferredToolResult ],
11871186 validated_calls : dict [str , ValidatedToolCall [DepsT ]],
1188- tracer : Tracer ,
11891187 output_parts : list [_messages .ModelRequestPart ],
11901188 output_deferred_calls : dict [Literal ['external' , 'unapproved' ], list [_messages .ToolCallPart ]],
11911189 output_deferred_metadata : dict [str , dict [str , Any ]],
@@ -1195,101 +1193,89 @@ async def _call_tools( # noqa: C901
11951193 deferred_calls_by_index : dict [int , Literal ['external' , 'unapproved' ]] = {}
11961194 deferred_metadata_by_index : dict [int , dict [str , Any ] | None ] = {}
11971195
1198- with tracer .start_as_current_span (
1199- 'running tools' ,
1200- attributes = {
1201- 'tools' : [call .tool_name for call in tool_calls ],
1202- 'logfire.msg' : f'running { len (tool_calls )} tool{ "" if len (tool_calls ) == 1 else "s" } ' ,
1203- },
1204- ):
1196+ async def handle_call_or_result (
1197+ coro_or_task : Awaitable [
1198+ tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , str | Sequence [_messages .UserContent ] | None ]
1199+ ]
1200+ | Task [
1201+ tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , str | Sequence [_messages .UserContent ] | None ]
1202+ ],
1203+ index : int ,
1204+ ) -> _messages .HandleResponseEvent | None :
1205+ try :
1206+ tool_part , tool_user_content = (
1207+ (await coro_or_task ) if inspect .isawaitable (coro_or_task ) else coro_or_task .result ()
1208+ )
1209+ except exceptions .CallDeferred as e :
1210+ deferred_calls_by_index [index ] = 'external'
1211+ deferred_metadata_by_index [index ] = e .metadata
1212+ except exceptions .ApprovalRequired as e :
1213+ deferred_calls_by_index [index ] = 'unapproved'
1214+ deferred_metadata_by_index [index ] = e .metadata
1215+ else :
1216+ tool_parts_by_index [index ] = tool_part
1217+ if tool_user_content :
1218+ user_parts_by_index [index ] = _messages .UserPromptPart (content = tool_user_content )
1219+
1220+ return _messages .FunctionToolResultEvent (tool_part , content = tool_user_content )
1221+
1222+ parallel_execution_mode = tool_manager .get_parallel_execution_mode (tool_calls )
1223+ if parallel_execution_mode == 'sequential' :
1224+ for index , call in enumerate (tool_calls ):
1225+ if event := await handle_call_or_result (
1226+ _call_tool (
1227+ tool_manager ,
1228+ validated_calls .get (call .tool_call_id , call ),
1229+ tool_call_results .get (call .tool_call_id ),
1230+ ),
1231+ index ,
1232+ ):
1233+ yield event
12051234
1206- async def handle_call_or_result (
1207- coro_or_task : Awaitable [
1208- tuple [
1209- _messages .ToolReturnPart | _messages .RetryPromptPart , str | Sequence [_messages .UserContent ] | None
1210- ]
1211- ]
1212- | Task [
1213- tuple [
1214- _messages .ToolReturnPart | _messages .RetryPromptPart , str | Sequence [_messages .UserContent ] | None
1215- ]
1216- ],
1217- index : int ,
1218- ) -> _messages .HandleResponseEvent | None :
1219- try :
1220- tool_part , tool_user_content = (
1221- (await coro_or_task ) if inspect .isawaitable (coro_or_task ) else coro_or_task .result ()
1222- )
1223- except exceptions .CallDeferred as e :
1224- deferred_calls_by_index [index ] = 'external'
1225- deferred_metadata_by_index [index ] = e .metadata
1226- except exceptions .ApprovalRequired as e :
1227- deferred_calls_by_index [index ] = 'unapproved'
1228- deferred_metadata_by_index [index ] = e .metadata
1235+ else :
1236+ tasks = [
1237+ asyncio .create_task (
1238+ _call_tool (
1239+ tool_manager ,
1240+ validated_calls .get (call .tool_call_id , call ),
1241+ tool_call_results .get (call .tool_call_id ),
1242+ ),
1243+ name = call .tool_name ,
1244+ )
1245+ for call in tool_calls
1246+ ]
1247+ try :
1248+ if parallel_execution_mode == 'parallel_ordered_events' :
1249+ # Wait for all tasks to complete before yielding any events
1250+ await asyncio .wait (tasks , return_when = asyncio .ALL_COMPLETED )
1251+ for index , task in enumerate (tasks ):
1252+ if event := await handle_call_or_result (coro_or_task = task , index = index ):
1253+ yield event
12291254 else :
1230- tool_parts_by_index [index ] = tool_part
1231- if tool_user_content :
1232- user_parts_by_index [index ] = _messages .UserPromptPart (content = tool_user_content )
1233-
1234- return _messages .FunctionToolResultEvent (tool_part , content = tool_user_content )
1235-
1236- parallel_execution_mode = tool_manager .get_parallel_execution_mode (tool_calls )
1237- if parallel_execution_mode == 'sequential' :
1238- for index , call in enumerate (tool_calls ):
1239- if event := await handle_call_or_result (
1240- _call_tool (
1241- tool_manager ,
1242- validated_calls .get (call .tool_call_id , call ),
1243- tool_call_results .get (call .tool_call_id ),
1244- ),
1245- index ,
1246- ):
1247- yield event
1248-
1249- else :
1250- tasks = [
1251- asyncio .create_task (
1252- _call_tool (
1253- tool_manager ,
1254- validated_calls .get (call .tool_call_id , call ),
1255- tool_call_results .get (call .tool_call_id ),
1256- ),
1257- name = call .tool_name ,
1258- )
1259- for call in tool_calls
1260- ]
1261- try :
1262- if parallel_execution_mode == 'parallel_ordered_events' :
1263- # Wait for all tasks to complete before yielding any events
1264- await asyncio .wait (tasks , return_when = asyncio .ALL_COMPLETED )
1265- for index , task in enumerate (tasks ):
1266- if event := await handle_call_or_result (coro_or_task = task , index = index ):
1255+ pending : set [
1256+ asyncio .Task [
1257+ tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , _messages .UserPromptPart | None ]
1258+ ]
1259+ ] = set (tasks ) # pyright: ignore[reportAssignmentType]
1260+ while pending :
1261+ done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
1262+ for task in done :
1263+ index = tasks .index (task ) # pyright: ignore[reportArgumentType]
1264+ if event := await handle_call_or_result (coro_or_task = task , index = index ): # pyright: ignore[reportArgumentType]
12671265 yield event
1268- else :
1269- pending : set [
1270- asyncio .Task [
1271- tuple [_messages .ToolReturnPart | _messages .RetryPromptPart , _messages .UserPromptPart | None ]
1272- ]
1273- ] = set (tasks ) # pyright: ignore[reportAssignmentType]
1274- while pending :
1275- done , pending = await asyncio .wait (pending , return_when = asyncio .FIRST_COMPLETED )
1276- for task in done :
1277- index = tasks .index (task ) # pyright: ignore[reportArgumentType]
1278- if event := await handle_call_or_result (coro_or_task = task , index = index ): # pyright: ignore[reportArgumentType]
1279- yield event
1280-
1281- except asyncio .CancelledError as e :
1282- for task in tasks :
1283- task .cancel (msg = e .args [0 ] if len (e .args ) != 0 else None )
1284- raise
1285- except BaseException :
1286- # Cancel any still-running sibling tasks so they don't become
1287- # orphaned asyncio tasks when a non-CancelledError exception
1288- # (e.g. RuntimeError, ConnectionError) propagates out of
1289- # handle_call_or_result().
1290- for task in tasks :
1291- task .cancel ()
1292- raise
1266+
1267+ except asyncio .CancelledError as e :
1268+ for task in tasks :
1269+ task .cancel (msg = e .args [0 ] if len (e .args ) != 0 else None )
1270+ raise
1271+ except BaseException :
1272+ # Cancel any still-running sibling tasks so they don't become
1273+ # orphaned asyncio tasks when a non-CancelledError exception
1274+ # (e.g. RuntimeError, ConnectionError) propagates out of
1275+ # handle_call_or_result().
1276+ for task in tasks :
1277+ task .cancel ()
1278+ raise
12931279
12941280 # We append the results at the end, rather than as they are received, to retain a consistent ordering
12951281 # This is mostly just to simplify testing
0 commit comments