Skip to content

Commit e7171f3

Browse files
authored
Ensure toolset spans (e.g. MCP sampling) are nested under agent run span (#3224)
1 parent a972cf6 commit e7171f3

File tree

4 files changed

+124
-70
lines changed

4 files changed

+124
-70
lines changed

docs/mcp/client.md

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,13 @@ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
5858
agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)!
5959

6060
async def main():
61-
async with agent: # (3)!
62-
result = await agent.run('What is 7 plus 5?')
61+
result = await agent.run('What is 7 plus 5?')
6362
print(result.output)
6463
#> The answer is 12.
6564
```
6665

6766
1. Define the MCP server with the URL used to connect.
6867
2. Create an agent with the MCP server attached.
69-
3. Create a client session to connect to the server.
7068

7169
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
7270

@@ -122,15 +120,13 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)!
122120

123121

124122
async def main():
125-
async with agent: # (3)!
126-
result = await agent.run('What is 7 plus 5?')
123+
result = await agent.run('What is 7 plus 5?')
127124
print(result.output)
128125
#> The answer is 12.
129126
```
130127

131128
1. Define the MCP server with the URL used to connect.
132129
2. Create an agent with the MCP server attached.
133-
3. Create a client session to connect to the server.
134130

135131
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
136132

@@ -151,8 +147,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server])
151147

152148

153149
async def main():
154-
async with agent:
155-
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
150+
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
156151
print(result.output)
157152
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
158153
```
@@ -205,8 +200,7 @@ servers = load_mcp_servers('mcp_config.json')
205200
agent = Agent('openai:gpt-5', toolsets=servers)
206201

207202
async def main():
208-
async with agent:
209-
result = await agent.run('What is 7 plus 5?')
203+
result = await agent.run('What is 7 plus 5?')
210204
print(result.output)
211205
```
212206

