Skip to content

Commit a6761cb

Browse files
authored
Stop calling MCP server get_tools ahead of agent run span (#2545)
1 parent 3839c6a commit a6761cb

File tree

8 files changed

+133
-138
lines changed

8 files changed

+133
-138
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ async def _prepare_request_parameters(
251251
) -> models.ModelRequestParameters:
252252
"""Build tools and create an agent model."""
253253
run_context = build_run_context(ctx)
254+
255+
# This will raise errors for any tool name conflicts
254256
ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context)
255257

256258
output_schema = ctx.deps.output_schema

pydantic_ai_slim/pydantic_ai/_tool_manager.py

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import dataclass, field, replace
66
from typing import Any, Generic
77

8+
from opentelemetry.trace import Tracer
89
from pydantic import ValidationError
910
from typing_extensions import assert_never
1011

@@ -21,41 +22,46 @@
2122
class ToolManager(Generic[AgentDepsT]):
2223
"""Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries."""
2324

24-
ctx: RunContext[AgentDepsT]
25-
"""The agent run context for a specific run step."""
2625
toolset: AbstractToolset[AgentDepsT]
2726
"""The toolset that provides the tools for this run step."""
28-
tools: dict[str, ToolsetTool[AgentDepsT]]
27+
ctx: RunContext[AgentDepsT] | None = None
28+
"""The agent run context for a specific run step."""
29+
tools: dict[str, ToolsetTool[AgentDepsT]] | None = None
2930
"""The cached tools for this run step."""
3031
failed_tools: set[str] = field(default_factory=set)
3132
"""Names of tools that failed in this run step."""
3233

33-
@classmethod
34-
async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
35-
"""Build a new tool manager for a specific run step."""
36-
return cls(
37-
ctx=ctx,
38-
toolset=toolset,
39-
tools=await toolset.get_tools(ctx),
40-
)
41-
4234
async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]:
4335
"""Build a new tool manager for the next run step, carrying over the retries from the current run step."""
44-
if ctx.run_step == self.ctx.run_step:
45-
return self
46-
47-
retries = {
48-
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools
49-
}
50-
return await self.__class__.build(self.toolset, replace(ctx, retries=retries))
36+
if self.ctx is not None:
37+
if ctx.run_step == self.ctx.run_step:
38+
return self
39+
40+
retries = {
41+
failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1
42+
for failed_tool_name in self.failed_tools
43+
}
44+
ctx = replace(ctx, retries=retries)
45+
46+
return self.__class__(
47+
toolset=self.toolset,
48+
ctx=ctx,
49+
tools=await self.toolset.get_tools(ctx),
50+
)
5151

5252
@property
5353
def tool_defs(self) -> list[ToolDefinition]:
5454
"""The tool definitions for the tools in this tool manager."""
55+
if self.tools is None:
56+
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
57+
5558
return [tool.tool_def for tool in self.tools.values()]
5659

5760
def get_tool_def(self, name: str) -> ToolDefinition | None:
5861
"""Get the tool definition for a given tool name, or `None` if the tool is unknown."""
62+
if self.tools is None:
63+
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
64+
5965
try:
6066
return self.tools[name].tool_def
6167
except KeyError:
@@ -71,15 +77,25 @@ async def handle_call(
7177
allow_partial: Whether to allow partial validation of the tool arguments.
7278
wrap_validation_errors: Whether to wrap validation errors in a retry prompt part.
7379
"""
80+
if self.tools is None or self.ctx is None:
81+
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
82+
7483
if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output':
7584
# Output tool calls are not traced
7685
return await self._call_tool(call, allow_partial, wrap_validation_errors)
7786
else:
78-
return await self._call_tool_traced(call, allow_partial, wrap_validation_errors)
87+
return await self._call_tool_traced(
88+
call,
89+
allow_partial,
90+
wrap_validation_errors,
91+
self.ctx.tracer,
92+
self.ctx.trace_include_content,
93+
)
94+
95+
async def _call_tool(self, call: ToolCallPart, allow_partial: bool, wrap_validation_errors: bool) -> Any:
96+
if self.tools is None or self.ctx is None:
97+
raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover
7998

80-
async def _call_tool(
81-
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
82-
) -> Any:
8399
name = call.tool_name
84100
tool = self.tools.get(name)
85101
try:
@@ -137,14 +153,19 @@ async def _call_tool(
137153
raise e
138154

139155
async def _call_tool_traced(
140-
self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True
156+
self,
157+
call: ToolCallPart,
158+
allow_partial: bool,
159+
wrap_validation_errors: bool,
160+
tracer: Tracer,
161+
include_content: bool = False,
141162
) -> Any:
142163
"""See <https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-spans/#execute-tool-span>."""
143164
span_attributes = {
144165
'gen_ai.tool.name': call.tool_name,
145166
# NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai
146167
'gen_ai.tool.call.id': call.tool_call_id,
147-
**({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}),
168+
**({'tool_arguments': call.args_as_json_str()} if include_content else {}),
148169
'logfire.msg': f'running tool: {call.tool_name}',
149170
# add the JSON schema so these attributes are formatted nicely in Logfire
150171
'logfire.json_schema': json.dumps(
@@ -156,7 +177,7 @@ async def _call_tool_traced(
156177
'tool_arguments': {'type': 'object'},
157178
'tool_response': {'type': 'object'},
158179
}
159-
if self.ctx.trace_include_content
180+
if include_content
160181
else {}
161182
),
162183
'gen_ai.tool.name': {},
@@ -165,16 +186,16 @@ async def _call_tool_traced(
165186
}
166187
),
167188
}
168-
with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
189+
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
169190
try:
170191
tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors)
171192
except ToolRetryError as e:
172193
part = e.tool_retry
173-
if self.ctx.trace_include_content and span.is_recording():
194+
if include_content and span.is_recording():
174195
span.set_attribute('tool_response', part.model_response())
175196
raise e
176197

