Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions src/mcp_agent/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,42 @@ def _should_include_non_namespaced_tool(
# No non_namespaced_tools key and no wildcard - include by default (no filter for non-namespaced)
return True, None

async def _sync_with_aggregator_state(self) -> None:
if self._agent_tasks is None:
return

executor = self.context.executor if self.context else None
if executor is None:
return

response = await executor.execute(
self._agent_tasks.get_aggregator_state_task,
GetAggregatorStateRequest(agent_name=self.name),
)

if isinstance(response, BaseException): # pragma: no cover - defensive
raise response

self.initialized = response.initialized

self._namespaced_tool_map.clear()
self._namespaced_tool_map.update(response.namespaced_tool_map)

self._server_to_tool_map.clear()
self._server_to_tool_map.update(response.server_to_tool_map)

self._namespaced_prompt_map.clear()
self._namespaced_prompt_map.update(response.namespaced_prompt_map)

self._server_to_prompt_map.clear()
self._server_to_prompt_map.update(response.server_to_prompt_map)

self._namespaced_resource_map.clear()
self._namespaced_resource_map.update(response.namespaced_resource_map)

self._server_to_resource_map.clear()
self._server_to_resource_map.update(response.server_to_resource_map)

async def list_tools(
self,
server_name: str | None = None,
Expand All @@ -508,6 +544,8 @@ async def list_tools(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.list_tools"
Expand Down Expand Up @@ -731,6 +769,8 @@ async def list_resources(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.list_resources"
Expand All @@ -754,6 +794,8 @@ async def read_resource(self, uri: str, server_name: str | None = None):
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.read_resource"
Expand Down Expand Up @@ -871,6 +913,8 @@ async def list_prompts(self, server_name: str | None = None) -> ListPromptsResul
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.list_prompts"
Expand Down Expand Up @@ -919,6 +963,8 @@ async def get_prompt(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.get_prompt"
Expand Down Expand Up @@ -1077,6 +1123,8 @@ async def call_tool(
if not self.initialized:
await self.initialize()

await self._sync_with_aggregator_state()

tracer = get_tracer(self.context)
with tracer.start_as_current_span(
f"{self.__class__.__name__}.{self.name}.call_tool"
Expand Down Expand Up @@ -1194,6 +1242,14 @@ class InitAggregatorResponse(BaseModel):
)


class GetAggregatorStateRequest(BaseModel):
"""
Request to fetch the current cached state from an agent's aggregator.
"""

agent_name: str


class ListToolsRequest(BaseModel):
"""
Request to list tools for an agent.
Expand Down Expand Up @@ -1435,6 +1491,41 @@ async def initialize_aggregator_task(
server_to_resource_map=aggregator._server_to_resource_map,
)

async def get_aggregator_state_task(
self, request: GetAggregatorStateRequest
) -> InitAggregatorResponse:
async with self._with_aggregator(request.agent_name) as aggregator:
async with aggregator._tool_map_lock:
namespaced_tool_map = dict(aggregator._namespaced_tool_map)
server_to_tool_map = {
server: list(tools)
for server, tools in aggregator._server_to_tool_map.items()
}

async with aggregator._prompt_map_lock:
namespaced_prompt_map = dict(aggregator._namespaced_prompt_map)
server_to_prompt_map = {
server: list(prompts)
for server, prompts in aggregator._server_to_prompt_map.items()
}

async with aggregator._resource_map_lock:
namespaced_resource_map = dict(aggregator._namespaced_resource_map)
server_to_resource_map = {
server: list(resources)
for server, resources in aggregator._server_to_resource_map.items()
}

return InitAggregatorResponse(
initialized=aggregator.initialized,
namespaced_tool_map=namespaced_tool_map,
server_to_tool_map=server_to_tool_map,
namespaced_prompt_map=namespaced_prompt_map,
server_to_prompt_map=server_to_prompt_map,
namespaced_resource_map=namespaced_resource_map,
server_to_resource_map=server_to_resource_map,
)

async def shutdown_aggregator_task(self, agent_name: str) -> bool:
"""
Shutdown the agent's servers.
Expand Down
Loading
Loading