Skip to content

Commit 29c72a6

Browse files
authored
Always enter Toolset context when running agent (#2361)
1 parent 0584724 commit 29c72a6

File tree

2 files changed

+123
-81
lines changed

2 files changed

+123
-81
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 81 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -774,90 +774,91 @@ async def main():
774774

775775
toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
776776
# This will raise errors for any name conflicts
777-
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
778-
779-
# Merge model settings in order of precedence: run > agent > model
780-
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
781-
model_settings = merge_model_settings(merged_settings, model_settings)
782-
usage_limits = usage_limits or _usage.UsageLimits()
783-
agent_name = self.name or 'agent'
784-
run_span = tracer.start_span(
785-
'agent run',
786-
attributes={
787-
'model_name': model_used.model_name if model_used else 'no-model',
788-
'agent_name': agent_name,
789-
'logfire.msg': f'{agent_name} run',
790-
},
791-
)
792-
793-
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
794-
parts = [
795-
self._instructions,
796-
*[await func.run(run_context) for func in self._instructions_functions],
797-
]
798-
799-
model_profile = model_used.profile
800-
if isinstance(output_schema, _output.PromptedOutputSchema):
801-
instructions = output_schema.instructions(model_profile.prompted_output_template)
802-
parts.append(instructions)
777+
async with toolset:
778+
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
779+
780+
# Merge model settings in order of precedence: run > agent > model
781+
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
782+
model_settings = merge_model_settings(merged_settings, model_settings)
783+
usage_limits = usage_limits or _usage.UsageLimits()
784+
agent_name = self.name or 'agent'
785+
run_span = tracer.start_span(
786+
'agent run',
787+
attributes={
788+
'model_name': model_used.model_name if model_used else 'no-model',
789+
'agent_name': agent_name,
790+
'logfire.msg': f'{agent_name} run',
791+
},
792+
)
803793

804-
parts = [p for p in parts if p]
805-
if not parts:
806-
return None
807-
return '\n\n'.join(parts).strip()
794+
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
795+
parts = [
796+
self._instructions,
797+
*[await func.run(run_context) for func in self._instructions_functions],
798+
]
808799

809-
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
810-
user_deps=deps,
811-
prompt=user_prompt,
812-
new_message_index=new_message_index,
813-
model=model_used,
814-
model_settings=model_settings,
815-
usage_limits=usage_limits,
816-
max_result_retries=self._max_result_retries,
817-
end_strategy=self.end_strategy,
818-
output_schema=output_schema,
819-
output_validators=output_validators,
820-
history_processors=self.history_processors,
821-
tool_manager=run_toolset,
822-
tracer=tracer,
823-
get_instructions=get_instructions,
824-
instrumentation_settings=instrumentation_settings,
825-
)
826-
start_node = _agent_graph.UserPromptNode[AgentDepsT](
827-
user_prompt=user_prompt,
828-
instructions=self._instructions,
829-
instructions_functions=self._instructions_functions,
830-
system_prompts=self._system_prompts,
831-
system_prompt_functions=self._system_prompt_functions,
832-
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
833-
)
800+
model_profile = model_used.profile
801+
if isinstance(output_schema, _output.PromptedOutputSchema):
802+
instructions = output_schema.instructions(model_profile.prompted_output_template)
803+
parts.append(instructions)
804+
805+
parts = [p for p in parts if p]
806+
if not parts:
807+
return None
808+
return '\n\n'.join(parts).strip()
809+
810+
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
811+
user_deps=deps,
812+
prompt=user_prompt,
813+
new_message_index=new_message_index,
814+
model=model_used,
815+
model_settings=model_settings,
816+
usage_limits=usage_limits,
817+
max_result_retries=self._max_result_retries,
818+
end_strategy=self.end_strategy,
819+
output_schema=output_schema,
820+
output_validators=output_validators,
821+
history_processors=self.history_processors,
822+
tool_manager=run_toolset,
823+
tracer=tracer,
824+
get_instructions=get_instructions,
825+
instrumentation_settings=instrumentation_settings,
826+
)
827+
start_node = _agent_graph.UserPromptNode[AgentDepsT](
828+
user_prompt=user_prompt,
829+
instructions=self._instructions,
830+
instructions_functions=self._instructions_functions,
831+
system_prompts=self._system_prompts,
832+
system_prompt_functions=self._system_prompt_functions,
833+
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
834+
)
834835

