diff --git a/src/mcp_agent/agents/agent.py b/src/mcp_agent/agents/agent.py index e55289d82..ecbb734ff 100644 --- a/src/mcp_agent/agents/agent.py +++ b/src/mcp_agent/agents/agent.py @@ -1,7 +1,7 @@ import asyncio import json import uuid -from typing import Callable, Dict, List, Optional, TypeVar, TYPE_CHECKING, Any +from typing import Callable, Dict, List, Optional, Set, TypeVar, TYPE_CHECKING, Any from contextlib import asynccontextmanager from opentelemetry import trace @@ -459,7 +459,48 @@ async def get_server_session(self, server_name: str): return result - async def list_tools(self, server_name: str | None = None) -> ListToolsResult: + def _should_include_non_namespaced_tool( + self, tool_name: str, tool_filter: Dict[str, Set[str]] | None + ) -> tuple[bool, str | None]: + """ + Determine if a non-namespaced tool (function tool or human input) should be included. + + Uses the special reserved key "non_namespaced_tools" to filter function tools and human input. + + Returns: (should_include, filter_reason) + - filter_reason is None if tool should be included, otherwise explains why filtered + """ + if tool_filter is None: + return True, None + + # Priority 1: Check non_namespaced_tools key (explicitly for non-namespaced tools) + if "non_namespaced_tools" in tool_filter: + if tool_name in tool_filter["non_namespaced_tools"]: + return True, None + else: + return False, f"{tool_name} not in tool_filter[non_namespaced_tools]" + + # Priority 2: Check wildcard filter + elif "*" in tool_filter: + if tool_name in tool_filter["*"]: + return True, None + else: + return False, f"{tool_name} not in tool_filter[*]" + + # No non_namespaced_tools key and no wildcard - include by default (no filter for non-namespaced) + return True, None + + async def list_tools(self, server_name: str | None = None, tool_filter: Dict[str, Set[str]] | None = None) -> ListToolsResult: + """ + List available tools with optional filtering. + + Args: + server_name: Optional specific server to list tools from + tool_filter: Optional dict mapping server names to sets of allowed tool names. + Special reserved keys: + - "*": Wildcard filter for servers without explicit filters + - "non_namespaced_tools": Filter for non-namespaced tools (function tools, human input) + """ if not self.initialized: await self.initialize() @@ -473,38 +514,112 @@ async def list_tools(self, server_name: str | None = None) -> ListToolsResult: "human_input_callback", self.human_input_callback is not None ) + # Track filtered tools for debugging and telemetry + filtered_out_tools = [] # List of (tool_name, reason) tuples + if server_name: span.set_attribute("server_name", server_name) - result = ListToolsResult( - tools=[ - namespaced_tool.tool.model_copy( - update={"name": namespaced_tool.namespaced_tool_name} - ) - for namespaced_tool in self._server_to_tool_map.get( - server_name, [] - ) - ] - ) + # Get tools for specific server + server_tools = self._server_to_tool_map.get(server_name, []) + + # Check if we should apply filtering for this specific server + if tool_filter is not None and server_name in tool_filter: + # Server is explicitly in filter dict - apply its filter rules + # If tool_filter[server_name] is empty set, no tools will pass + # If tool_filter[server_name] has tools, only those will pass + allowed_tools = tool_filter[server_name] + result_tools = [] + for namespaced_tool in server_tools: + if namespaced_tool.tool.name in allowed_tools: + result_tools.append( + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool.namespaced_tool_name} + ) + ) + else: + filtered_out_tools.append( + (namespaced_tool.namespaced_tool_name, f"Not in tool_filter[{server_name}]") + ) + result = ListToolsResult(tools=result_tools) + else: + # Either no filter at all (tool_filter is None) or + # this server is not in the filter dict (no filtering for this server) + # Include all tools from this server + result = ListToolsResult( + tools=[ + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool.namespaced_tool_name} + ) + for namespaced_tool in server_tools + ] + ) else: - result = ListToolsResult( - tools=[ - namespaced_tool.tool.model_copy( - update={"name": namespaced_tool_name} - ) - for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items() - ] - ) + # No specific server requested - get tools from all servers + if tool_filter is not None: + # Filter is active - check each tool's server against filter rules + filtered_tools = [] + for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items(): + should_include = False + + # Priority 1: Check if tool's server has explicit filter rules + if namespaced_tool.server_name in tool_filter: + # Server has explicit filter - tool must be in the allowed set + if namespaced_tool.tool.name in tool_filter[namespaced_tool.server_name]: + should_include = True + else: + filtered_out_tools.append( + (namespaced_tool_name, f"Not in tool_filter[{namespaced_tool.server_name}]") + ) + # Priority 2: If no server-specific filter, check wildcard + elif "*" in tool_filter: + # Wildcard filter applies to servers without explicit filters + if namespaced_tool.tool.name in tool_filter["*"]: + should_include = True + else: + filtered_out_tools.append( + (namespaced_tool_name, "Not in tool_filter[*]") + ) + else: + # No explicit filter for this server and no wildcard + # Default behavior: include the tool (no filtering) + should_include = True + + if should_include: + filtered_tools.append( + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool_name} + ) + ) + result = ListToolsResult(tools=filtered_tools) + else: + # No filter at all - include everything + result = ListToolsResult( + tools=[ + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool_name} + ) + for namespaced_tool_name, namespaced_tool in self._namespaced_tool_map.items() + ] + ) - # Add function tools + # Add function tools (non-namespaced) with filtering + # These use the special "non_namespaced_tools" key in tool_filter for tool in self._function_tool_map.values(): - result.tools.append( - Tool( - name=tool.name, - description=tool.description, - inputSchema=tool.parameters, - ) + should_include, filter_reason = self._should_include_non_namespaced_tool( + tool.name, tool_filter ) + if should_include: + result.tools.append( + Tool( + name=tool.name, + description=tool.description, + inputSchema=tool.parameters, + ) + ) + elif filter_reason: + filtered_out_tools.append((tool.name, filter_reason)) + def _annotate_span_for_tools_result(result: ListToolsResult): if not self.context.tracing_enabled: return @@ -529,24 +644,59 @@ def _annotate_span_for_tools_result(result: ListToolsResult): f"tool.{tool.name}.annotations.{attr}", value ) - # Add a human_input_callback as a tool - if not self.human_input_callback: - logger.debug("Human input callback not set") - _annotate_span_for_tools_result(result) + # Add human_input_callback tool (non-namespaced) with filtering + # This uses the special "non_namespaced_tools" key in tool_filter + if self.human_input_callback: + should_include, filter_reason = self._should_include_non_namespaced_tool( + HUMAN_INPUT_TOOL_NAME, tool_filter + ) - return result + if should_include: + human_input_tool: FastTool = FastTool.from_function( + self.request_human_input + ) + result.tools.append( + Tool( + name=HUMAN_INPUT_TOOL_NAME, + description=human_input_tool.description, + inputSchema=human_input_tool.parameters, + ) + ) + elif filter_reason: + filtered_out_tools.append((HUMAN_INPUT_TOOL_NAME, filter_reason)) + else: + logger.debug("Human input callback not set") - # Add a human_input_callback as a tool - human_input_tool: FastTool = FastTool.from_function( - self.request_human_input - ) - result.tools.append( - Tool( - name=HUMAN_INPUT_TOOL_NAME, - description=human_input_tool.description, - inputSchema=human_input_tool.parameters, - ) - ) + # Log and track filtering metrics if filter was applied + if tool_filter is not None: + span.set_attribute("tool_filter_applied", True) + span.set_attribute("tools_included_count", len(result.tools)) + span.set_attribute("tools_filtered_out_count", len(filtered_out_tools)) + + # Add telemetry for filtered tools (limit to first 20 to avoid span bloat) + if self.context.tracing_enabled: + for i, (tool_name, reason) in enumerate(filtered_out_tools[:20]): + span.set_attribute(f"filtered_tool.{i}.name", tool_name) + span.set_attribute(f"filtered_tool.{i}.reason", reason) + if len(filtered_out_tools) > 20: + span.set_attribute("filtered_tools_truncated", True) + + # Log filtered tools for debugging + if filtered_out_tools: + logger.debug( + f"Tool filter applied: {len(filtered_out_tools)} tools filtered out, " + f"{len(result.tools)} tools remaining. " + f"Filtered tools: {[name for name, _ in filtered_out_tools[:10]]}" + + ("..." if len(filtered_out_tools) > 10 else "") + ) + # Log detailed reasons at trace level (if trace logging is available) + if logger.isEnabledFor(10): # TRACE level is usually 10 + for tool_name, reason in filtered_out_tools: + logger.log(10, f"Filtered out '{tool_name}': {reason}") + else: + logger.debug( + f"Tool filter applied: All {len(result.tools)} tools passed the filter" + ) _annotate_span_for_tools_result(result) diff --git a/src/mcp_agent/workflows/llm/augmented_llm.py b/src/mcp_agent/workflows/llm/augmented_llm.py index cbb179929..33112af5c 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm.py +++ b/src/mcp_agent/workflows/llm/augmented_llm.py @@ -2,10 +2,12 @@ from typing import ( Any, + Dict, Generic, List, Optional, Protocol, + Set, Type, TypeVar, Union, @@ -173,6 +175,26 @@ class RequestParams(CreateMessageRequestParams): Whether models that support strict mode should strictly enforce the response schema. """ + tool_filter: Dict[str, Set[str]] | None = None + """ + Mapping of server names to sets of allowed tool names for this request. + If specified, only these tools will be exposed to the LLM for each server. + This overrides the server-level allowed_tools configuration. + + Special reserved keys: + - "*": Wildcard filter for servers without explicit filters + - "non_namespaced_tools": Filter for non-namespaced tools (function tools, human input) + + Examples: + - {"server1": {"tool1", "tool2"}} - Allow specific tools from server1 + - {"*": {"tool1"}} - Allow tool1 from all servers without explicit filters + - {"non_namespaced_tools": {"human_input", "func1"}} - Allow specific non-namespaced tools + - {} - No tools allowed from any server + - None - No filtering applied (default behavior) + + Tool names should match exactly as they appear in the server's tool list. + """ + class AugmentedLLMProtocol(Protocol, Generic[MessageParamT, MessageT]): """Protocol defining the interface for augmented LLMs""" @@ -538,9 +560,9 @@ async def call_tool( ], ) - async def list_tools(self, server_name: str | None = None) -> ListToolsResult: + async def list_tools(self, server_name: str | None = None, tool_filter: Dict[str, Set[str]] | None = None) -> ListToolsResult: """Call the underlying agent's list_tools method for a given server.""" - return await self.agent.list_tools(server_name=server_name) + return await self.agent.list_tools(server_name=server_name, tool_filter=tool_filter) async def list_resources( self, server_name: str | None = None diff --git a/src/mcp_agent/workflows/llm/augmented_llm_anthropic.py b/src/mcp_agent/workflows/llm/augmented_llm_anthropic.py index e9a9f9a7d..14f582960 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm_anthropic.py +++ b/src/mcp_agent/workflows/llm/augmented_llm_anthropic.py @@ -180,7 +180,7 @@ async def generate( AnthropicConverter.convert_mixed_messages_to_anthropic(message) ) - list_tools_result = await self.agent.list_tools() + list_tools_result = await self.agent.list_tools(tool_filter=params.tool_filter) available_tools: List[ToolParam] = [ { "name": tool.name, diff --git a/src/mcp_agent/workflows/llm/augmented_llm_azure.py b/src/mcp_agent/workflows/llm/augmented_llm_azure.py index 98b6496d9..6c39ad168 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm_azure.py +++ b/src/mcp_agent/workflows/llm/augmented_llm_azure.py @@ -160,7 +160,7 @@ async def generate(self, message, request_params: RequestParams | None = None): messages.extend(AzureConverter.convert_mixed_messages_to_azure(message)) - response = await self.agent.list_tools() + response = await self.agent.list_tools(tool_filter=params.tool_filter) tools: list[ChatCompletionsToolDefinition] = [ ChatCompletionsToolDefinition( diff --git a/src/mcp_agent/workflows/llm/augmented_llm_bedrock.py b/src/mcp_agent/workflows/llm/augmented_llm_bedrock.py index 086e20406..a3b385b9e 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm_bedrock.py +++ b/src/mcp_agent/workflows/llm/augmented_llm_bedrock.py @@ -107,7 +107,7 @@ async def generate(self, message, request_params: RequestParams | None = None): messages.extend(BedrockConverter.convert_mixed_messages_to_bedrock(message)) - response = await self.agent.list_tools() + response = await self.agent.list_tools(tool_filter=params.tool_filter) tool_config: ToolConfigurationTypeDef = { "tools": [ diff --git a/src/mcp_agent/workflows/llm/augmented_llm_google.py b/src/mcp_agent/workflows/llm/augmented_llm_google.py index 1964dc9cd..10d516a29 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm_google.py +++ b/src/mcp_agent/workflows/llm/augmented_llm_google.py @@ -89,7 +89,7 @@ async def generate(self, message, request_params: RequestParams | None = None): messages.extend(GoogleConverter.convert_mixed_messages_to_google(message)) - response = await self.agent.list_tools() + response = await self.agent.list_tools(tool_filter=params.tool_filter) tools = [ types.Tool( diff --git a/src/mcp_agent/workflows/llm/augmented_llm_openai.py b/src/mcp_agent/workflows/llm/augmented_llm_openai.py index efdb86665..178c2da3d 100644 --- a/src/mcp_agent/workflows/llm/augmented_llm_openai.py +++ b/src/mcp_agent/workflows/llm/augmented_llm_openai.py @@ -189,7 +189,7 @@ async def generate( ) messages.extend((OpenAIConverter.convert_mixed_messages_to_openai(message))) - response: ListToolsResult = await self.agent.list_tools() + response: ListToolsResult = await self.agent.list_tools(tool_filter=params.tool_filter) available_tools: List[ChatCompletionToolParam] = [ ChatCompletionToolParam( type="function", diff --git a/tests/workflows/llm/test_request_params_tool_filter.py b/tests/workflows/llm/test_request_params_tool_filter.py new file mode 100644 index 000000000..776c0fb1a --- /dev/null +++ b/tests/workflows/llm/test_request_params_tool_filter.py @@ -0,0 +1,655 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock + +from mcp.types import Tool, ListToolsResult + +from mcp_agent.workflows.llm.augmented_llm import RequestParams +from mcp_agent.agents.agent import Agent +from mcp_agent.mcp.mcp_aggregator import NamespacedTool +from mcp_agent.core.context import Context + + +class TestRequestParamsToolFilter: + """Test cases for RequestParams tool_filter backward compatibility and functionality.""" + + def test_request_params_default_tool_filter_is_none(self): + """Test that RequestParams has tool_filter defaulting to None for backward compatibility.""" + # Create RequestParams without specifying tool_filter + params = RequestParams() + + # Should default to None + assert params.tool_filter is None + + def test_request_params_accepts_dict_tool_filter(self): + """Test that RequestParams accepts Dict[str, Set[str]] tool_filter.""" + tool_filter = {"server1": {"tool1", "tool2"}, "server2": {"tool3"}} + params = RequestParams(tool_filter=tool_filter) + + assert params.tool_filter == tool_filter + + def test_wildcard_filter(self): + """Test wildcard '*' key in tool_filter.""" + tool_filter = {"*": {"tool1", "tool2"}} + params = RequestParams(tool_filter=tool_filter) + assert params.tool_filter == tool_filter + + def test_non_namespaced_tools_key(self): + """Test non_namespaced_tools key for filtering non-namespaced tools.""" + tool_filter = {"non_namespaced_tools": {"human_input", "function_tool1"}} + params = RequestParams(tool_filter=tool_filter) + assert params.tool_filter == tool_filter + + def test_empty_set_filters_all_tools(self): + """Test that empty set filters out all tools for a server.""" + tool_filter = {"server1": set()} + params = RequestParams(tool_filter=tool_filter) + assert params.tool_filter["server1"] == set() + + def test_request_params_existing_fields_unchanged(self): + """Test that existing RequestParams fields work as before.""" + # Test existing parameters work unchanged + params = RequestParams( + maxTokens=1000, + model="test-model", + use_history=False, + max_iterations=5, + parallel_tool_calls=True, + temperature=0.5, + user="test-user", + strict=True + ) + + # All existing fields should work + assert params.maxTokens == 1000 + assert params.model == "test-model" + assert params.use_history is False + assert params.max_iterations == 5 + assert params.parallel_tool_calls is True + assert params.temperature == 0.5 + assert params.user == "test-user" + assert params.strict is True + # New field should default to None + assert params.tool_filter is None + + def test_request_params_with_mixed_parameters(self): + """Test RequestParams with both old and new parameters.""" + tool_filter = {"server1": {"tool1"}} + params = RequestParams( + maxTokens=2048, + tool_filter=tool_filter, + temperature=0.8 + ) + + assert params.maxTokens == 2048 + assert params.tool_filter == tool_filter + assert params.temperature == 0.8 + + def test_request_params_model_dump_includes_tool_filter(self): + """Test that model_dump includes tool_filter when set.""" + tool_filter = {"server1": {"tool1", "tool2"}} + params = RequestParams(tool_filter=tool_filter) + + dumped = params.model_dump() + assert "tool_filter" in dumped + assert dumped["tool_filter"] == tool_filter + + def test_request_params_model_dump_excludes_unset_tool_filter(self): + """Test that model_dump with exclude_unset=True handles tool_filter correctly.""" + # When tool_filter is not set + params1 = RequestParams(maxTokens=1000) + dumped1 = params1.model_dump(exclude_unset=True) + # tool_filter should not be in dumped output if not set + assert "tool_filter" not in dumped1 or dumped1.get("tool_filter") is None + + # When tool_filter is explicitly set + params2 = RequestParams(maxTokens=1000, tool_filter={"server1": {"tool1"}}) + dumped2 = params2.model_dump(exclude_unset=True) + assert "tool_filter" in dumped2 + assert dumped2["tool_filter"] == {"server1": {"tool1"}} + + +class TestAgentToolFilteringWithServer: + """Test cases when server_name is provided to list_tools.""" + + @pytest.fixture + def mock_agent_with_tools(self): + """Create a mock agent with test data.""" + agent = MagicMock(spec=Agent) + agent.initialized = True + agent.context = MagicMock(spec=Context) + agent.context.tracing_enabled = False + + # Setup server tools + agent._server_to_tool_map = { + "server1": [ + NamespacedTool( + tool=Tool(name="tool1", description="Tool 1", inputSchema={}), + server_name="server1", + namespaced_tool_name="server1:tool1" + ), + NamespacedTool( + tool=Tool(name="tool2", description="Tool 2", inputSchema={}), + server_name="server1", + namespaced_tool_name="server1:tool2" + ), + NamespacedTool( + tool=Tool(name="tool3", description="Tool 3", inputSchema={}), + server_name="server1", + namespaced_tool_name="server1:tool3" + ), + ], + "server2": [ + NamespacedTool( + tool=Tool(name="tool1", description="Tool 1", inputSchema={}), + server_name="server2", + namespaced_tool_name="server2:tool1" + ), + NamespacedTool( + tool=Tool(name="tool4", description="Tool 4", inputSchema={}), + server_name="server2", + namespaced_tool_name="server2:tool4" + ), + ], + } + + # Setup function tools + agent._function_tool_map = {} + agent.human_input_callback = None + + return agent + + @pytest.mark.asyncio + async def test_no_filter_includes_all_tools(self, mock_agent_with_tools): + """Test: tool_filter is None → No filtering, include all tools.""" + result = await self._apply_list_tools_logic( + mock_agent_with_tools, + server_name="server1", + tool_filter=None + ) + + assert len(result.tools) == 3 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server1:tool1", "server1:tool2", "server1:tool3"} + + @pytest.mark.asyncio + async def test_server_not_in_filter_includes_all_tools(self, mock_agent_with_tools): + """Test: server_name not in tool_filter → No filtering for this server.""" + result = await self._apply_list_tools_logic( + mock_agent_with_tools, + server_name="server2", + tool_filter={"server1": {"tool1"}} # server2 not in filter + ) + + assert len(result.tools) == 2 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server2:tool1", "server2:tool4"} + + @pytest.mark.asyncio + async def test_empty_set_filters_all_tools(self, mock_agent_with_tools): + """Test: tool_filter[server_name] = set() → Filter all tools out.""" + result = await self._apply_list_tools_logic( + mock_agent_with_tools, + server_name="server1", + tool_filter={"server1": set()} + ) + + assert len(result.tools) == 0 + + @pytest.mark.asyncio + async def test_specific_tools_filter(self, mock_agent_with_tools): + """Test: tool_filter[server_name] = {"tool1", "tool2"} → Only include those tools.""" + result = await self._apply_list_tools_logic( + mock_agent_with_tools, + server_name="server1", + tool_filter={"server1": {"tool1", "tool3"}} + ) + + assert len(result.tools) == 2 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server1:tool1", "server1:tool3"} + + async def _apply_list_tools_logic(self, agent, server_name, tool_filter): + """Apply the actual list_tools filtering logic.""" + filtered_out_tools = [] + + if server_name: + server_tools = agent._server_to_tool_map.get(server_name, []) + + if tool_filter is not None and server_name in tool_filter: + allowed_tools = tool_filter[server_name] + result_tools = [] + for namespaced_tool in server_tools: + if namespaced_tool.tool.name in allowed_tools: + result_tools.append( + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool.namespaced_tool_name} + ) + ) + else: + filtered_out_tools.append( + (namespaced_tool.namespaced_tool_name, f"Not in tool_filter[{server_name}]") + ) + result = ListToolsResult(tools=result_tools) + else: + result = ListToolsResult( + tools=[ + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool.namespaced_tool_name} + ) + for namespaced_tool in server_tools + ] + ) + + return result + + +class TestAgentToolFilteringAllServers: + """Test cases when server_name is NOT provided (listing all tools).""" + + @pytest.fixture + def mock_agent_all_servers(self): + """Create a mock agent with test data.""" + agent = MagicMock(spec=Agent) + agent.initialized = True + agent.context = MagicMock(spec=Context) + agent.context.tracing_enabled = False + + # Setup namespaced tool map + agent._namespaced_tool_map = { + "server1:tool1": NamespacedTool( + tool=Tool(name="tool1", description="Tool 1", inputSchema={}), + server_name="server1", + namespaced_tool_name="server1:tool1" + ), + "server1:tool2": NamespacedTool( + tool=Tool(name="tool2", description="Tool 2", inputSchema={}), + server_name="server1", + namespaced_tool_name="server1:tool2" + ), + "server2:tool1": NamespacedTool( + tool=Tool(name="tool1", description="Tool 1", inputSchema={}), + server_name="server2", + namespaced_tool_name="server2:tool1" + ), + "server2:tool3": NamespacedTool( + tool=Tool(name="tool3", description="Tool 3", inputSchema={}), + server_name="server2", + namespaced_tool_name="server2:tool3" + ), + "server3:tool4": NamespacedTool( + tool=Tool(name="tool4", description="Tool 4", inputSchema={}), + server_name="server3", + namespaced_tool_name="server3:tool4" + ), + } + + agent._function_tool_map = {} + agent.human_input_callback = None + + return agent + + @pytest.mark.asyncio + async def test_server_in_filter_applies_filter(self, mock_agent_all_servers): + """Test: X in tool_filter → Apply filter for server X.""" + result = await self._apply_list_tools_logic_all_servers( + mock_agent_all_servers, + tool_filter={"server1": {"tool1"}, "server2": {"tool3"}} + ) + + # server1: only tool1, server2: only tool3, server3: all tools (no filter) + assert len(result.tools) == 3 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server1:tool1", "server2:tool3", "server3:tool4"} + + @pytest.mark.asyncio + async def test_wildcard_applies_to_unfiltered_servers(self, mock_agent_all_servers): + """Test: X not in tool_filter and '*' in tool_filter → Apply wildcard filter.""" + result = await self._apply_list_tools_logic_all_servers( + mock_agent_all_servers, + tool_filter={ + "server1": {"tool1"}, # Explicit filter for server1 + "*": {"tool3", "tool4"} # Wildcard for others + } + ) + + # server1: only tool1 (explicit filter) + # server2: only tool3 (from wildcard) + # server3: only tool4 (from wildcard) + assert len(result.tools) == 3 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server1:tool1", "server2:tool3", "server3:tool4"} + + @pytest.mark.asyncio + async def test_no_filter_no_wildcard_includes_tool(self, mock_agent_all_servers): + """Test: X not in tool_filter and '*' not in tool_filter → Include tool (no filter).""" + result = await self._apply_list_tools_logic_all_servers( + mock_agent_all_servers, + tool_filter={"server1": {"tool1"}} # Only server1 has filter + ) + + # server1: only tool1 (explicit filter) + # server2: all tools (no filter) + # server3: all tools (no filter) + assert len(result.tools) == 4 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server1:tool1", "server2:tool1", "server2:tool3", "server3:tool4"} + + @pytest.mark.asyncio + async def test_empty_filter_dict_includes_all(self, mock_agent_all_servers): + """Test: tool_filter = {} → All tools included (no explicit filters defined).""" + result = await self._apply_list_tools_logic_all_servers( + mock_agent_all_servers, + tool_filter={} + ) + + # Empty dict means no explicit filters are defined + # Since no server is explicitly listed and there's no wildcard, + # the logic falls through to include all tools by default + assert len(result.tools) == 5 # All 5 tools from the fixture should be included + + @pytest.mark.asyncio + async def test_wildcard_only_filter(self, mock_agent_all_servers): + """Test: Only wildcard filter applies to all servers.""" + result = await self._apply_list_tools_logic_all_servers( + mock_agent_all_servers, + tool_filter={"*": {"tool1"}} + ) + + # All servers should only include tool1 + assert len(result.tools) == 2 + tool_names = {tool.name for tool in result.tools} + assert tool_names == {"server1:tool1", "server2:tool1"} + + @pytest.mark.asyncio + async def test_block_all_tools_with_wildcard_empty_set(self, mock_agent_all_servers): + """Test: Use wildcard with empty set to block all tools.""" + result = await self._apply_list_tools_logic_all_servers( + mock_agent_all_servers, + tool_filter={"*": set()} + ) + + # Wildcard with empty set blocks all tools from all servers + assert len(result.tools) == 0 + + async def _apply_list_tools_logic_all_servers(self, agent, tool_filter): + """Apply the actual list_tools filtering logic for all servers.""" + filtered_out_tools = [] + + if tool_filter is not None: + filtered_tools = [] + for namespaced_tool_name, namespaced_tool in agent._namespaced_tool_map.items(): + should_include = False + + if namespaced_tool.server_name in tool_filter: + if namespaced_tool.tool.name in tool_filter[namespaced_tool.server_name]: + should_include = True + else: + filtered_out_tools.append( + (namespaced_tool_name, f"Not in tool_filter[{namespaced_tool.server_name}]") + ) + elif "*" in tool_filter: + if namespaced_tool.tool.name in tool_filter["*"]: + should_include = True + else: + filtered_out_tools.append( + (namespaced_tool_name, "Not in tool_filter[*]") + ) + else: + should_include = True + + if should_include: + filtered_tools.append( + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool_name} + ) + ) + result = ListToolsResult(tools=filtered_tools) + else: + result = ListToolsResult( + tools=[ + namespaced_tool.tool.model_copy( + update={"name": namespaced_tool_name} + ) + for namespaced_tool_name, namespaced_tool in agent._namespaced_tool_map.items() + ] + ) + + return result + + +class TestNonNamespacedToolFiltering: + """Test filtering of function tools and human input tools.""" + + def test_non_namespaced_tools_key_filters(self): + """Test: non_namespaced_tools key filters function tools and human input.""" + from mcp_agent.agents.agent import Agent + + agent = MagicMock(spec=Agent) + agent._should_include_non_namespaced_tool = Agent._should_include_non_namespaced_tool.__get__(agent) + + # Test inclusion with non_namespaced_tools key + should_include, reason = agent._should_include_non_namespaced_tool( + "func1", {"non_namespaced_tools": {"func1", "human_input"}} + ) + assert should_include is True + assert reason is None + + # Test exclusion with non_namespaced_tools key + should_include, reason = agent._should_include_non_namespaced_tool( + "func2", {"non_namespaced_tools": {"func1", "human_input"}} + ) + assert should_include is False + assert "not in tool_filter[non_namespaced_tools]" in reason + + def test_wildcard_filters_non_namespaced(self): + """Test: Wildcard filters non-namespaced tools when no non_namespaced_tools key.""" + from mcp_agent.agents.agent import Agent + + agent = MagicMock(spec=Agent) + agent._should_include_non_namespaced_tool = Agent._should_include_non_namespaced_tool.__get__(agent) + + should_include, reason = agent._should_include_non_namespaced_tool( + "func1", {"*": {"func1", "human_input"}} + ) + assert should_include is True + + should_include, reason = agent._should_include_non_namespaced_tool( + "func2", {"*": {"func1", "human_input"}} + ) + assert should_include is False + assert "not in tool_filter[*]" in reason + + def test_no_filter_includes_non_namespaced(self): + """Test: No non_namespaced_tools key and no wildcard includes non-namespaced tools.""" + from mcp_agent.agents.agent import Agent + + agent = MagicMock(spec=Agent) + agent._should_include_non_namespaced_tool = Agent._should_include_non_namespaced_tool.__get__(agent) + + should_include, reason = agent._should_include_non_namespaced_tool( + "func1", {"server1": {"tool1"}} # No non_namespaced_tools key or wildcard + ) + assert should_include is True + assert reason is None + + +class TestBackwardCompatibilityIntegration: + """Integration tests to ensure existing code patterns still work.""" + + @pytest.fixture + def mock_context(self): + """Create a Context with mocked components for testing.""" + from mcp_agent.core.context import Context + + context = Context() + context.executor = AsyncMock() + context.server_registry = MagicMock() + context.tracing_enabled = False + return context + + @pytest.fixture + def mock_agent(self): + """Create a mock agent for testing.""" + agent = MagicMock() + agent.list_tools = AsyncMock(return_value=ListToolsResult(tools=[ + Tool(name="tool1", description="Tool 1", inputSchema={}), + Tool(name="tool2", description="Tool 2", inputSchema={}) + ])) + return agent + + @pytest.mark.asyncio + async def test_existing_code_without_tool_filter_still_works(self, mock_agent): + """Test that existing code calling agent.list_tools() without parameters still works.""" + # This simulates existing code that doesn't use tool_filter + result = await mock_agent.list_tools() + + assert len(result.tools) == 2 + assert result.tools[0].name == "tool1" + assert result.tools[1].name == "tool2" + + # Verify the call was made without tool_filter parameter + mock_agent.list_tools.assert_called_with() + + @pytest.mark.asyncio + async def test_existing_code_with_server_name_still_works(self, mock_agent): + """Test that existing code calling agent.list_tools(server_name) still works.""" + # This simulates existing code that uses server_name parameter + result = await mock_agent.list_tools(server_name="test_server") + + assert len(result.tools) == 2 + + # Verify the call was made with server_name but without tool_filter + mock_agent.list_tools.assert_called_with(server_name="test_server") + + def test_augmented_llm_get_request_params_backward_compatible(self, mock_context): + """Test that AugmentedLLM.get_request_params handles tool_filter correctly.""" + from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM + + # Create a mock AugmentedLLM instance + llm = MagicMock(spec=AugmentedLLM) + llm.context = mock_context + llm.default_request_params = RequestParams(maxTokens=1000) + + # Simulate the get_request_params method behavior + def mock_get_request_params(request_params=None, default=None): + default_params = default or llm.default_request_params + params = default_params.model_dump() if default_params else {} + if request_params: + params.update(request_params.model_dump(exclude_unset=True)) + return RequestParams(**params) + + llm.get_request_params = mock_get_request_params + + # Test 1: No overrides (existing behavior) + result1 = llm.get_request_params() + assert result1.maxTokens == 1000 + assert result1.tool_filter is None + + # Test 2: Override with new tool_filter + override_params = RequestParams(tool_filter={"server1": {"tool1"}}) + result2 = llm.get_request_params(request_params=override_params) + assert result2.maxTokens == 1000 # From default + assert result2.tool_filter == {"server1": {"tool1"}} # From override + + # Test 3: Override with non_namespaced_tools key + override_params3 = RequestParams(tool_filter={"non_namespaced_tools": {"human_input"}}) + result3 = llm.get_request_params(request_params=override_params3) + assert result3.tool_filter == {"non_namespaced_tools": {"human_input"}} + + # Test 3: Override with existing params only + override_params2 = RequestParams(temperature=0.9) + result4 = llm.get_request_params(request_params=override_params2) + assert result4.maxTokens == 1000 # From default + assert result4.temperature == 0.9 # From override + assert result4.tool_filter is None # Default + + @pytest.mark.asyncio + async def test_augmented_llm_list_tools_method_signature_compatible(self): + """Test that AugmentedLLM.list_tools method signature is backward compatible.""" + from mcp_agent.workflows.llm.augmented_llm import AugmentedLLM + import inspect + + # Get the method signature + sig = inspect.signature(AugmentedLLM.list_tools) + params = list(sig.parameters.keys()) + + # Should have both old and new parameters + assert "self" in params + assert "server_name" in params # Existing parameter + assert "tool_filter" in params # New parameter + + # Both should be optional (have defaults) + server_name_param = sig.parameters["server_name"] + tool_filter_param = sig.parameters["tool_filter"] + + assert server_name_param.default is None + assert tool_filter_param.default is None + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_same_tool_name_different_servers(self): + """Test that tools with same name from different servers are handled correctly.""" + agent = MagicMock(spec=Agent) + agent._namespaced_tool_map = { + "server1:tool1": NamespacedTool( + tool=Tool(name="tool1", description="Tool 1 from server1", inputSchema={}), + server_name="server1", + namespaced_tool_name="server1:tool1" + ), + "server2:tool1": NamespacedTool( + tool=Tool(name="tool1", description="Tool 1 from server2", inputSchema={}), + server_name="server2", + namespaced_tool_name="server2:tool1" + ), + } + + # Filter should work independently for each server + tool_filter = {"server1": {"tool1"}, "server2": set()} + + # server1:tool1 should be included, server2:tool1 should not + assert "server1" in tool_filter + assert "tool1" in tool_filter["server1"] + assert "server2" in tool_filter + assert len(tool_filter["server2"]) == 0 + + def test_server_not_in_map(self): + """Test requesting tools from a server that doesn't exist.""" + agent = MagicMock(spec=Agent) + agent._server_to_tool_map = {} + + # Should return empty list, not error + server_tools = agent._server_to_tool_map.get("nonexistent", []) + assert server_tools == [] + + def test_request_params_with_invalid_tool_filter_type(self): + """Test that RequestParams handles invalid tool_filter types gracefully.""" + # Test with string (should cause type error) + try: + params = RequestParams(tool_filter="invalid_string") + # If no exception, it's being converted somehow + assert isinstance(params.tool_filter, dict) or params.tool_filter is None + except (ValueError, TypeError): + pass # This is expected behavior + + # Test with dict having non-set values (should convert or error) + try: + params_with_list = RequestParams(tool_filter={"server1": ["tool1", "tool2"]}) + # Pydantic should convert list to set + if params_with_list.tool_filter: + assert isinstance(params_with_list.tool_filter["server1"], set) + assert params_with_list.tool_filter["server1"] == {"tool1", "tool2"} + except (ValueError, TypeError): + pass # This is also acceptable behavior + + def test_request_params_with_empty_dict_tool_filter(self): + """Test that RequestParams accepts empty dict for tool_filter.""" + # Empty dict should be valid (means no tools allowed from any server) + params = RequestParams(tool_filter={}) + assert params.tool_filter == {} + + def test_request_params_with_none_tool_filter_explicit(self): + """Test that RequestParams accepts explicit None for tool_filter.""" + params = RequestParams(tool_filter=None) + assert params.tool_filter is None \ No newline at end of file