Skip to content

Commit 5cf372a

Browse files
fix: close initialized MCP server if any MCP server fails to initalize (#2312)
1 parent 41dd069 commit 5cf372a

File tree

5 files changed

+66
-25
lines changed

5 files changed

+66
-25
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1792,9 +1792,11 @@ async def __aenter__(self) -> Self:
17921792
"""
17931793
async with self._enter_lock:
17941794
if self._entered_count == 0:
1795-
self._exit_stack = AsyncExitStack()
1796-
toolset = self._get_toolset()
1797-
await self._exit_stack.enter_async_context(toolset)
1795+
async with AsyncExitStack() as exit_stack:
1796+
toolset = self._get_toolset()
1797+
await exit_stack.enter_async_context(toolset)
1798+
1799+
self._exit_stack = exit_stack.pop_all()
17981800
self._entered_count += 1
17991801
return self
18001802

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -201,25 +201,24 @@ async def __aenter__(self) -> Self:
201201
"""
202202
async with self._enter_lock:
203203
if self._running_count == 0:
204-
self._exit_stack = AsyncExitStack()
205-
206-
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(
207-
self.client_streams()
208-
)
209-
client = ClientSession(
210-
read_stream=self._read_stream,
211-
write_stream=self._write_stream,
212-
sampling_callback=self._sampling_callback if self.allow_sampling else None,
213-
logging_callback=self.log_handler,
214-
read_timeout_seconds=timedelta(seconds=self.read_timeout),
215-
)
216-
self._client = await self._exit_stack.enter_async_context(client)
217-
218-
with anyio.fail_after(self.timeout):
219-
await self._client.initialize()
220-
221-
if log_level := self.log_level:
222-
await self._client.set_logging_level(log_level)
204+
async with AsyncExitStack() as exit_stack:
205+
self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams())
206+
client = ClientSession(
207+
read_stream=self._read_stream,
208+
write_stream=self._write_stream,
209+
sampling_callback=self._sampling_callback if self.allow_sampling else None,
210+
logging_callback=self.log_handler,
211+
read_timeout_seconds=timedelta(seconds=self.read_timeout),
212+
)
213+
self._client = await exit_stack.enter_async_context(client)
214+
215+
with anyio.fail_after(self.timeout):
216+
await self._client.initialize()
217+
218+
if log_level := self.log_level:
219+
await self._client.set_logging_level(log_level)
220+
221+
self._exit_stack = exit_stack.pop_all()
223222
self._running_count += 1
224223
return self
225224

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ def __post_init__(self):
4343
async def __aenter__(self) -> Self:
4444
async with self._enter_lock:
4545
if self._entered_count == 0:
46-
self._exit_stack = AsyncExitStack()
47-
for toolset in self.toolsets:
48-
await self._exit_stack.enter_async_context(toolset)
46+
async with AsyncExitStack() as exit_stack:
47+
for toolset in self.toolsets:
48+
await exit_stack.enter_async_context(toolset)
49+
self._exit_stack = exit_stack.pop_all()
4950
self._entered_count += 1
5051
return self
5152

tests/test_mcp.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,20 @@ async def test_reentrant_context_manager():
9191
pass
9292

9393

94+
async def test_context_manager_initialization_error() -> None:
95+
"""Test if streams are closed if client fails to initialize."""
96+
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
97+
from mcp.client.session import ClientSession
98+
99+
with patch.object(ClientSession, 'initialize', side_effect=Exception):
100+
with pytest.raises(Exception):
101+
async with server:
102+
pass
103+
104+
assert server._read_stream._closed # pyright: ignore[reportPrivateUsage]
105+
assert server._write_stream._closed # pyright: ignore[reportPrivateUsage]
106+
107+
94108
async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]):
95109
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo')
96110
async with server:

tests/test_toolsets.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
from dataclasses import dataclass, replace
55
from typing import TypeVar
6+
from unittest.mock import AsyncMock
67

78
import pytest
89
from inline_snapshot import snapshot
@@ -469,3 +470,27 @@ async def test_context_manager():
469470
async with toolset:
470471
assert server1.is_running
471472
assert server2.is_running
473+
474+
475+
class InitializationError(Exception):
476+
pass
477+
478+
479+
async def test_context_manager_failed_initialization():
480+
"""Test if MCP servers stop if any MCP server fails to initialize."""
481+
try:
482+
from pydantic_ai.mcp import MCPServerStdio
483+
except ImportError: # pragma: lax no cover
484+
pytest.skip('mcp is not installed')
485+
486+
server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
487+
server2 = AsyncMock()
488+
server2.__aenter__.side_effect = InitializationError
489+
490+
toolset = CombinedToolset([server1, server2])
491+
492+
with pytest.raises(InitializationError):
493+
async with toolset:
494+
pass
495+
496+
assert server1.is_running is False

0 commit comments

Comments
 (0)