@@ -247,8 +241,7 @@ agent = Agent(
247241

248242

249243
async def main():
250-
async with agent:
251-
result = await agent.run('Echo with deps set to 42', deps=42)
244+
result = await agent.run('Echo with deps set to 42', deps=42)
252245
print(result.output)
253246
#> {"echo_deps":{"echo":"This is an echo message","deps":42}}
254247
```
@@ -356,8 +349,7 @@ server = MCPServerSSE(
356349
agent = Agent('openai:gpt-4o', toolsets=[server])
357350

358351
async def main():
359-
async with agent:
360-
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
352+
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
361353
print(result.output)
362354
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
363355
```
@@ -454,9 +446,8 @@ agent = Agent('openai:gpt-4o', toolsets=[server])
454446

455447

456448
async def main():
457-
async with agent:
458-
agent.set_mcp_sampling_model()
459-
result = await agent.run('Create an image of a robot in a punk style.')
449+
agent.set_mcp_sampling_model()
450+
result = await agent.run('Create an image of a robot in a punk style.')
460451
print(result.output)
461452
#> Image file written to robot_punk.svg.
462453
```
@@ -598,9 +589,8 @@ agent = Agent('openai:gpt-4o', toolsets=[restaurant_server])
598589

599590
async def main():
600591
"""Run the agent to book a restaurant table."""
601-
async with agent:
602-
result = await agent.run('Book me a table')
603-
print(f'\nResult: {result.output}')
592+
result = await agent.run('Book me a table')
593+
print(f'\nResult: {result.output}')
604594

605595

606596
if __name__ == '__main__':

pydantic_ai_slim/pydantic_ai/agent/__init__.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -662,14 +662,14 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
662662
)
663663

664664
try:
665-
async with toolset:
666-
async with graph.iter(
667-
start_node,
668-
state=state,
669-
deps=graph_deps,
670-
span=use_span(run_span) if run_span.is_recording() else None,
671-
infer_name=False,
672-
) as graph_run:
665+
async with graph.iter(
666+
start_node,
667+
state=state,
668+
deps=graph_deps,
669+
span=use_span(run_span) if run_span.is_recording() else None,
670+
infer_name=False,
671+
) as graph_run:
672+
async with toolset:
673673
agent_run = AgentRun(graph_run)
674674
yield agent_run
675675
if (final_result := agent_run.result) is not None and run_span.is_recording():

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -441,14 +441,9 @@ class MCPServerStdio(MCPServer):
441441
'uv', args=['run', 'mcp-run-python', 'stdio'], timeout=10
442442
)
443443
agent = Agent('openai:gpt-4o', toolsets=[server])
444-
445-
async def main():
446-
async with agent: # (2)!
447-
...
448444
```
449445
450446
1. See [MCP Run Python](https://github.com/pydantic/mcp-run-python) for more information.
451-
2. This will start the server as a subprocess and connect to it.
452447
"""
453448

454449
command: str
@@ -788,13 +783,7 @@ class MCPServerSSE(_MCPServerHTTP):
788783
789784
server = MCPServerSSE('http://localhost:3001/sse')
790785
agent = Agent('openai:gpt-4o', toolsets=[server])
791-
792-
async def main():
793-
async with agent: # (1)!
794-
...
795786
```
796-
797-
1. This will connect to a server running on `localhost:3001`.
798787
"""
799788

800789
@classmethod
@@ -837,13 +826,7 @@ class MCPServerHTTP(MCPServerSSE):
837826
838827
server = MCPServerHTTP('http://localhost:3001/sse')
839828
agent = Agent('openai:gpt-4o', toolsets=[server])
840-
841-
async def main():
842-
async with agent: # (2)!
843-
...
844829
```
845-
846-
1. This will connect to a server running on `localhost:3001`.
847830
"""
848831

849832

@@ -862,12 +845,8 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
862845
from pydantic_ai import Agent
863846
from pydantic_ai.mcp import MCPServerStreamableHTTP
864847
865-
server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
848+
server = MCPServerStreamableHTTP('http://localhost:8000/mcp')
866849
agent = Agent('openai:gpt-4o', toolsets=[server])
867-
868-
async def main():
869-
async with agent: # (2)!
870-
...
871850
```
872851
"""
873852

tests/test_logfire.py

Lines changed: 105 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dirty_equals import IsInt, IsJson, IsList
99
from inline_snapshot import snapshot
1010
from pydantic import BaseModel
11-
from typing_extensions import NotRequired, TypedDict
11+
from typing_extensions import NotRequired, Self, TypedDict
1212

1313
from pydantic_ai import Agent, ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, UserPromptPart
1414
from pydantic_ai._utils import get_traceparent
@@ -18,10 +18,14 @@
1818
from pydantic_ai.models.test import TestModel
1919
from pydantic_ai.output import PromptedOutput, TextOutput
2020
from pydantic_ai.tools import RunContext
21+
from pydantic_ai.toolsets.abstract import ToolsetTool
22+
from pydantic_ai.toolsets.function import FunctionToolset
23+
from pydantic_ai.toolsets.wrapper import WrapperToolset
2124

2225
from .conftest import IsStr
2326

2427
try:
28+
import logfire
2529
from logfire.testing import CaptureLogfire
2630
except ImportError: # pragma: lax no cover
2731
logfire_installed = False
@@ -87,12 +91,37 @@ def test_logfire(
8791
instrument: InstrumentationSettings | bool,
8892
capfire: CaptureLogfire,
8993
) -> None:
90-
my_agent = Agent(model=TestModel(), instrument=instrument)
94+
class InstrumentedToolset(WrapperToolset):
95+
async def __aenter__(self) -> Self:
96+
with logfire.span('toolset_enter'): # pyright: ignore[reportPossiblyUnboundVariable]
97+
await super().__aenter__()
98+
return self
9199

92-
@my_agent.tool_plain
100+
async def __aexit__(self, *args: Any) -> bool | None:
101+
with logfire.span('toolset_exit'): # pyright: ignore[reportPossiblyUnboundVariable]
102+
return await super().__aexit__(*args)
103+
104+
async def call_tool(
105+
self, name: str, tool_args: dict[str, Any], ctx: RunContext[Any], tool: ToolsetTool[Any]
106+
) -> Any:
107+
with logfire.span('toolset_call_tool {name}', name=name): # pyright: ignore[reportPossiblyUnboundVariable]
108+
return await super().call_tool(name, tool_args, ctx, tool)
109+
110+
toolset = FunctionToolset()
111+
112+
@toolset.tool
93113
async def my_ret(x: int) -> str:
94114
return str(x + 1)
95115

116+
if instrument:
117+
toolset = InstrumentedToolset(toolset)
118+
119+
my_agent = Agent(
120+
model=TestModel(),
121+
toolsets=[toolset],
122+
instrument=instrument,
123+
)
124+
96125
result = my_agent.run_sync('Hello')
97126
assert result.output == snapshot('{"my_ret":"1"}')
98127

@@ -109,16 +138,29 @@ async def my_ret(x: int) -> str:
109138
'name': 'invoke_agent my_agent',
110139
'message': 'my_agent run',
111140
'children': [
112-
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
141+
{'id': 1, 'name': 'toolset_enter', 'message': 'toolset_enter'},
142+
{'id': 2, 'name': 'chat test', 'message': 'chat test'},
113143
{
114-
'id': 2,
144+
'id': 3,
115145
'name': 'running tools',
116146
'message': 'running 1 tool',
117147
'children': [
118-
{'id': 3, 'name': 'execute_tool my_ret', 'message': 'running tool: my_ret'},
148+
{
149+
'id': 4,
150+
'name': 'execute_tool my_ret',
151+
'message': 'running tool: my_ret',
152+
'children': [
153+
{
154+
'id': 5,
155+
'name': 'toolset_call_tool {name}',
156+
'message': 'toolset_call_tool my_ret',
157+
}
158+
],
159+
}
119160
],
120161
},
121-
{'id': 4, 'name': 'chat test', 'message': 'chat test'},
162+
{'id': 6, 'name': 'chat test', 'message': 'chat test'},
163+
{'id': 7, 'name': 'toolset_exit', 'message': 'toolset_exit'},
122164
],
123165
}
124166
]
@@ -131,16 +173,29 @@ async def my_ret(x: int) -> str:
131173
'name': 'agent run',
132174
'message': 'my_agent run',
133175
'children': [
134-
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
176+
{'id': 1, 'name': 'toolset_enter', 'message': 'toolset_enter'},
177+
{'id': 2, 'name': 'chat test', 'message': 'chat test'},
135178
{
136-
'id': 2,
179+
'id': 3,
137180
'name': 'running tools',
138181
'message': 'running 1 tool',
139182
'children': [
140-
{'id': 3, 'name': 'running tool', 'message': 'running tool: my_ret'},
183+
{
184+
'id': 4,
185+
'name': 'running tool',
186+
'message': 'running tool: my_ret',
187+
'children': [
188+
{
189+
'id': 5,
190+
'name': 'toolset_call_tool {name}',
191+
'message': 'toolset_call_tool my_ret',
192+
}
193+
],
194+
}
141195
],
142196
},
143-
{'id': 4, 'name': 'chat test', 'message': 'chat test'},
197+
{'id': 6, 'name': 'chat test', 'message': 'chat test'},
198+
{'id': 7, 'name': 'toolset_exit', 'message': 'toolset_exit'},
144199
],
145200
}
146201
]
@@ -156,14 +211,29 @@ async def my_ret(x: int) -> str:
156211
'name': 'agent run',
157212
'message': 'my_agent run',
158213
'children': [
159-
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
214+
{'id': 1, 'name': 'toolset_enter', 'message': 'toolset_enter'},
215+
{'id': 2, 'name': 'chat test', 'message': 'chat test'},
160216
{
161-
'id': 2,
217+
'id': 3,
162218
'name': 'running tools',
163219
'message': 'running 1 tool',
164-
'children': [{'id': 3, 'name': 'running tool', 'message': 'running tool: my_ret'}],
220+
'children': [
221+
{
222+
'id': 4,
223+
'name': 'running tool',
224+
'message': 'running tool: my_ret',
225+
'children': [
226+
{
227+
'id': 5,
228+
'name': 'toolset_call_tool {name}',
229+
'message': 'toolset_call_tool my_ret',
230+
}
231+
],
232+
}
233+
],
165234
},
166-
{'id': 4, 'name': 'chat test', 'message': 'chat test'},
235+
{'id': 6, 'name': 'chat test', 'message': 'chat test'},
236+
{'id': 7, 'name': 'toolset_exit', 'message': 'toolset_exit'},
167237
],
168238
}
169239
]
@@ -176,16 +246,29 @@ async def my_ret(x: int) -> str:
176246
'name': 'invoke_agent my_agent',
177247
'message': 'my_agent run',
178248
'children': [
179-
{'id': 1, 'name': 'chat test', 'message': 'chat test'},
249+
{'id': 1, 'name': 'toolset_enter', 'message': 'toolset_enter'},
250+
{'id': 2, 'name': 'chat test', 'message': 'chat test'},
180251
{
181-
'id': 2,
252+
'id': 3,
182253
'name': 'running tools',
183254
'message': 'running 1 tool',
184255
'children': [
185-
{'id': 3, 'name': 'execute_tool my_ret', 'message': 'running tool: my_ret'}
256+
{
257+
'id': 4,
258+
'name': 'execute_tool my_ret',
259+
'message': 'running tool: my_ret',
260+
'children': [
261+
{
262+
'id': 5,
263+
'name': 'toolset_call_tool {name}',
264+
'message': 'toolset_call_tool my_ret',
265+
}
266+
],
267+
}
186268
],
187269
},
188-
{'id': 4, 'name': 'chat test', 'message': 'chat test'},
270+
{'id': 6, 'name': 'chat test', 'message': 'chat test'},
271+
{'id': 7, 'name': 'toolset_exit', 'message': 'toolset_exit'},
189272
],
190273
}
191274
]
@@ -309,7 +392,9 @@ async def my_ret(x: int) -> str:
309392
),
310393
}
311394
)
312-
chat_span_attributes = summary.attributes[1]
395+
chat_span_attributes = next(
396+
attrs for attrs in summary.attributes.values() if attrs.get('gen_ai.operation.name', None) == 'chat'
397+
)
313398
if instrument is True or instrument.event_mode == 'attributes':
314399
if hasattr(capfire, 'get_collected_metrics'): # pragma: no branch
315400
assert capfire.get_collected_metrics() == snapshot(

0 commit comments

Comments
 (0)