Skip to content

[Draft] Use explicit context managers for Agent and Toolset #2381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions docs/mcp/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Examples of all three are shown below; [mcp-run-python](run-python.md) is used a

Each MCP server instance is a [toolset](../toolsets.md) and can be registered with an [`Agent`][pydantic_ai.Agent] using the `toolsets` argument.

You can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager to open and close connections to all registered servers (and in the case of stdio servers, start and stop the subprocesses) around the context where they'll be used in agent runs. You can also use [`async with server`][pydantic_ai.mcp.MCPServer.__aenter__] to manage the connection or subprocess of a specific server, for example if you'd like to use it with multiple agents. If you don't explicitly enter one of these context managers to set up the server, this will be done automatically when it's needed (e.g. to list the available tools or call a specific tool), but it's more efficient to do so around the entire context where you expect the servers to be used.
You can use the [`async with agent.setup()`][pydantic_ai.Agent.__aenter__] context manager to open and close connections to all registered servers (and in the case of stdio servers, start and stop the subprocesses) around the context where theyll be used in agent runs. You can also use [`async with server`][pydantic_ai.mcp.MCPServer.__aenter__] to manage the connection or subprocess of a specific server, for example if youd like to use it with multiple agents. If you dont explicitly enter one of these context managers to set up the server, this will be done automatically when its needed (e.g. to list the available tools or call a specific tool), but its more efficient to do so around the entire context where you expect the servers to be used.

### Streamable HTTP Client

Expand Down Expand Up @@ -61,7 +61,7 @@ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)!
agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)!

async def main():
async with agent: # (3)!
async with agent.setup(): # (3)!
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
print(result.output)
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
Expand All @@ -71,18 +71,18 @@ async def main():
2. Create an agent with the MCP server attached.
3. Create a client session to connect to the server.

_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_
_(This example is complete, it can be run as is with Python 3.10+ — youll need to add `asyncio.run(main())` to run `main`)_

**What's happening here?**
**Whats happening here?**

- The model is receiving the prompt "how many days between 2000-01-01 and 2025-03-18?"
- The model decides "Oh, I've got this `run_python_code` tool, that will be a good way to answer this question", and writes some python code to calculate the answer.
- The model is receiving the prompt how many days between 2000-01-01 and 2025-03-18?
- The model decides Oh, Ive got this `run_python_code` tool, that will be a good way to answer this question, and writes some python code to calculate the answer.
- The model returns a tool call
- Pydantic AI sends the tool call to the MCP server using the SSE transport
- The model is called again with the return value of running the code
- The model returns the final answer

You can visualise this clearly, and even see the code that's run by adding three lines of code to instrument the example with [logfire](https://logfire.pydantic.dev/docs):
You can visualise this clearly, and even see the code thats run by adding three lines of code to instrument the example with [logfire](https://logfire.pydantic.dev/docs):

```python {title="mcp_sse_client_logfire.py" test="skip"}
import logfire
Expand All @@ -102,7 +102,7 @@ Will display as follows:
!!! note
[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI.

The name "HTTP" is used since this implementation will be adapted in future to use the new
The name HTTP is used since this implementation will be adapted in future to use the new
[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development.

Before creating the SSE client, we need to run the server (docs [here](run-python.md)):
Expand All @@ -122,7 +122,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)!


async def main():
async with agent: # (3)!
async with agent.setup(): # (3)!
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
print(result.output)
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
Expand All @@ -132,11 +132,11 @@ async def main():
2. Create an agent with the MCP server attached.
3. Create a client session to connect to the server.

_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_
_(This example is complete, it can be run as is with Python 3.10+ — youll need to add `asyncio.run(main())` to run `main`)_

### MCP "stdio" Server
### MCP stdio Server

The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class.
The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, youd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class.

```python {title="mcp_stdio_client.py" py="3.10"}
from pydantic_ai import Agent
Expand All @@ -158,7 +158,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server])


