Skip to content

Commit 849b481

Browse files
Kludexclaude
andauthored
Remove redundant running tools parent span (#4560)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6df3288 commit 849b481

File tree

5 files changed

+209
-313
lines changed

5 files changed

+209
-313
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 80 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,7 +1117,6 @@ async def process_tool_calls( # noqa: C901
11171117
tool_calls=calls_to_run,
11181118
tool_call_results=calls_to_run_results,
11191119
validated_calls=validated_calls,
1120-
tracer=ctx.deps.tracer,
11211120
output_parts=output_parts,
11221121
output_deferred_calls=deferred_calls,
11231122
output_deferred_metadata=deferred_metadata,
@@ -1185,7 +1184,6 @@ async def _call_tools( # noqa: C901
11851184
tool_calls: list[_messages.ToolCallPart],
11861185
tool_call_results: dict[str, DeferredToolResult],
11871186
validated_calls: dict[str, ValidatedToolCall[DepsT]],
1188-
tracer: Tracer,
11891187
output_parts: list[_messages.ModelRequestPart],
11901188
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
11911189
output_deferred_metadata: dict[str, dict[str, Any]],
@@ -1195,101 +1193,89 @@ async def _call_tools( # noqa: C901
11951193
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
11961194
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
11971195

1198-
with tracer.start_as_current_span(
1199-
'running tools',
1200-
attributes={
1201-
'tools': [call.tool_name for call in tool_calls],
1202-
'logfire.msg': f'running {len(tool_calls)} tool{"" if len(tool_calls) == 1 else "s"}',
1203-
},
1204-
):
1196+
async def handle_call_or_result(
1197+
coro_or_task: Awaitable[
1198+
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None]
1199+
]
1200+
| Task[
1201+
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None]
1202+
],
1203+
index: int,
1204+
) -> _messages.HandleResponseEvent | None:
1205+
try:
1206+
tool_part, tool_user_content = (
1207+
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
1208+
)
1209+
except exceptions.CallDeferred as e:
1210+
deferred_calls_by_index[index] = 'external'
1211+
deferred_metadata_by_index[index] = e.metadata
1212+
except exceptions.ApprovalRequired as e:
1213+
deferred_calls_by_index[index] = 'unapproved'
1214+
deferred_metadata_by_index[index] = e.metadata
1215+
else:
1216+
tool_parts_by_index[index] = tool_part
1217+
if tool_user_content:
1218+
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)
1219+
1220+
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)
1221+
1222+
parallel_execution_mode = tool_manager.get_parallel_execution_mode(tool_calls)
1223+
if parallel_execution_mode == 'sequential':
1224+
for index, call in enumerate(tool_calls):
1225+
if event := await handle_call_or_result(
1226+
_call_tool(
1227+
tool_manager,
1228+
validated_calls.get(call.tool_call_id, call),
1229+
tool_call_results.get(call.tool_call_id),
1230+
),
1231+
index,
1232+
):
1233+
yield event
12051234

1206-
async def handle_call_or_result(
1207-
coro_or_task: Awaitable[
1208-
tuple[
1209-
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
1210-
]
1211-
]
1212-
| Task[
1213-
tuple[
1214-
_messages.ToolReturnPart | _messages.RetryPromptPart, str | Sequence[_messages.UserContent] | None
1215-
]
1216-
],
1217-
index: int,
1218-
) -> _messages.HandleResponseEvent | None:
1219-
try:
1220-
tool_part, tool_user_content = (
1221-
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
1222-
)
1223-
except exceptions.CallDeferred as e:
1224-
deferred_calls_by_index[index] = 'external'
1225-
deferred_metadata_by_index[index] = e.metadata
1226-
except exceptions.ApprovalRequired as e:
1227-
deferred_calls_by_index[index] = 'unapproved'
1228-
deferred_metadata_by_index[index] = e.metadata
1235+
else:
1236+
tasks = [
1237+
asyncio.create_task(
1238+
_call_tool(
1239+
tool_manager,
1240+
validated_calls.get(call.tool_call_id, call),
1241+
tool_call_results.get(call.tool_call_id),
1242+
),
1243+
name=call.tool_name,
1244+
)
1245+
for call in tool_calls
1246+
]
1247+
try:
1248+
if parallel_execution_mode == 'parallel_ordered_events':
1249+
# Wait for all tasks to complete before yielding any events
1250+
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
1251+
for index, task in enumerate(tasks):
1252+
if event := await handle_call_or_result(coro_or_task=task, index=index):
1253+
yield event
12291254
else:
1230-
tool_parts_by_index[index] = tool_part
1231-
if tool_user_content:
1232-
user_parts_by_index[index] = _messages.UserPromptPart(content=tool_user_content)
1233-
1234-
return _messages.FunctionToolResultEvent(tool_part, content=tool_user_content)
1235-
1236-
parallel_execution_mode = tool_manager.get_parallel_execution_mode(tool_calls)
1237-
if parallel_execution_mode == 'sequential':
1238-
for index, call in enumerate(tool_calls):
1239-
if event := await handle_call_or_result(
1240-
_call_tool(
1241-
tool_manager,
1242-
validated_calls.get(call.tool_call_id, call),
1243-
tool_call_results.get(call.tool_call_id),
1244-
),
1245-
index,
1246-
):
1247-
yield event
1248-
1249-
else:
1250-
tasks = [
1251-
asyncio.create_task(
1252-
_call_tool(
1253-
tool_manager,
1254-
validated_calls.get(call.tool_call_id, call),
1255-
tool_call_results.get(call.tool_call_id),
1256-
),
1257-
name=call.tool_name,
1258-
)
1259-
for call in tool_calls
1260-
]
1261-
try:
1262-
if parallel_execution_mode == 'parallel_ordered_events':
1263-
# Wait for all tasks to complete before yielding any events
1264-
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
1265-
for index, task in enumerate(tasks):
1266-
if event := await handle_call_or_result(coro_or_task=task, index=index):
1255+
pending: set[
1256+
asyncio.Task[
1257+
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]
1258+
]
1259+
] = set(tasks) # pyright: ignore[reportAssignmentType]
1260+
while pending:
1261+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1262+
for task in done:
1263+
index = tasks.index(task) # pyright: ignore[reportArgumentType]
1264+
if event := await handle_call_or_result(coro_or_task=task, index=index): # pyright: ignore[reportArgumentType]
12671265
yield event
1268-
else:
1269-
pending: set[
1270-
asyncio.Task[
1271-
tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, _messages.UserPromptPart | None]
1272-
]
1273-
] = set(tasks) # pyright: ignore[reportAssignmentType]
1274-
while pending:
1275-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
1276-
for task in done:
1277-
index = tasks.index(task) # pyright: ignore[reportArgumentType]
1278-
if event := await handle_call_or_result(coro_or_task=task, index=index): # pyright: ignore[reportArgumentType]
1279-
yield event
1280-
1281-
except asyncio.CancelledError as e:
1282-
for task in tasks:
1283-
task.cancel(msg=e.args[0] if len(e.args) != 0 else None)
1284-
raise
1285-
except BaseException:
1286-
# Cancel any still-running sibling tasks so they don't become
1287-
# orphaned asyncio tasks when a non-CancelledError exception
1288-
# (e.g. RuntimeError, ConnectionError) propagates out of
1289-
# handle_call_or_result().
1290-
for task in tasks:
1291-
task.cancel()
1292-
raise
1266+
1267+
except asyncio.CancelledError as e:
1268+
for task in tasks:
1269+
task.cancel(msg=e.args[0] if len(e.args) != 0 else None)
1270+
raise
1271+
except BaseException:
1272+
# Cancel any still-running sibling tasks so they don't become
1273+
# orphaned asyncio tasks when a non-CancelledError exception
1274+
# (e.g. RuntimeError, ConnectionError) propagates out of
1275+
# handle_call_or_result().
1276+
for task in tasks:
1277+
task.cancel()
1278+
raise
12931279