835-
try:
836-
async with graph.iter(
837-
start_node,
838-
state=state,
839-
deps=graph_deps,
840-
span=use_span(run_span) if run_span.is_recording() else None,
841-
infer_name=False,
842-
) as graph_run:
843-
agent_run = AgentRun(graph_run)
844-
yield agent_run
845-
if (final_result := agent_run.result) is not None and run_span.is_recording():
846-
if instrumentation_settings and instrumentation_settings.include_content:
847-
run_span.set_attribute(
848-
'final_result',
849-
(
850-
final_result.output
851-
if isinstance(final_result.output, str)
852-
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
853-
),
854-
)
855-
finally:
856836
try:
857-
if instrumentation_settings and run_span.is_recording():
858-
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
837+
async with graph.iter(
838+
start_node,
839+
state=state,
840+
deps=graph_deps,
841+
span=use_span(run_span) if run_span.is_recording() else None,
842+
infer_name=False,
843+
) as graph_run:
844+
agent_run = AgentRun(graph_run)
845+
yield agent_run
846+
if (final_result := agent_run.result) is not None and run_span.is_recording():
847+
if instrumentation_settings and instrumentation_settings.include_content:
848+
run_span.set_attribute(
849+
'final_result',
850+
(
851+
final_result.output
852+
if isinstance(final_result.output, str)
853+
else json.dumps(InstrumentedModel.serialize_any(final_result.output))
854+
),
855+
)
859856
finally:
860-
run_span.end()
857+
try:
858+
if instrumentation_settings and run_span.is_recording():
859+
run_span.set_attributes(self._run_span_end_attributes(state, usage, instrumentation_settings))
860+
finally:
861+
run_span.end()
861862

862863
def _run_span_end_attributes(
863864
self, state: _agent_graph.GraphAgentState, usage: _usage.Usage, settings: InstrumentationSettings
@@ -2173,7 +2174,7 @@ async def __anext__(
21732174
) -> _agent_graph.AgentNode[AgentDepsT, OutputDataT] | End[FinalResult[OutputDataT]]:
21742175
"""Advance to the next node automatically based on the last returned node."""
21752176
next_node = await self._graph_run.__anext__()
2176-
if _agent_graph.is_agent_node(next_node):
2177+
if _agent_graph.is_agent_node(node=next_node):
21772178
return next_node
21782179
assert isinstance(next_node, End), f'Unexpected node type: {type(next_node)}'
21792180
return next_node

tests/test_agent.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3866,7 +3866,7 @@ async def only_if_plan_presented(
38663866
)
38673867

38683868

3869-
async def test_context_manager():
3869+
async def test_explicit_context_manager():
38703870
try:
38713871
from pydantic_ai.mcp import MCPServerStdio
38723872
except ImportError: # pragma: lax no cover
@@ -3886,6 +3886,47 @@ async def test_context_manager():
38863886
assert server2.is_running
38873887

38883888

3889+
async def test_implicit_context_manager():
3890+
try:
3891+
from pydantic_ai.mcp import MCPServerStdio
3892+
except ImportError: # pragma: lax no cover
3893+
pytest.skip('mcp is not installed')
3894+
3895+
server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
3896+
server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
3897+
toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')])
3898+
agent = Agent('test', toolsets=[toolset])
3899+
3900+
async with agent.iter(
3901+
user_prompt='Hello',
3902+
):
3903+
assert server1.is_running
3904+
assert server2.is_running
3905+
3906+
3907+
def test_parallel_mcp_calls():
3908+
try:
3909+
from pydantic_ai.mcp import MCPServerStdio
3910+
except ImportError: # pragma: lax no cover
3911+
pytest.skip('mcp is not installed')
3912+
3913+
async def call_tools_parallel(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
3914+
if len(messages) == 1:
3915+
return ModelResponse(
3916+
parts=[
3917+
ToolCallPart(tool_name='get_none'),
3918+
ToolCallPart(tool_name='get_multiple_items'),
3919+
]
3920+
)
3921+
else:
3922+
return ModelResponse(parts=[TextPart('finished')])
3923+
3924+
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
3925+
agent = Agent(FunctionModel(call_tools_parallel), toolsets=[server])
3926+
result = agent.run_sync()
3927+
assert result.output == snapshot('finished')
3928+
3929+
38893930
def test_set_mcp_sampling_model():
38903931
try:
38913932
from pydantic_ai.mcp import MCPServerStdio

0 commit comments

Comments
 (0)