diff --git a/src/mcp_agent/agents/agent.py b/src/mcp_agent/agents/agent.py index 210fd47d6..415c05ac6 100644 --- a/src/mcp_agent/agents/agent.py +++ b/src/mcp_agent/agents/agent.py @@ -490,6 +490,42 @@ def _should_include_non_namespaced_tool( # No non_namespaced_tools key and no wildcard - include by default (no filter for non-namespaced) return True, None + async def _sync_with_aggregator_state(self) -> None: + if self._agent_tasks is None: + return + + executor = self.context.executor if self.context else None + if executor is None: + return + + response = await executor.execute( + self._agent_tasks.get_aggregator_state_task, + GetAggregatorStateRequest(agent_name=self.name), + ) + + if isinstance(response, BaseException): # pragma: no cover - defensive + raise response + + self.initialized = response.initialized + + self._namespaced_tool_map.clear() + self._namespaced_tool_map.update(response.namespaced_tool_map) + + self._server_to_tool_map.clear() + self._server_to_tool_map.update(response.server_to_tool_map) + + self._namespaced_prompt_map.clear() + self._namespaced_prompt_map.update(response.namespaced_prompt_map) + + self._server_to_prompt_map.clear() + self._server_to_prompt_map.update(response.server_to_prompt_map) + + self._namespaced_resource_map.clear() + self._namespaced_resource_map.update(response.namespaced_resource_map) + + self._server_to_resource_map.clear() + self._server_to_resource_map.update(response.server_to_resource_map) + async def list_tools( self, server_name: str | None = None, @@ -508,6 +544,8 @@ async def list_tools( if not self.initialized: await self.initialize() + await self._sync_with_aggregator_state() + tracer = get_tracer(self.context) with tracer.start_as_current_span( f"{self.__class__.__name__}.{self.name}.list_tools" @@ -731,6 +769,8 @@ async def list_resources( if not self.initialized: await self.initialize() + await self._sync_with_aggregator_state() + tracer = get_tracer(self.context) with tracer.start_as_current_span( f"{self.__class__.__name__}.{self.name}.list_resources" @@ -754,6 +794,8 @@ async def read_resource(self, uri: str, server_name: str | None = None): if not self.initialized: await self.initialize() + await self._sync_with_aggregator_state() + tracer = get_tracer(self.context) with tracer.start_as_current_span( f"{self.__class__.__name__}.{self.name}.read_resource" @@ -871,6 +913,8 @@ async def list_prompts(self, server_name: str | None = None) -> ListPromptsResul if not self.initialized: await self.initialize() + await self._sync_with_aggregator_state() + tracer = get_tracer(self.context) with tracer.start_as_current_span( f"{self.__class__.__name__}.{self.name}.list_prompts" @@ -919,6 +963,8 @@ async def get_prompt( if not self.initialized: await self.initialize() + await self._sync_with_aggregator_state() + tracer = get_tracer(self.context) with tracer.start_as_current_span( f"{self.__class__.__name__}.{self.name}.get_prompt" @@ -1077,6 +1123,8 @@ async def call_tool( if not self.initialized: await self.initialize() + await self._sync_with_aggregator_state() + tracer = get_tracer(self.context) with tracer.start_as_current_span( f"{self.__class__.__name__}.{self.name}.call_tool" @@ -1194,6 +1242,14 @@ class InitAggregatorResponse(BaseModel): ) +class GetAggregatorStateRequest(BaseModel): + """ + Request to fetch the current cached state from an agent's aggregator. + """ + + agent_name: str + + class ListToolsRequest(BaseModel): """ Request to list tools for an agent. @@ -1435,6 +1491,41 @@ async def initialize_aggregator_task( server_to_resource_map=aggregator._server_to_resource_map, ) + async def get_aggregator_state_task( + self, request: GetAggregatorStateRequest + ) -> InitAggregatorResponse: + async with self._with_aggregator(request.agent_name) as aggregator: + async with aggregator._tool_map_lock: + namespaced_tool_map = dict(aggregator._namespaced_tool_map) + server_to_tool_map = { + server: list(tools) + for server, tools in aggregator._server_to_tool_map.items() + } + + async with aggregator._prompt_map_lock: + namespaced_prompt_map = dict(aggregator._namespaced_prompt_map) + server_to_prompt_map = { + server: list(prompts) + for server, prompts in aggregator._server_to_prompt_map.items() + } + + async with aggregator._resource_map_lock: + namespaced_resource_map = dict(aggregator._namespaced_resource_map) + server_to_resource_map = { + server: list(resources) + for server, resources in aggregator._server_to_resource_map.items() + } + + return InitAggregatorResponse( + initialized=aggregator.initialized, + namespaced_tool_map=namespaced_tool_map, + server_to_tool_map=server_to_tool_map, + namespaced_prompt_map=namespaced_prompt_map, + server_to_prompt_map=server_to_prompt_map, + namespaced_resource_map=namespaced_resource_map, + server_to_resource_map=server_to_resource_map, + ) + async def shutdown_aggregator_task(self, agent_name: str) -> bool: """ Shutdown the agent's servers. diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index 80a3e1f12..02269789c 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -1,9 +1,12 @@ import asyncio from typing import List, Literal, Dict, Optional, TypeVar, TYPE_CHECKING +import anyio.lowlevel + from opentelemetry import trace from pydantic import BaseModel from mcp.client.session import ClientSession +from mcp.shared.session import RequestResponder from mcp.server.lowlevel.server import Server from mcp.server.stdio import stdio_server from mcp.types import ( @@ -18,6 +21,10 @@ Tool, TextContent, Resource, + ServerNotification, + ToolListChangedNotification, + PromptListChangedNotification, + ResourceListChangedNotification, ) from mcp_agent.logging.event_progress import ProgressAction @@ -147,6 +154,16 @@ def __init__( self._server_to_resource_map: Dict[str, List[NamespacedResource]] = {} self._resource_map_lock = asyncio.Lock() + # Track message handler registration and refresh state per server + self._notification_handler_sessions: Dict[str, ClientSession] = {} + self._server_refresh_tasks: Dict[str, asyncio.Task] = {} + self._server_refresh_pending: Dict[str, bool] = {} + self._capability_list_changed_supported: Dict[str, Dict[str, bool]] = { + "tools": {}, + "prompts": {}, + "resources": {}, + } + async def initialize(self, force: bool = False): """Initialize the application.""" tracer = get_tracer(self.context) @@ -272,6 +289,16 @@ async def close(self): # Always mark as uninitialized regardless of errors self.initialized = False + if self._server_refresh_tasks: + tasks = list(self._server_refresh_tasks.values()) + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + self._server_refresh_tasks.clear() + + self._server_refresh_pending.clear() + self._notification_handler_sessions.clear() + @classmethod async def create( cls, @@ -509,14 +536,28 @@ async def load_servers(self, force: bool = False): self.initialized = True + async def _get_persistent_server_connection(self, server_name: str): + if not self._persistent_connection_manager: + raise RuntimeError( + "Persistent connection manager is not available for this aggregator" + ) + + server_conn = await self._persistent_connection_manager.get_server( + server_name, client_session_factory=MCPAgentClientSession + ) + + session = server_conn.session + if session is not None: + self._ensure_notification_handler(server_name, session) + + return server_conn + async def get_server(self, server_name: str) -> Optional[ClientSession]: """Get a server connection if available.""" if self.connection_persistence: try: - server_conn = await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) + server_conn = await self._get_persistent_server_connection(server_name) return server_conn.session except Exception as e: logger.warning( @@ -537,6 +578,151 @@ async def get_server(self, server_name: str) -> Optional[ClientSession]: ) as client: return client + def _ensure_notification_handler( + self, server_name: str, session: ClientSession + ) -> None: + if session is None: + return + + existing = self._notification_handler_sessions.get(server_name) + if existing is session: + return + + original_handler = getattr(session, "_message_handler", None) + + async def downstream_handler(message): + if original_handler is not None: + try: + await original_handler(message) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + f"Error in original message handler for '{server_name}': {exc}", + exc_info=True, + ) + else: + await anyio.lowlevel.checkpoint() + + async def message_handler(message): + try: + await self._handle_incoming_server_message(server_name, message) + except Exception as exc: # pragma: no cover - defensive + logger.warning( + f"Error handling notification from server '{server_name}': {exc}", + exc_info=True, + ) + + await downstream_handler(message) + + # Replace the session's message handler so we can observe notifications + setattr(session, "_message_handler", message_handler) + self._notification_handler_sessions[server_name] = session + + async def _handle_incoming_server_message( + self, + server_name: str, + message: RequestResponder | ServerNotification | Exception, + ) -> None: + if isinstance(message, RequestResponder): + return + + if isinstance(message, Exception): + logger.debug( + f"Server '{server_name}' raised exception in message stream: {message}" + ) + return + + if not isinstance(message, ServerNotification): + return + + root = message.root + method = getattr(root, "method", None) + + capability = None + if method == "notifications/tools/list_changed" or isinstance( + root, ToolListChangedNotification + ): + capability = "tools" + elif method == "notifications/prompts/list_changed" or isinstance( + root, PromptListChangedNotification + ): + capability = "prompts" + elif method == "notifications/resources/list_changed" or isinstance( + root, ResourceListChangedNotification + ): + capability = "resources" + + if capability is None: + return + + if not self.connection_persistence: + logger.debug( + f"Received {capability} list_changed without persistent connections; ignoring", + data={"server_name": server_name}, + ) + return + + supports = self._capability_list_changed_supported.get(capability, {}).get( + server_name + ) + if supports is False: + logger.debug( + f"Server reported {capability} list_changed but capability not advertised; skipping", + data={"server_name": server_name}, + ) + return + + self._schedule_server_refresh(server_name) + + def _schedule_server_refresh(self, server_name: str) -> None: + existing_task = self._server_refresh_tasks.get(server_name) + if existing_task and not existing_task.done(): + self._server_refresh_pending[server_name] = True + return + + self._server_refresh_pending.pop(server_name, None) + self._server_refresh_tasks[server_name] = asyncio.create_task( + self._refresh_server(server_name) + ) + + async def _refresh_server(self, server_name: str) -> None: + try: + await self.load_server(server_name) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning( + f"Failed to refresh server '{server_name}' after list_changed notification: {exc}", + exc_info=True, + ) + finally: + if self._server_refresh_pending.pop(server_name, False): + self._server_refresh_tasks[server_name] = asyncio.create_task( + self._refresh_server(server_name) + ) + else: + self._server_refresh_tasks.pop(server_name, None) + + def _update_capability_support( + self, capabilities: ServerCapabilities | dict | None, server_name: str + ) -> None: + if capabilities is None: + for mapping in self._capability_list_changed_supported.values(): + mapping.pop(server_name, None) + return + + def get_capability(cap_name: str): + if isinstance(capabilities, dict): + return capabilities.get(cap_name) + return getattr(capabilities, cap_name, None) + + for cap_name, mapping in self._capability_list_changed_supported.items(): + cap_obj = get_capability(cap_name) + if isinstance(cap_obj, dict): + supported = bool(cap_obj.get("listChanged")) + else: + supported = bool(getattr(cap_obj, "listChanged", False)) + mapping[server_name] = supported + async def get_capabilities(self, server_name: str): """Get server capabilities if available.""" tracer = get_tracer(self.context) @@ -548,8 +734,10 @@ async def get_capabilities(self, server_name: str): span.set_attribute("connection_persistence", self.connection_persistence) span.set_attribute("server_name", server_name) - def _annotate_span_for_capabilities(capabilities: ServerCapabilities): - if not self.context.tracing_enabled: + def _annotate_span_for_capabilities( + capabilities: ServerCapabilities | dict | None, + ): + if not self.context.tracing_enabled or capabilities is None: return for attr in [ @@ -559,19 +747,22 @@ def _annotate_span_for_capabilities(capabilities: ServerCapabilities): "resources", "tools", ]: - value = getattr(capabilities, attr, None) + if isinstance(capabilities, dict): + value = capabilities.get(attr) + else: + value = getattr(capabilities, attr, None) span.set_attribute( f"{server_name}.capabilities.{attr}", value is not None ) if self.connection_persistence: try: - server_conn = await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession + server_conn = await self._get_persistent_server_connection( + server_name ) - # TODO: saqadri (FA1) - verify # server_capabilities is a property, not a coroutine res = server_conn.server_capabilities + self._update_capability_support(res, server_name) _annotate_span_for_capabilities(res) return res except Exception as e: @@ -596,6 +787,7 @@ def _annotate_span_for_capabilities(capabilities: ServerCapabilities): try: initialize_result = await session.initialize() res = initialize_result.capabilities + self._update_capability_support(res, server_name) _annotate_span_for_capabilities(res) return res except Exception as e: @@ -782,9 +974,7 @@ async def try_read_resource(client: ClientSession): return ReadResourceResult(contents=[]) if self.connection_persistence: - server_conn = await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) + server_conn = await self._get_persistent_server_connection(server_name) res = await try_read_resource(server_conn.session) # TODO: jerron - annotate span for result return res @@ -912,10 +1102,8 @@ async def try_call_tool(client: ClientSession): ) if self.connection_persistence: - server_connection = ( - await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) + server_connection = await self._get_persistent_server_connection( + server_name ) res = await try_call_tool(server_connection.session) _annotate_span_for_result(res) @@ -1101,10 +1289,8 @@ async def try_get_prompt(client: ClientSession): result: GetPromptResult = GetPromptResult(messages=[]) if self.connection_persistence: - server_connection = ( - await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) + server_connection = await self._get_persistent_server_connection( + server_name ) result = await try_get_prompt(server_connection.session) else: @@ -1234,9 +1420,7 @@ async def _start_server(self, server_name: str): }, ) - server_conn = await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) + server_conn = await self._get_persistent_server_connection(server_name) logger.info( f"MCP Server initialized for agent '{self.agent_name}'", @@ -1353,14 +1537,13 @@ async def _fetch_capabilities(self, server_name: str): resources: List[Resource] = [] if self.connection_persistence: - server_connection = await self._persistent_connection_manager.get_server( - server_name, client_session_factory=MCPAgentClientSession - ) - tools = await self._fetch_tools(server_connection.session, server_name) - prompts = await self._fetch_prompts(server_connection.session, server_name) - resources = await self._fetch_resources( - server_connection.session, server_name + server_connection = await self._get_persistent_server_connection( + server_name ) + session = server_connection.session + tools = await self._fetch_tools(session, server_name) + prompts = await self._fetch_prompts(session, server_name) + resources = await self._fetch_resources(session, server_name) else: async with gen_client( server_name, server_registry=self.context.server_registry diff --git a/tests/agents/test_agent.py b/tests/agents/test_agent.py index ed9688e57..f39053e2e 100644 --- a/tests/agents/test_agent.py +++ b/tests/agents/test_agent.py @@ -531,12 +531,18 @@ async def test_call_tool_parent(self, basic_agent): server_to_prompt_map={}, ) - # Patch executor.execute to return InitAggregatorResponse for initialization, - # and CallToolResult for the tool call - def execute_side_effect(*args, **kwargs): - if not basic_agent.initialized: + # Patch executor.execute to return InitAggregatorResponse for initialization + # and state sync, then CallToolResult for the tool invocation + + async def execute_side_effect(task, *args, **kwargs): + task_name = getattr(task, "__name__", "") + if task_name == "initialize_aggregator_task": + return init_response + if task_name == "get_aggregator_state_task": return init_response - return mock_result + if task_name == "call_tool_task": + return mock_result + return init_response with patch.object( basic_agent.context.executor, diff --git a/tests/agents/test_agent_tasks_concurrency.py b/tests/agents/test_agent_tasks_concurrency.py index c5d650f9c..685af8ffa 100644 --- a/tests/agents/test_agent_tasks_concurrency.py +++ b/tests/agents/test_agent_tasks_concurrency.py @@ -1,3 +1,5 @@ +import asyncio + import anyio import pytest @@ -31,6 +33,9 @@ def __init__(self, server_names, connection_persistence, context, name): self._server_to_prompt_map = {} self._namespaced_resource_map = {} self._server_to_resource_map = {} + self._tool_map_lock = asyncio.Lock() + self._prompt_map_lock = asyncio.Lock() + self._resource_map_lock = asyncio.Lock() def set_block(self, block: bool): self._block = block diff --git a/tests/mcp/test_mcp_aggregator.py b/tests/mcp/test_mcp_aggregator.py index a592743fc..10b27b57b 100644 --- a/tests/mcp/test_mcp_aggregator.py +++ b/tests/mcp/test_mcp_aggregator.py @@ -5,7 +5,13 @@ from types import SimpleNamespace from unittest.mock import AsyncMock, patch -from mcp.types import Tool +from mcp.types import ( + Tool, + ServerNotification, + ToolListChangedNotification, + PromptListChangedNotification, + ResourceListChangedNotification, +) import src.mcp_agent.mcp.mcp_aggregator as mcp_aggregator_mod @@ -578,6 +584,8 @@ async def test_mcp_aggregator_get_capabilities(monkeypatch): aggregator.initialized = True class DummyServerConn: + session = None + @property def server_capabilities(self): return {"foo": "bar"} @@ -1274,3 +1282,110 @@ async def mock_fetch_capabilities(server_name): assert "tool_exact" in tool_names assert "tool_similar" not in tool_names assert "my_tool" not in tool_names + + +@pytest.mark.asyncio +async def test_tools_list_changed_notification_triggers_reload(): + context = SimpleNamespace( + tracer=None, + tracing_enabled=False, + server_registry=None, + _mcp_connection_manager_lock=asyncio.Lock(), + _mcp_connection_manager_ref_count=0, + ) + + aggregator = mcp_aggregator_mod.MCPAggregator( + server_names=["test_server"], + connection_persistence=True, + context=context, + name="agent", + ) + aggregator.initialized = True + aggregator._capability_list_changed_supported["tools"]["test_server"] = True + + aggregator.load_server = AsyncMock() + + notification = ServerNotification( + root=ToolListChangedNotification(method="notifications/tools/list_changed") + ) + + await aggregator._handle_incoming_server_message("test_server", notification) + + task = aggregator._server_refresh_tasks.get("test_server") + assert task is not None + + await task + + aggregator.load_server.assert_awaited_once_with("test_server") + + +@pytest.mark.asyncio +async def test_prompts_list_changed_notification_triggers_reload(): + context = SimpleNamespace( + tracer=None, + tracing_enabled=False, + server_registry=None, + _mcp_connection_manager_lock=asyncio.Lock(), + _mcp_connection_manager_ref_count=0, + ) + + aggregator = mcp_aggregator_mod.MCPAggregator( + server_names=["test_server"], + connection_persistence=True, + context=context, + name="agent", + ) + aggregator.initialized = True + aggregator._capability_list_changed_supported["prompts"]["test_server"] = True + + aggregator.load_server = AsyncMock() + + notification = ServerNotification( + root=PromptListChangedNotification(method="notifications/prompts/list_changed") + ) + + await aggregator._handle_incoming_server_message("test_server", notification) + + task = aggregator._server_refresh_tasks.get("test_server") + assert task is not None + + await task + + aggregator.load_server.assert_awaited_once_with("test_server") + + +@pytest.mark.asyncio +async def test_resources_list_changed_notification_triggers_reload(): + context = SimpleNamespace( + tracer=None, + tracing_enabled=False, + server_registry=None, + _mcp_connection_manager_lock=asyncio.Lock(), + _mcp_connection_manager_ref_count=0, + ) + + aggregator = mcp_aggregator_mod.MCPAggregator( + server_names=["test_server"], + connection_persistence=True, + context=context, + name="agent", + ) + aggregator.initialized = True + aggregator._capability_list_changed_supported["resources"]["test_server"] = True + + aggregator.load_server = AsyncMock() + + notification = ServerNotification( + root=ResourceListChangedNotification( + method="notifications/resources/list_changed" + ) + ) + + await aggregator._handle_incoming_server_message("test_server", notification) + + task = aggregator._server_refresh_tasks.get("test_server") + assert task is not None + + await task + + aggregator.load_server.assert_awaited_once_with("test_server")