diff --git a/src/mcp_agent/agents/base_agent.py b/src/mcp_agent/agents/base_agent.py index 387de413..a60e77f6 100644 --- a/src/mcp_agent/agents/base_agent.py +++ b/src/mcp_agent/agents/base_agent.py @@ -82,6 +82,8 @@ def __init__( server_names=self.config.servers, connection_persistence=connection_persistence, name=self.config.name, + include_tools=self.config.include_tools, + exclude_tools=self.config.exclude_tools, **kwargs, ) diff --git a/src/mcp_agent/core/agent_types.py b/src/mcp_agent/core/agent_types.py index 248d7f8e..2a266a9f 100644 --- a/src/mcp_agent/core/agent_types.py +++ b/src/mcp_agent/core/agent_types.py @@ -34,6 +34,8 @@ class AgentConfig: default_request_params: RequestParams | None = None human_input: bool = False agent_type: str = AgentType.BASIC.value + include_tools: List[str] = dataclasses.field(default_factory=list) + exclude_tools: List[str] = dataclasses.field(default_factory=list) def __post_init__(self) -> None: """Ensure default_request_params exists with proper history setting""" diff --git a/src/mcp_agent/core/direct_decorators.py b/src/mcp_agent/core/direct_decorators.py index a86ae98d..5e8e5e5e 100644 --- a/src/mcp_agent/core/direct_decorators.py +++ b/src/mcp_agent/core/direct_decorators.py @@ -90,6 +90,8 @@ def _decorator_impl( use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, + include_tools: List[str] = [], + exclude_tools: List[str] = [], **extra_kwargs, ) -> Callable[[AgentCallable[P, R]], DecoratedAgentProtocol[P, R]]: """ @@ -104,6 +106,8 @@ def _decorator_impl( use_history: Whether to maintain conversation history request_params: Additional request parameters for the LLM human_input: Whether to enable human input capabilities + include_tools: Optional list of tool names. When provided, only the listed tools will be registered. Tool names must be prefixed with the server name: "{server_name}-{tool_name}". Tools in exclude_tools will still be excluded. + exclude_tools: Optional list of tool names. When provided, the listed tools will not be registered. Tool names must be prefixed with the server name: "{server_name}-{tool_name}". Excludes any tools also listed in include_tools. **extra_kwargs: Additional agent/workflow-specific parameters """ @@ -132,6 +136,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: model=model, use_history=use_history, human_input=human_input, + include_tools=include_tools, + exclude_tools=exclude_tools, ) # Update request params if provided @@ -174,6 +180,8 @@ def agent( use_history: bool = True, request_params: RequestParams | None = None, human_input: bool = False, + include_tools: List[str] = [], + exclude_tools: List[str] = [], ) -> Callable[[AgentCallable[P, R]], DecoratedAgentProtocol[P, R]]: """ Decorator to create and register a standard agent with type-safe signature. @@ -187,6 +195,8 @@ def agent( use_history: Whether to maintain conversation history request_params: Additional request parameters for the LLM human_input: Whether to enable human input capabilities + include_tools: List of tool names (including server name) to include (if empty, include all tools) + exclude_tools: List of tool names (including server name) to exclude (if empty, exclude none) Returns: A decorator that registers the agent with proper type annotations @@ -203,6 +213,8 @@ def agent( use_history=use_history, request_params=request_params, human_input=human_input, + include_tools=include_tools, + exclude_tools=exclude_tools, ) diff --git a/src/mcp_agent/mcp/mcp_aggregator.py b/src/mcp_agent/mcp/mcp_aggregator.py index 2774996b..b74b7514 100644 --- a/src/mcp_agent/mcp/mcp_aggregator.py +++ b/src/mcp_agent/mcp/mcp_aggregator.py @@ -97,11 +97,17 @@ def __init__( connection_persistence: bool = True, # Default to True for better stability context: Optional["Context"] = None, name: str = None, + include_tools: List[str] = None, + exclude_tools: List[str] = None, **kwargs, ) -> None: """ :param server_names: A list of server names to connect to. :param connection_persistence: Whether to maintain persistent connections to servers (default: True). + :param context: Optional Context instance. + :param name: Optional name for this aggregator. + :param include_tools: Optional list of tool names. When provided, only the listed tools will be registered. Tool names must be prefixed with the server name: "{server_name}-{tool_name}". Tools in exclude_tools will still be excluded. + :param exclude_tools: Optional list of tool names. When provided, the listed tools will not be registered. Tool names must be prefixed with the server name: "{server_name}-{tool_name}". Excludes any tools also listed in include_tools. Note: The server names must be resolvable by the gen_client function, and specified in the server registry. """ super().__init__( @@ -112,6 +118,8 @@ def __init__( self.server_names = server_names self.connection_persistence = connection_persistence self.agent_name = name + self.include_tools = include_tools or [] + self.exclude_tools = exclude_tools or [] self._persistent_connection_manager: MCPConnectionManager = None # Set up logger with agent name in namespace if available @@ -276,16 +284,31 @@ async def load_server_data(server_name: str): # Process tools self._server_to_tool_map[server_name] = [] + + # Filter and register tools in a single loop + filtered_tools_count = 0 for tool in tools: namespaced_tool_name = f"{server_name}{SEP}{tool.name}" + + # Apply exclusion filter + if self.exclude_tools and namespaced_tool_name in self.exclude_tools: + logger.debug(f"Excluding namespaced tool '{namespaced_tool_name}'") + continue + + # Apply inclusion filter if specified + if self.include_tools and namespaced_tool_name not in self.include_tools: + logger.debug(f"Skipping tool '{namespaced_tool_name}' not in include list") + continue + + # Create and register the tool that passed the filters namespaced_tool = NamespacedTool( tool=tool, server_name=server_name, namespaced_tool_name=namespaced_tool_name, ) - self._namespaced_tool_map[namespaced_tool_name] = namespaced_tool self._server_to_tool_map[server_name].append(namespaced_tool) + filtered_tools_count += 1 # Process prompts async with self._prompt_cache_lock: @@ -297,7 +320,7 @@ async def load_server_data(server_name: str): "progress_action": ProgressAction.INITIALIZED, "server_name": server_name, "agent_name": self.agent_name, - "tool_count": len(tools), + "tool_count": filtered_tools_count, "prompt_count": len(prompts), }, )