12941280
# We append the results at the end, rather than as they are received, to retain a consistent ordering
12951281
# This is mostly just to simplify testing

tests/test_dbos.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -343,35 +343,30 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
343343
),
344344
],
345345
),
346+
BasicSpan(content='running tool: get_country'),
346347
BasicSpan(
347-
content='running 2 tools',
348+
content='running tool: get_product_name',
349+
children=[BasicSpan(content='complex_agent__mcp_server__mcp.call_tool')],
350+
),
351+
BasicSpan(
352+
content='event_stream_handler',
348353
children=[
349-
BasicSpan(content='running tool: get_country'),
350-
BasicSpan(
351-
content='running tool: get_product_name',
352-
children=[BasicSpan(content='complex_agent__mcp_server__mcp.call_tool')],
353-
),
354+
BasicSpan(content='ctx.run_step=1'),
354355
BasicSpan(
355-
content='event_stream_handler',
356-
children=[
357-
BasicSpan(content='ctx.run_step=1'),
358-
BasicSpan(
359-
content=IsStr(
360-
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
361-
)
362-
),
363-
],
356+
content=IsStr(
357+
regex=r'{"result":{"tool_name":"get_country","content":"Mexico","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
358+
)
364359
),
360+
],
361+
),
362+
BasicSpan(
363+
content='event_stream_handler',
364+
children=[
365+
BasicSpan(content='ctx.run_step=1'),
365366
BasicSpan(
366-
content='event_stream_handler',
367-
children=[
368-
BasicSpan(content='ctx.run_step=1'),
369-
BasicSpan(
370-
content=IsStr(
371-
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
372-
)
373-
),
374-
],
367+
content=IsStr(
368+
regex=r'{"result":{"tool_name":"get_product_name","content":"Pydantic AI","tool_call_id":"call_Xw9XMKBJU48kAAd78WgIswDx","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
369+
)
375370
),
376371
],
377372
),
@@ -419,22 +414,15 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D
419414
),
420415
],
421416
),
417+
BasicSpan(content='running tool: get_weather', children=[BasicSpan(content='get_weather')]),
422418
BasicSpan(
423-
content='running 1 tool',
419+
content='event_stream_handler',
424420
children=[
421+
BasicSpan(content='ctx.run_step=2'),
425422
BasicSpan(
426-
content='running tool: get_weather', children=[BasicSpan(content='get_weather')]
427-
),
428-
BasicSpan(
429-
content='event_stream_handler',
430-
children=[
431-
BasicSpan(content='ctx.run_step=2'),
432-
BasicSpan(
433-
content=IsStr(
434-
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
435-
)
436-
),
437-
],
423+
content=IsStr(
424+
regex=r'{"result":{"tool_name":"get_weather","content":"sunny","tool_call_id":"call_Vz0Sie91Ap56nH0ThKGrZXT7","metadata":null,"timestamp":".+?","part_kind":"tool-return"},"content":null,"event_kind":"function_tool_result"}'
425+
)
438426
),
439427
],
440428
),

0 commit comments

Comments
 (0)