async def main():
async with agent:
async with agent.setup():
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
print(result.output)
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
Expand Down Expand Up @@ -202,7 +202,7 @@ agent = Agent(


async def main():
async with agent:
async with agent.setup():
result = await agent.run('Echo with deps set to 42', deps=42)
print(result.output)
#> {"echo_deps":{"echo":"This is an echo message","deps":42}}
Expand Down Expand Up @@ -273,7 +273,7 @@ server = MCPServerSSE(
agent = Agent("openai:gpt-4o", toolsets=[server])

async def main():
async with agent:
async with agent.setup():
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
print(result.output)
#> There are 9,208 days between January 1, 2000, and March 18, 2025.
Expand All @@ -285,7 +285,7 @@ async def main():

## MCP Sampling

!!! info "What is MCP Sampling?"
!!! info What is MCP Sampling?
In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used.

Sampling is extremely useful when MCP servers need to use Gen AI but you don't want to provision them each with their own LLM credentials or when a public MCP server would like the connecting client to pay for LLM calls.
Expand Down Expand Up @@ -318,11 +318,11 @@ Pydantic AI supports sampling as both a client and server. See the [server](./se

Sampling is automatically supported by Pydantic AI agents when they act as a client.

To be able to use sampling, an MCP server instance needs to have a [`sampling_model`][pydantic_ai.mcp.MCPServerStdio.sampling_model] set. This can be done either directly on the server using the constructor keyword argument or the property, or by using [`agent.set_mcp_sampling_model()`][pydantic_ai.Agent.set_mcp_sampling_model] to set the agent's model or one specified as an argument as the sampling model on all MCP servers registered with that agent.
To be able to use sampling, an MCP server instance needs to have a [`sampling_model`][pydantic_ai.mcp.MCPServerStdio.sampling_model] set. This can be done either directly on the server using the constructor keyword argument or the property, or by using [`agent.set_mcp_sampling_model()`][pydantic_ai.Agent.set_mcp_sampling_model] to set the agents model or one specified as an argument as the sampling model on all MCP servers registered with that agent.

Let's say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments).
Lets say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments).

??? example "Sampling MCP Server"
??? example Sampling MCP Server

```python {title="generate_svg.py" py="3.10"}
import re
Expand Down Expand Up @@ -371,14 +371,14 @@ agent = Agent('openai:gpt-4o', toolsets=[server])


async def main():
async with agent:
async with agent.setup():
agent.set_mcp_sampling_model()
result = await agent.run('Create an image of a robot in a punk style.')
print(result.output)
#> Image file written to robot_punk.svg.
```

_(This example is complete, it can be run "as is" with Python 3.10+)_
_(This example is complete, it can be run as is with Python 3.10+)_

You can disallow sampling by setting [`allow_sampling=False`][pydantic_ai.mcp.MCPServerStdio.allow_sampling] when creating the server reference, e.g.:

Expand Down
4 changes: 2 additions & 2 deletions mcp-run-python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ where:
- `warmup` will run a minimal Python script to download and cache the Python standard library. This is also useful to
check the server is running correctly.

Here's an example of using `@pydantic/mcp-run-python` with Pydantic AI:
Heres an example of using `@pydantic/mcp-run-python` with Pydantic AI:

```python
from pydantic_ai import Agent
Expand All @@ -56,7 +56,7 @@ agent = Agent('claude-3-5-haiku-latest', toolsets=[server])


async def main():
async with agent:
async with agent.setup():
result = await agent.run('How many days between 2000-01-01 and 2025-03-18?')
print(result.output)
#> There are 9,208 days between January 1, 2000, and March 18, 2025.w
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT,

This ensures the worker is started and ready to process tasks as soon as the application starts.
"""
async with app.task_manager, agent:
async with app.task_manager, agent.setup():
async with worker.run():
yield

Expand Down
53 changes: 30 additions & 23 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import warnings
from asyncio import Lock
from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterator, Mapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
from contextvars import ContextVar
from copy import deepcopy
Expand Down Expand Up @@ -548,19 +548,20 @@ async def main():

_utils.validate_empty_kwargs(_deprecated_kwargs)

async with self.iter(
user_prompt=user_prompt,
output_type=output_type,
message_history=message_history,
model=model,
deps=deps,
model_settings=model_settings,
usage_limits=usage_limits,
usage=usage,
toolsets=toolsets,
) as agent_run:
async for _ in agent_run:
pass
async with self.setup():
async with self.iter(
user_prompt=user_prompt,
output_type=output_type,
message_history=message_history,
model=model,
deps=deps,
model_settings=model_settings,
usage_limits=usage_limits,
usage=usage,
toolsets=toolsets,
) as agent_run:
async for _ in agent_run:
pass

assert agent_run.result is not None, 'The graph run did not finish properly'
return agent_run.result
Expand Down Expand Up @@ -774,8 +775,8 @@ async def main():

toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets)
# This will raise errors for any name conflicts
async with toolset:
run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context)
async with self.setup():
run_toolset = await ToolManager[AgentDepsT].build(toolset, ctx=run_context)

# Merge model settings in order of precedence: run > agent > model
merged_settings = merge_model_settings(model_used.settings, self.model_settings)
Expand Down Expand Up @@ -1784,19 +1785,25 @@ def is_end_node(
"""
return isinstance(node, End)

async def __aenter__(self) -> Self:
@asynccontextmanager
async def setup(self) -> AsyncGenerator[Self, Any]:
"""Enter the agent context.

This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used.
"""
toolset = self._get_toolset()
async with toolset.setup():
yield self

async def __aenter__(self) -> Self:
"""Enter the agent context.

This is a no-op if the agent has already been entered.
A backwards compatible way to enter the Agent context
"""
async with self._enter_lock:
if self._entered_count == 0:
async with AsyncExitStack() as exit_stack:
toolset = self._get_toolset()
await exit_stack.enter_async_context(toolset)

await exit_stack.enter_async_context(self.setup())
self._exit_stack = exit_stack.pop_all()
self._entered_count += 1
return self
Expand Down Expand Up @@ -1828,7 +1835,7 @@ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None:

@asynccontextmanager
@deprecated(
'`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.'
'`run_mcp_servers` is deprecated, use `async with agent.setup():` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.'
)
async def run_mcp_servers(
self, model: models.Model | models.KnownModelName | str | None = None
Expand All @@ -1846,7 +1853,7 @@ async def run_mcp_servers(
if model is not None:
raise

async with self:
async with self.setup():
yield

def to_ag_ui(
Expand Down
33 changes: 22 additions & 11 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from abc import ABC, abstractmethod
from asyncio import Lock
from collections.abc import AsyncIterator, Awaitable, Sequence
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field, replace
from datetime import timedelta
Expand Down Expand Up @@ -72,7 +72,7 @@ class MCPServer(AbstractToolset[Any], ABC):
_running_count: int
_exit_stack: AsyncExitStack | None

_client: ClientSession
_client: ClientSession | None = None
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
_write_stream: MemoryObjectSendStream[SessionMessage]

Expand All @@ -99,6 +99,12 @@ async def client_streams(
def name(self) -> str:
return repr(self)

@property
def client(self) -> ClientSession:
if self._client is None:
raise RuntimeError('MCP server is not running')
return self._client

@property
def tool_name_conflict_hint(self) -> str:
return 'Consider setting `tool_prefix` to avoid name conflicts.'
Expand All @@ -110,8 +116,8 @@ async def list_tools(self) -> list[mcp_types.Tool]:
- We don't cache tools as they might change.
- We also don't subscribe to the server to avoid complexity.
"""
async with self: # Ensure server is running
result = await self._client.list_tools()
async with self.setup(): # Ensure server is running
result = await self.client.list_tools()
return result.tools

async def direct_call_tool(
Expand All @@ -133,9 +139,9 @@ async def direct_call_tool(
Raises:
ModelRetry: If the tool call fails.
"""
async with self: # Ensure server is running
async with self.setup(): # Ensure server is running
try:
result = await self._client.send_request(
result = await self.client.send_request(
mcp_types.ClientRequest(
mcp_types.CallToolRequest(
method='tools/call',
Expand Down Expand Up @@ -191,6 +197,11 @@ async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]:
if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name)
}

@asynccontextmanager
async def setup(self) -> AsyncGenerator[Self, Any]:
async with self:
yield self

async def __aenter__(self) -> Self:
"""Enter the MCP server context.

Expand Down Expand Up @@ -286,7 +297,7 @@ async def _map_tool_result_part(
resource = part.resource
return self._get_content(resource)
elif isinstance(part, mcp_types.ResourceLink):
resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri)
resource_result: mcp_types.ReadResourceResult = await self.client.read_resource(part.uri)
return (
self._get_content(resource_result.contents[0])
if len(resource_result.contents) == 1
Expand Down Expand Up @@ -339,7 +350,7 @@ class MCPServerStdio(MCPServer):
agent = Agent('openai:gpt-4o', toolsets=[server])

async def main():
async with agent: # (2)!
async with agent.setup(): # (2)!
...
```

Expand Down Expand Up @@ -629,7 +640,7 @@ class MCPServerSSE(_MCPServerHTTP):
agent = Agent('openai:gpt-4o', toolsets=[server])

async def main():
async with agent: # (2)!
async with agent.setup(): # (2)!
...
```

Expand Down Expand Up @@ -663,7 +674,7 @@ class MCPServerHTTP(MCPServerSSE):
agent = Agent('openai:gpt-4o', toolsets=[server])

async def main():
async with agent: # (2)!
async with agent.setup(): # (2)!
...
```

Expand Down Expand Up @@ -692,7 +703,7 @@ class MCPServerStreamableHTTP(_MCPServerHTTP):
agent = Agent('openai:gpt-4o', toolsets=[server])

async def main():
async with agent: # (2)!
async with agent.setup(): # (2)!
...
```
"""
Expand Down
Loading
Loading