diff --git a/docs/mcp/client.md b/docs/mcp/client.md index ce2f30d33..51543960c 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -26,7 +26,7 @@ Examples of all three are shown below; [mcp-run-python](run-python.md) is used a Each MCP server instance is a [toolset](../toolsets.md) and can be registered with an [`Agent`][pydantic_ai.Agent] using the `toolsets` argument. -You can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager to open and close connections to all registered servers (and in the case of stdio servers, start and stop the subprocesses) around the context where they'll be used in agent runs. You can also use [`async with server`][pydantic_ai.mcp.MCPServer.__aenter__] to manage the connection or subprocess of a specific server, for example if you'd like to use it with multiple agents. If you don't explicitly enter one of these context managers to set up the server, this will be done automatically when it's needed (e.g. to list the available tools or call a specific tool), but it's more efficient to do so around the entire context where you expect the servers to be used. +You can use the [`async with agent.setup()`][pydantic_ai.Agent.__aenter__] context manager to open and close connections to all registered servers (and in the case of stdio servers, start and stop the subprocesses) around the context where they’ll be used in agent runs. You can also use [`async with server`][pydantic_ai.mcp.MCPServer.__aenter__] to manage the connection or subprocess of a specific server, for example if you’d like to use it with multiple agents. If you don’t explicitly enter one of these context managers to set up the server, this will be done automatically when it’s needed (e.g. to list the available tools or call a specific tool), but it’s more efficient to do so around the entire context where you expect the servers to be used. ### Streamable HTTP Client @@ -61,7 +61,7 @@ server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent: # (3)! + async with agent.setup(): # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -71,18 +71,18 @@ async def main(): 2. Create an agent with the MCP server attached. 3. Create a client session to connect to the server. -_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ +_(This example is complete, it can be run “as is” with Python 3.10+ — you’ll need to add `asyncio.run(main())` to run `main`)_ -**What's happening here?** +**What’s happening here?** -- The model is receiving the prompt "how many days between 2000-01-01 and 2025-03-18?" -- The model decides "Oh, I've got this `run_python_code` tool, that will be a good way to answer this question", and writes some python code to calculate the answer. +- The model is receiving the prompt “how many days between 2000-01-01 and 2025-03-18?” +- The model decides “Oh, I’ve got this `run_python_code` tool, that will be a good way to answer this question”, and writes some python code to calculate the answer. - The model returns a tool call - Pydantic AI sends the tool call to the MCP server using the SSE transport - The model is called again with the return value of running the code - The model returns the final answer -You can visualise this clearly, and even see the code that's run by adding three lines of code to instrument the example with [logfire](https://logfire.pydantic.dev/docs): +You can visualise this clearly, and even see the code that’s run by adding three lines of code to instrument the example with [logfire](https://logfire.pydantic.dev/docs): ```python {title="mcp_sse_client_logfire.py" test="skip"} import logfire @@ -102,7 +102,7 @@ Will display as follows: !!! note [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. -The name "HTTP" is used since this implementation will be adapted in future to use the new +The name “HTTP” is used since this implementation will be adapted in future to use the new [Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. Before creating the SSE client, we need to run the server (docs [here](run-python.md)): @@ -122,7 +122,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent: # (3)! + async with agent.setup(): # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -132,11 +132,11 @@ async def main(): 2. Create an agent with the MCP server attached. 3. Create a client session to connect to the server. -_(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ +_(This example is complete, it can be run “as is” with Python 3.10+ — you’ll need to add `asyncio.run(main())` to run `main`)_ -### MCP "stdio" Server +### MCP “stdio” Server -The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. +The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you’d use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent @@ -158,7 +158,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent: + async with agent.setup(): result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -202,7 +202,7 @@ agent = Agent( async def main(): - async with agent: + async with agent.setup(): result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} @@ -273,7 +273,7 @@ server = MCPServerSSE( agent = Agent("openai:gpt-4o", toolsets=[server]) async def main(): - async with agent: + async with agent.setup(): result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -285,7 +285,7 @@ async def main(): ## MCP Sampling -!!! info "What is MCP Sampling?" +!!! info “What is MCP Sampling?” In MCP [sampling](https://modelcontextprotocol.io/docs/concepts/sampling) is a system by which an MCP server can make LLM calls via the MCP client - effectively proxying requests to an LLM via the client over whatever transport is being used. Sampling is extremely useful when MCP servers need to use Gen AI but you don't want to provision them each with their own LLM credentials or when a public MCP server would like the connecting client to pay for LLM calls. @@ -318,11 +318,11 @@ Pydantic AI supports sampling as both a client and server. See the [server](./se Sampling is automatically supported by Pydantic AI agents when they act as a client. -To be able to use sampling, an MCP server instance needs to have a [`sampling_model`][pydantic_ai.mcp.MCPServerStdio.sampling_model] set. This can be done either directly on the server using the constructor keyword argument or the property, or by using [`agent.set_mcp_sampling_model()`][pydantic_ai.Agent.set_mcp_sampling_model] to set the agent's model or one specified as an argument as the sampling model on all MCP servers registered with that agent. +To be able to use sampling, an MCP server instance needs to have a [`sampling_model`][pydantic_ai.mcp.MCPServerStdio.sampling_model] set. This can be done either directly on the server using the constructor keyword argument or the property, or by using [`agent.set_mcp_sampling_model()`][pydantic_ai.Agent.set_mcp_sampling_model] to set the agent’s model or one specified as an argument as the sampling model on all MCP servers registered with that agent. -Let's say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments). +Let’s say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments). -??? example "Sampling MCP Server" +??? example “Sampling MCP Server” ```python {title="generate_svg.py" py="3.10"} import re @@ -371,14 +371,14 @@ agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent: + async with agent.setup(): agent.set_mcp_sampling_model() result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. ``` -_(This example is complete, it can be run "as is" with Python 3.10+)_ +_(This example is complete, it can be run “as is” with Python 3.10+)_ You can disallow sampling by setting [`allow_sampling=False`][pydantic_ai.mcp.MCPServerStdio.allow_sampling] when creating the server reference, e.g.: diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 4492f1eb7..eecd81ca8 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -30,7 +30,7 @@ where: - `warmup` will run a minimal Python script to download and cache the Python standard library. This is also useful to check the server is running correctly. -Here's an example of using `@pydantic/mcp-run-python` with Pydantic AI: +Here’s an example of using `@pydantic/mcp-run-python` with Pydantic AI: ```python from pydantic_ai import Agent @@ -56,7 +56,7 @@ agent = Agent('claude-3-5-haiku-latest', toolsets=[server]) async def main(): - async with agent: + async with agent.setup(): result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025.w diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 8a916b5cc..12e39cfd8 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -64,7 +64,7 @@ async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, This ensures the worker is started and ready to process tasks as soon as the application starts. """ - async with app.task_manager, agent: + async with app.task_manager, agent.setup(): async with worker.run(): yield diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 008e17337..f3e86c03c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,7 +5,7 @@ import json import warnings from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy @@ -548,19 +548,20 @@ async def main(): _utils.validate_empty_kwargs(_deprecated_kwargs) - async with self.iter( - user_prompt=user_prompt, - output_type=output_type, - message_history=message_history, - model=model, - deps=deps, - model_settings=model_settings, - usage_limits=usage_limits, - usage=usage, - toolsets=toolsets, - ) as agent_run: - async for _ in agent_run: - pass + async with self.setup(): + async with self.iter( + user_prompt=user_prompt, + output_type=output_type, + message_history=message_history, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + toolsets=toolsets, + ) as agent_run: + async for _ in agent_run: + pass assert agent_run.result is not None, 'The graph run did not finish properly' return agent_run.result @@ -774,8 +775,8 @@ async def main(): toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) # This will raise errors for any name conflicts - async with toolset: - run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context) + async with self.setup(): + run_toolset = await ToolManager[AgentDepsT].build(toolset, ctx=run_context) # Merge model settings in order of precedence: run > agent > model merged_settings = merge_model_settings(model_used.settings, self.model_settings) @@ -1784,19 +1785,25 @@ def is_end_node( """ return isinstance(node, End) - async def __aenter__(self) -> Self: + @asynccontextmanager + async def setup(self) -> AsyncGenerator[Self, Any]: """Enter the agent context. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used. + """ + toolset = self._get_toolset() + async with toolset.setup(): + yield self + + async def __aenter__(self) -> Self: + """Enter the agent context. - This is a no-op if the agent has already been entered. + A backwards compatible way to enter the Agent context """ async with self._enter_lock: if self._entered_count == 0: async with AsyncExitStack() as exit_stack: - toolset = self._get_toolset() - await exit_stack.enter_async_context(toolset) - + await exit_stack.enter_async_context(self.setup()) self._exit_stack = exit_stack.pop_all() self._entered_count += 1 return self @@ -1828,7 +1835,7 @@ def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None: @asynccontextmanager @deprecated( - '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.' + '`run_mcp_servers` is deprecated, use `async with agent.setup():` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.' ) async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None @@ -1846,7 +1853,7 @@ async def run_mcp_servers( if model is not None: raise - async with self: + async with self.setup(): yield def to_ag_ui( diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index eb3e13ff5..0d4395be8 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -5,7 +5,7 @@ import warnings from abc import ABC, abstractmethod from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Sequence +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager from dataclasses import dataclass, field, replace from datetime import timedelta @@ -72,7 +72,7 @@ class MCPServer(AbstractToolset[Any], ABC): _running_count: int _exit_stack: AsyncExitStack | None - _client: ClientSession + _client: ClientSession | None = None _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] @@ -99,6 +99,12 @@ async def client_streams( def name(self) -> str: return repr(self) + @property + def client(self) -> ClientSession: + if self._client is None: + raise RuntimeError('MCP server is not running') + return self._client + @property def tool_name_conflict_hint(self) -> str: return 'Consider setting `tool_prefix` to avoid name conflicts.' @@ -110,8 +116,8 @@ async def list_tools(self) -> list[mcp_types.Tool]: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - async with self: # Ensure server is running - result = await self._client.list_tools() + async with self.setup(): # Ensure server is running + result = await self.client.list_tools() return result.tools async def direct_call_tool( @@ -133,9 +139,9 @@ async def direct_call_tool( Raises: ModelRetry: If the tool call fails. """ - async with self: # Ensure server is running + async with self.setup(): # Ensure server is running try: - result = await self._client.send_request( + result = await self.client.send_request( mcp_types.ClientRequest( mcp_types.CallToolRequest( method='tools/call', @@ -191,6 +197,11 @@ async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name) } + @asynccontextmanager + async def setup(self) -> AsyncGenerator[Self, Any]: + async with self: + yield self + async def __aenter__(self) -> Self: """Enter the MCP server context. @@ -286,7 +297,7 @@ async def _map_tool_result_part( resource = part.resource return self._get_content(resource) elif isinstance(part, mcp_types.ResourceLink): - resource_result: mcp_types.ReadResourceResult = await self._client.read_resource(part.uri) + resource_result: mcp_types.ReadResourceResult = await self.client.read_resource(part.uri) return ( self._get_content(resource_result.contents[0]) if len(resource_result.contents) == 1 @@ -339,7 +350,7 @@ class MCPServerStdio(MCPServer): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent: # (2)! + async with agent.setup(): # (2)! ... ``` @@ -629,7 +640,7 @@ class MCPServerSSE(_MCPServerHTTP): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent: # (2)! + async with agent.setup(): # (2)! ... ``` @@ -663,7 +674,7 @@ class MCPServerHTTP(MCPServerSSE): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent: # (2)! + async with agent.setup(): # (2)! ... ``` @@ -692,7 +703,7 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent: # (2)! + async with agent.setup(): # (2)! ... ``` """ diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py index 455336418..550e16bee 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -1,6 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol @@ -79,19 +81,13 @@ def tool_name_conflict_hint(self) -> str: """A hint for how to avoid name conflicts with other toolsets for use in error messages.""" return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.' - async def __aenter__(self) -> Self: - """Enter the toolset context. + @asynccontextmanager + async def setup(self) -> AsyncGenerator[Self, Any]: + """Set up the toolset. This is where you can set up network connections in a concrete implementation. """ - return self - - async def __aexit__(self, *args: Any) -> bool | None: - """Exit the toolset context. - - This is where you can tear down network connections in a concrete implementation. - """ - return None + yield self @abstractmethod async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index d2ddaa125..fe3815685 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -1,15 +1,14 @@ from __future__ import annotations import asyncio -from collections.abc import Sequence -from contextlib import AsyncExitStack -from dataclasses import dataclass, field +from collections.abc import AsyncGenerator, Sequence +from contextlib import AsyncExitStack, asynccontextmanager +from dataclasses import dataclass from typing import Any, Callable from typing_extensions import Self from .._run_context import AgentDepsT, RunContext -from .._utils import get_async_lock from ..exceptions import UserError from .abstract import AbstractToolset, ToolsetTool @@ -31,31 +30,17 @@ class CombinedToolset(AbstractToolset[AgentDepsT]): toolsets: Sequence[AbstractToolset[AgentDepsT]] - _enter_lock: asyncio.Lock = field(compare=False, init=False) - _entered_count: int = field(init=False) - _exit_stack: AsyncExitStack | None = field(init=False) - - def __post_init__(self): - self._enter_lock = get_async_lock() - self._entered_count = 0 - self._exit_stack = None - - async def __aenter__(self) -> Self: - async with self._enter_lock: - if self._entered_count == 0: - async with AsyncExitStack() as exit_stack: - for toolset in self.toolsets: - await exit_stack.enter_async_context(toolset) - self._exit_stack = exit_stack.pop_all() - self._entered_count += 1 - return self - - async def __aexit__(self, *args: Any) -> bool | None: - async with self._enter_lock: - self._entered_count -= 1 - if self._entered_count == 0 and self._exit_stack is not None: - await self._exit_stack.aclose() - self._exit_stack = None + @asynccontextmanager + async def setup(self) -> AsyncGenerator[Self, Any]: + async with AsyncExitStack() as exit_stack: + try: + for toolset in self.toolsets: + await exit_stack.enter_async_context(toolset.setup()) + except Exception as e: + await exit_stack.aclose() + raise e + + yield self async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets)) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 8440f1c46..922fc3d7a 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from dataclasses import dataclass from typing import Any, Callable @@ -18,12 +20,10 @@ class WrapperToolset(AbstractToolset[AgentDepsT]): wrapped: AbstractToolset[AgentDepsT] - async def __aenter__(self) -> Self: - await self.wrapped.__aenter__() - return self - - async def __aexit__(self, *args: Any) -> bool | None: - return await self.wrapped.__aexit__(*args) + @asynccontextmanager + async def setup(self) -> AsyncGenerator[Self, Any]: + async with self.wrapped.setup(): + yield self async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: return await self.wrapped.get_tools(ctx) diff --git a/tests/test_agent.py b/tests/test_agent.py index 0b37d1041..22147d077 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3866,7 +3866,7 @@ async def only_if_plan_presented( ) -async def test_explicit_context_manager(): +async def test_legacy_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio except ImportError: # pragma: lax no cover @@ -3886,6 +3886,26 @@ async def test_explicit_context_manager(): assert server2.is_running +async def test_explicit_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + agent = Agent('test', toolsets=[toolset]) + + async with agent.setup(): + assert server1.is_running + assert server2.is_running + + async with agent.setup(): + assert server1.is_running + assert server2.is_running + + async def test_implicit_context_manager(): try: from pydantic_ai.mcp import MCPServerStdio diff --git a/tests/test_examples.py b/tests/test_examples.py index 20a7ece7b..5bb4ffc5e 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -269,12 +269,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: class MockMCPServer(AbstractToolset[Any]): - async def __aenter__(self) -> MockMCPServer: - return self - - async def __aexit__(self, *args: Any) -> None: - pass - async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: return {} diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 1021b3151..1bb5b1bea 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -72,7 +72,7 @@ def run_context(model: Model) -> RunContext[int]: async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - async with server: + async with server.setup(): tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] assert len(tools) == snapshot(16) assert tools[0].name == 'celsius_to_fahrenheit' @@ -86,8 +86,8 @@ async def test_stdio_server(run_context: RunContext[int]): async def test_reentrant_context_manager(): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - async with server: - async with server: + async with server.setup(): + async with server.setup(): pass @@ -96,18 +96,22 @@ async def test_context_manager_initialization_error() -> None: server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) from mcp.client.session import ClientSession + with pytest.raises(Exception): + _ = server.client + with patch.object(ClientSession, 'initialize', side_effect=Exception): with pytest.raises(Exception): - async with server: + async with server.setup(): pass - assert server._read_stream._closed # pyright: ignore[reportPrivateUsage] - assert server._write_stream._closed # pyright: ignore[reportPrivateUsage] + assert server._client is not None # pyright: ignore[reportPrivateUsage] + assert server._client._read_stream._closed # pyright: ignore[reportPrivateUsage] + assert server._client._write_stream._closed # pyright: ignore[reportPrivateUsage] async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') - async with server: + async with server.setup(): tools = await server.get_tools(run_context) assert all(name.startswith('foo_') for name in tools.keys()) @@ -120,7 +124,7 @@ async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): async def test_stdio_server_with_cwd(run_context: RunContext[int]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) - async with server: + async with server.setup(): tools = await server.get_tools(run_context) assert len(tools) == snapshot(16) @@ -140,7 +144,7 @@ async def process_tool_call( return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) - async with server: + async with server.setup(): agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), toolsets=[server]) result = await agent.run('Echo with deps set to 42', deps=42) assert result.output == snapshot('{"echo_deps":{"echo":"This is an echo message","deps":42}}') @@ -180,7 +184,7 @@ def test_sse_server_conflicting_timeout_params(): @pytest.mark.vcr() async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') assert result.output == snapshot('0 degrees Celsius is equal to 32 degrees Fahrenheit.') assert result.all_messages() == snapshot( @@ -257,7 +261,7 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent: + async with agent.setup(): with pytest.raises( UserError, match=re.escape( @@ -280,7 +284,7 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent: + async with agent.setup(): # This means that we passed the _prepare_request_parameters check and there is no conflict in the tool name with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): await agent.run('No conflict') @@ -295,7 +299,7 @@ async def test_agent_with_server_not_running(agent: Agent, allow_model_requests: async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None - async with server: + async with server.setup(): tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] assert len(tools) == snapshot(16) assert tools[13].name == 'get_log_level' @@ -307,14 +311,14 @@ async def test_log_level_unset(run_context: RunContext[int]): async def test_log_level_set(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') assert server.log_level == 'info' - async with server: + async with server.setup(): result = await server.direct_call_tool('get_log_level', {}) assert result == snapshot('info') @pytest.mark.vcr() async def test_tool_returning_str(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('What is the weather in Mexico City?') assert result.output == snapshot( 'The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' @@ -393,7 +397,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('Get me the product name') assert result.output == snapshot('The product name is "Pydantic AI".') assert result.all_messages() == snapshot( @@ -466,7 +470,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A @pytest.mark.vcr() async def test_tool_returning_text_resource_link(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('Get me the product name via get_product_name_link') assert result.output == snapshot('The product name is "Pydantic AI".') assert result.all_messages() == snapshot( @@ -539,7 +543,7 @@ async def test_tool_returning_text_resource_link(allow_model_requests: None, age @pytest.mark.vcr() async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent: + async with agent.setup(): result = await agent.run('Get me the image resource') assert result.output == snapshot( 'This is an image of a sliced kiwi with a vibrant green interior and black seeds.' @@ -621,7 +625,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: async def test_tool_returning_image_resource_link( allow_model_requests: None, agent: Agent, image_content: BinaryContent ): - async with agent: + async with agent.setup(): result = await agent.run('Get me the image resource via get_image_resource_link') assert result.output == snapshot( 'This is an image of a sliced kiwi fruit. It shows the green, seed-speckled interior with fuzzy brown skin around the edges.' @@ -704,7 +708,7 @@ async def test_tool_returning_audio_resource( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent: + async with agent.setup(): result = await agent.run("What's the content of the audio resource?", model=model) assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."') assert result.all_messages() == snapshot( @@ -758,7 +762,7 @@ async def test_tool_returning_audio_resource_link( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent: + async with agent.setup(): result = await agent.run("What's the content of the audio resource via get_audio_resource_link?", model=model) assert result.output == snapshot('00:05') assert result.all_messages() == snapshot( @@ -819,7 +823,7 @@ async def test_tool_returning_audio_resource_link( @pytest.mark.vcr() async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent: + async with agent.setup(): result = await agent.run('Get me an image') assert result.output == snapshot('Here is an image of a sliced kiwi on a white background.') assert result.all_messages() == snapshot( @@ -899,7 +903,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im @pytest.mark.vcr() async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('Get me a dict, respond on one line') assert result.output == snapshot('{"foo":"bar","baz":123}') assert result.all_messages() == snapshot( @@ -966,7 +970,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_error(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') assert result.output == snapshot( 'I called the tool with the correct parameter, and it returned: "This is not an error."' @@ -1080,7 +1084,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_none(allow_model_requests: None, agent: Agent): - async with agent: + async with agent.setup(): result = await agent.run('Call the none tool and say Hello') assert result.output == snapshot('Hello! How can I assist you today?') assert result.all_messages() == snapshot( @@ -1147,7 +1151,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_multiple_items(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent: + async with agent.setup(): result = await agent.run('Get me multiple items and summarize in one sentence') assert result.output == snapshot( 'The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' @@ -1239,7 +1243,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: async def test_client_sampling(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server.sampling_model = TestModel(custom_output_text='sampling model response') - async with server: + async with server.setup(): result = await server.direct_call_tool('use_sampling', {'foo': 'bar'}) assert result == snapshot( { @@ -1255,7 +1259,7 @@ async def test_client_sampling(run_context: RunContext[int]): async def test_client_sampling_disabled(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], allow_sampling=False) server.sampling_model = TestModel(custom_output_text='sampling model response') - async with server: + async with server.setup(): with pytest.raises(ModelRetry, match='Error executing tool use_sampling: Sampling not supported'): await server.direct_call_tool('use_sampling', {'foo': 'bar'}) @@ -1265,7 +1269,7 @@ async def test_mcp_server_raises_mcp_error( ) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) - async with agent: + async with agent.setup(): with patch.object( mcp_server._client, # pyright: ignore[reportPrivateUsage] 'send_request', diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index f217b34f4..a735f6598 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -2,12 +2,14 @@ import re from collections import defaultdict +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from dataclasses import dataclass, replace -from typing import TypeVar -from unittest.mock import AsyncMock +from typing import Any, TypeVar import pytest from inline_snapshot import snapshot +from typing_extensions import Self from pydantic_ai._run_context import RunContext from pydantic_ai._tool_manager import ToolManager @@ -15,6 +17,7 @@ from pydantic_ai.messages import ToolCallPart from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool from pydantic_ai.toolsets.combined import CombinedToolset from pydantic_ai.toolsets.filtered import FilteredToolset from pydantic_ai.toolsets.function import FunctionToolset @@ -464,11 +467,11 @@ async def test_context_manager(): server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) - async with toolset: + async with toolset.setup(): assert server1.is_running assert server2.is_running - async with toolset: + async with toolset.setup(): assert server1.is_running assert server2.is_running @@ -484,14 +487,26 @@ async def test_context_manager_failed_initialization(): except ImportError: # pragma: lax no cover pytest.skip('mcp is not installed') + class FailingToolset(AbstractToolset[Any]): + @asynccontextmanager + async def setup(self) -> AsyncGenerator[Self, Any]: + raise InitializationError + + yield self # pragma: no cover + + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: ... # pragma: no cover + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[Any], tool: ToolsetTool[Any] + ) -> Any: ... # pragma: no cover + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - server2 = AsyncMock() - server2.__aenter__.side_effect = InitializationError + server2 = FailingToolset() toolset = CombinedToolset([server1, server2]) with pytest.raises(InitializationError): - async with toolset: + async with toolset.setup(): pass assert server1.is_running is False