177-
if self.ctx.trace_include_content and span.is_recording():
198+
if include_content and span.is_recording():
178199
span.set_attribute(
179200
'tool_response',
180201
tool_result

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 62 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ async def main():
566566
if output_toolset:
567567
output_toolset.max_retries = self._max_result_retries
568568
output_toolset.output_validators = output_validators
569+
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
570+
tool_manager = ToolManager[AgentDepsT](toolset)
569571

570572
# Build the graph
571573
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (
@@ -581,88 +583,73 @@ async def main():
581583
run_step=0,
582584
)
583585

586+
# Merge model settings in order of precedence: run > agent > model
587+
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
588+
model_settings = merge_model_settings(merged_settings, model_settings)
589+
usage_limits = usage_limits or _usage.UsageLimits()
590+
591+
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
592+
parts = [
593+
self._instructions,
594+
*[await func.run(run_context) for func in self._instructions_functions],
595+
]
596+
597+
model_profile = model_used.profile
598+
if isinstance(output_schema, _output.PromptedOutputSchema):
599+
instructions = output_schema.instructions(model_profile.prompted_output_template)
600+
parts.append(instructions)
601+
602+
parts = [p for p in parts if p]
603+
if not parts:
604+
return None
605+
return '\n\n'.join(parts).strip()
606+
584607
if isinstance(model_used, InstrumentedModel):
585608
instrumentation_settings = model_used.instrumentation_settings
586609
tracer = model_used.instrumentation_settings.tracer
587610
else:
588611
instrumentation_settings = None
589612
tracer = NoOpTracer()
590613

591-
run_context = RunContext[AgentDepsT](
592-
deps=deps,
593-
model=model_used,
594-
usage=usage,
614+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
615+
user_deps=deps,
595616
prompt=user_prompt,
596-
messages=state.message_history,
617+
new_message_index=new_message_index,
618+
model=model_used,
619+
model_settings=model_settings,
620+
usage_limits=usage_limits,
621+
max_result_retries=self._max_result_retries,
622+
end_strategy=self.end_strategy,
623+
output_schema=output_schema,
624+
output_validators=output_validators,
625+
history_processors=self.history_processors,
626+
builtin_tools=list(self._builtin_tools),
627+
tool_manager=tool_manager,
597628
tracer=tracer,
598-
trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content,
599-
run_step=state.run_step,
629+
get_instructions=get_instructions,
630+
instrumentation_settings=instrumentation_settings,
631+
)
632+
start_node = _agent_graph.UserPromptNode[AgentDepsT](
633+
user_prompt=user_prompt,
634+
instructions=self._instructions,
635+
instructions_functions=self._instructions_functions,
636+
system_prompts=self._system_prompts,
637+
system_prompt_functions=self._system_prompt_functions,
638+
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
600639
)
601640

602-
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
603-
604-
async with toolset:
605-
# This will raise errors for any name conflicts
606-
tool_manager = await ToolManager[AgentDepsT].build(toolset, run_context)
607-
608-
# Merge model settings in order of precedence: run > agent > model
609-
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
610-
model_settings = merge_model_settings(merged_settings, model_settings)
611-
usage_limits = usage_limits or _usage.UsageLimits()
612-
agent_name = self.name or 'agent'
613-
run_span = tracer.start_span(
614-
'agent run',
615-
attributes={
616-
'model_name': model_used.model_name if model_used else 'no-model',
617-
'agent_name': agent_name,
618-
'logfire.msg': f'{agent_name} run',
619-
},
620-
)
621-
622-
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
623-
parts = [
624-
self._instructions,
625-
*[await func.run(run_context) for func in self._instructions_functions],
626-
]
627-
628-
model_profile = model_used.profile
629-
if isinstance(output_schema, _output.PromptedOutputSchema):
630-
instructions = output_schema.instructions(model_profile.prompted_output_template)
631-
parts.append(instructions)
632-
633-
parts = [p for p in parts if p]
634-
if not parts:
635-
return None
636-
return '\n\n'.join(parts).strip()
637-
638-
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
639-
user_deps=deps,
640-
prompt=user_prompt,
641-
new_message_index=new_message_index,
642-
model=model_used,
643-
model_settings=model_settings,
644-
usage_limits=usage_limits,
645-
max_result_retries=self._max_result_retries,
646-
end_strategy=self.end_strategy,
647-
output_schema=output_schema,
648-
output_validators=output_validators,
649-
history_processors=self.history_processors,
650-
builtin_tools=list(self._builtin_tools),
651-
tool_manager=tool_manager,
652-
tracer=tracer,
653-
get_instructions=get_instructions,
654-
instrumentation_settings=instrumentation_settings,
655-
)
656-
start_node = _agent_graph.UserPromptNode[AgentDepsT](
657-
user_prompt=user_prompt,
658-
instructions=self._instructions,
659-
instructions_functions=self._instructions_functions,
660-
system_prompts=self._system_prompts,
661-
system_prompt_functions=self._system_prompt_functions,
662-
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
663-
)
641+
agent_name = self.name or 'agent'
642+
run_span = tracer.start_span(
643+
'agent run',
644+
attributes={
645+
'model_name': model_used.model_name if model_used else 'no-model',
646+
'agent_name': agent_name,
647+
'logfire.msg': f'{agent_name} run',
648+
},
649+
)
664650

