diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index fbe843a510..b919b76613 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -38,7 +38,6 @@ from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client except ImportError as e: - if sys.version_info < (3, 10): raise ImportError( 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' @@ -49,7 +48,6 @@ logger = logging.getLogger('google_adk.' + __name__) - class StdioConnectionParams(BaseModel): """Parameters for the MCP Stdio connection. @@ -58,16 +56,17 @@ class StdioConnectionParams(BaseModel): timeout: Timeout in seconds for establishing the connection to the MCP stdio server. """ - server_params: StdioServerParameters timeout: float = 5.0 + class Config: + arbitrary_types_allowed = True class SseConnectionParams(BaseModel): """Parameters for the MCP SSE connection. See MCP SSE Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py + [https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py) Attributes: url: URL for the MCP SSE server. @@ -77,18 +76,16 @@ class SseConnectionParams(BaseModel): sse_read_timeout: Timeout in seconds for reading data from the MCP SSE server. """ - url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 - class StreamableHTTPConnectionParams(BaseModel): """Parameters for the MCP Streamable HTTP connection. See MCP Streamable HTTP Client documentation for more details. - https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py + [https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py](https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py) Attributes: url: URL for the MCP Streamable HTTP server. @@ -100,41 +97,45 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: Whether to terminate the MCP Streamable HTTP server when the connection is closed. """ - url: str headers: dict[str, Any] | None = None timeout: float = 5.0 sse_read_timeout: float = 60 * 5.0 terminate_on_close: bool = True - def retry_on_closed_resource(func): """Decorator to automatically retry action when MCP session is closed. + CRITICAL WARNING: This decorator is UNSAFE for non-idempotent operations. + Do NOT use with tool calls that create, update, or delete resources as + retrying can cause duplicate operations or data corruption. + + Only use with read-only, idempotent operations like list_tools, + list_resources, or read_resource. + + Do NOT apply to generic tool execution methods like _run_async_impl. + When MCP session was closed, the decorator will automatically retry the action once. The create_session method will handle creating a new session if the old one was disconnected. Args: - func: The function to decorate. + func: The function to decorate. Must be idempotent and safe to retry. Returns: The decorated function. """ - @functools.wraps(func) # Preserves original function metadata async def wrapper(self, *args, **kwargs): try: return await func(self, *args, **kwargs) - except anyio.ClosedResourceError: + except (anyio.ClosedResourceError, anyio.BrokenResourceError): # Simply retry the function - create_session will handle # detecting and replacing disconnected sessions - logger.info('Retrying %s due to closed resource', func.__name__) + logger.info('Retrying %s due to closed/broken resource', func.__name__) return await func(self, *args, **kwargs) - return wrapper - class MCPSessionManager: """Manages MCP client sessions. @@ -176,11 +177,10 @@ def __init__( ) else: self._connection_params = connection_params - self._errlog = errlog + self._errlog = errlog # Session pool: maps session keys to (session, exit_stack) tuples self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {} - # Lock to prevent race conditions in session creation self._session_lock = asyncio.Lock() @@ -292,6 +292,7 @@ def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): ' StdioServerParameters or SseServerParams, but got' f' {self._connection_params}' ) + return client async def create_session( @@ -313,7 +314,6 @@ async def create_session( """ # Merge headers once at the beginning merged_headers = self._merge_headers(headers) - # Generate session key using merged headers session_key = self._generate_session_key(merged_headers) @@ -322,7 +322,6 @@ async def create_session( # Check if we have an existing session if session_key in self._sessions: session, exit_stack = self._sessions[session_key] - # Check if the existing session is still connected if not self._is_session_disconnected(session): # Session is still good, return it @@ -339,11 +338,10 @@ async def create_session( # Create a new session (either first time or replacing disconnected one) exit_stack = AsyncExitStack() - try: client = self._create_client(merged_headers) - transports = await exit_stack.enter_async_context(client) + # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients. if isinstance(self._connection_params, StdioConnectionParams): @@ -359,8 +357,8 @@ async def create_session( session = await exit_stack.enter_async_context( ClientSession(*transports[:2]) ) - await session.initialize() + await session.initialize() # Store session and exit stack in the pool self._sessions[session_key] = (session, exit_stack) logger.debug('Created new session: %s', session_key) @@ -369,7 +367,10 @@ async def create_session( except Exception: # If session creation fails, clean up the exit stack if exit_stack: - await exit_stack.aclose() + try: + await exit_stack.aclose() + except (anyio.BrokenResourceError, anyio.ClosedResourceError) as e: + logger.warning('Error during exit stack cleanup: %s', e) raise async def close(self): @@ -389,7 +390,5 @@ async def close(self): finally: del self._sessions[session_key] - SseServerParams = SseConnectionParams - StreamableHTTPServerParams = StreamableHTTPConnectionParams