665-
try:
651+
try:
652+
async with toolset:
666653
async with graph.iter(
667654
start_node,
668655
state=state,
@@ -682,12 +669,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
682669
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
683670
),
684671
)
672+
finally:
673+
try:
674+
if instrumentation_settings and run_span.is_recording():
675+
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
685676
finally:
686-
try:
687-
if instrumentation_settings and run_span.is_recording():
688-
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
689-
finally:
690-
run_span.end()
677+
run_span.end()
691678

692679
def _run_span_end_attributes(
693680
self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings

tests/test_ag_ui.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1104,18 +1104,7 @@ async def store_state(
11041104
events.append(json.loads(event.removeprefix('data: ')))
11051105

11061106
assert events == simple_result()
1107-
assert seen_states == snapshot(
1108-
[
1109-
41, # run msg_1, prepare_tools call 1
1110-
42, # run msg_1, prepare_tools call 2
1111-
0, # run msg_2, prepare_tools call 1
1112-
1, # run msg_2, prepare_tools call 2
1113-
0, # run msg_3, prepare_tools call 1
1114-
1, # run msg_3, prepare_tools call 2
1115-
42, # run msg_4, prepare_tools call 1
1116-
43, # run msg_4, prepare_tools call 2
1117-
]
1118-
)
1107+
assert seen_states == snapshot([41, 0, 0, 42])
11191108

11201109

11211110
async def test_request_with_state_without_handler() -> None:

tests/test_agent.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3768,11 +3768,7 @@ async def via_toolset_decorator_for_entire_run(ctx: RunContext[None]) -> Abstrac
37683768
assert run_result._state.run_step == 3 # pyright: ignore[reportPrivateUsage]
37693769
assert len(available_tools) == 3
37703770
assert toolset_creation_counts == snapshot(
3771-
{
3772-
'via_toolsets_arg': 4,
3773-
'via_toolset_decorator': 4,
3774-
'via_toolset_decorator_for_entire_run': 1,
3775-
}
3771+
defaultdict(int, {'via_toolsets_arg': 3, 'via_toolset_decorator': 3, 'via_toolset_decorator_for_entire_run': 1})
37763772
)
37773773

37783774

tests/test_temporal.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,6 @@ async def test_complex_agent_run_in_workflow(
322322
'RunWorkflow:ComplexAgentWorkflow',
323323
'StartActivity:agent__complex_agent__mcp_server__mcp__get_tools',
324324
'RunActivity:agent__complex_agent__mcp_server__mcp__get_tools',
325-
'StartActivity:agent__complex_agent__mcp_server__mcp__get_tools',
326-
'RunActivity:agent__complex_agent__mcp_server__mcp__get_tools',
327325
'StartActivity:agent__complex_agent__model_request_stream',
328326
'ctx.run_step=1',
329327
'{"index":0,"part":{"tool_name":"get_country","args":"","tool_call_id":"call_3rqTYrA6H21AYUaRGP4F66oq","part_kind":"tool-call"},"event_kind":"part_start"}',

tests/test_tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1226,10 +1226,9 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int:
12261226
with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"):
12271227
agent.run_sync('Begin infinite retry loop!')
12281228

1229-
# There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in.
1230-
assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5]
1231-
assert prepare_retries == [0, 0, 1, 2, 3, 4, 5]
1232-
assert call_retries == [0, 1, 2, 3, 4, 5]
1229+
assert prepare_tools_retries == snapshot([0, 1, 2, 3, 4, 5])
1230+
assert prepare_retries == snapshot([0, 1, 2, 3, 4, 5])
1231+
assert call_retries == snapshot([0, 1, 2, 3, 4, 5])
12331232

12341233

12351234
def test_deferred_tool():

0 commit comments

Comments
 (0)