diff --git a/.gitignore b/.gitignore index 705fdcc38b..f989a75311 100644 --- a/.gitignore +++ b/.gitignore @@ -139,4 +139,7 @@ dmypy.json # Spell checker config cspell.json -tmp* \ No newline at end of file +tmp* + +# Claude Code +CLAUDE.md \ No newline at end of file diff --git a/src/huggingface_hub/inference/_mcp/mcp_client.py b/src/huggingface_hub/inference/_mcp/mcp_client.py index 4acbdbfb31..a6defd45aa 100644 --- a/src/huggingface_hub/inference/_mcp/mcp_client.py +++ b/src/huggingface_hub/inference/_mcp/mcp_client.py @@ -139,21 +139,27 @@ async def add_mcp_server(self, type: ServerType, **params: Any): - args (List[str], optional): Arguments for the command - env (Dict[str, str], optional): Environment variables for the command - cwd (Union[str, Path, None], optional): Working directory for the command + - allowed_tools (List[str], optional): List of tool names to allow from this server - For SSE servers: - url (str): The URL of the SSE server - headers (Dict[str, Any], optional): Headers for the SSE connection - timeout (float, optional): Connection timeout - sse_read_timeout (float, optional): SSE read timeout + - allowed_tools (List[str], optional): List of tool names to allow from this server - For StreamableHTTP servers: - url (str): The URL of the StreamableHTTP server - headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection - timeout (timedelta, optional): Connection timeout - sse_read_timeout (timedelta, optional): SSE read timeout - terminate_on_close (bool, optional): Whether to terminate on close + - allowed_tools (List[str], optional): List of tool names to allow from this server """ from mcp import ClientSession, StdioServerParameters from mcp import types as mcp_types + # Extract allowed_tools configuration if provided + allowed_tools = params.pop("allowed_tools", None) + # Determine server type and create appropriate parameters if type == "stdio": # Handle stdio server @@ -209,9 +215,18 @@ async def add_mcp_server(self, type: ServerType, **params: Any): # List available tools response = await session.list_tools() - logger.debug("Connected to server with tools:", [tool.name for tool in response.tools]) + all_tool_names = [tool.name for tool in response.tools] + logger.debug("Connected to server with tools:", all_tool_names) + + # Filter tools based on allowed_tools configuration + filtered_tools = self._filter_tools(response.tools, allowed_tools) + + if allowed_tools: + logger.info( + f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}" + ) - for tool in response.tools: + for tool in filtered_tools: if tool.name in self.sessions: logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.") continue @@ -233,6 +248,28 @@ async def add_mcp_server(self, type: ServerType, **params: Any): ) ) + def _filter_tools(self, tools: List[Any], allowed_tools: Optional[List[str]]) -> List[Any]: + """Filter tools based on allowed_tools list. + + Args: + tools: List of MCP tool objects + allowed_tools: Optional list of tool names to allow + + Returns: + Filtered list of tools + """ + if allowed_tools is None: + return tools + + # Validate that specified tools exist + all_tool_names = [tool.name for tool in tools] + missing_tools = set(allowed_tools) - set(all_tool_names) + if missing_tools: + logger.warning(f"Tools specified in 'allowed_tools' not found on server: {list(missing_tools)}") + + # Filter tools using list comprehension + return [tool for tool in tools if tool.name in allowed_tools] + async def process_single_turn_with_tools( self, messages: List[Union[Dict, ChatCompletionInputMessage]], diff --git a/src/huggingface_hub/inference/_mcp/types.py b/src/huggingface_hub/inference/_mcp/types.py index cfb5e0eac9..100f67832e 100644 --- a/src/huggingface_hub/inference/_mcp/types.py +++ b/src/huggingface_hub/inference/_mcp/types.py @@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict): args: List[str] env: Dict[str, str] cwd: str + allowed_tools: NotRequired[List[str]] class HTTPServerConfig(TypedDict): type: Literal["http"] url: str headers: Dict[str, str] + allowed_tools: NotRequired[List[str]] class SSEServerConfig(TypedDict): type: Literal["sse"] url: str headers: Dict[str, str] + allowed_tools: NotRequired[List[str]] ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig] diff --git a/tests/test_mcp_client.py b/tests/test_mcp_client.py new file mode 100644 index 0000000000..7dd9798e0e --- /dev/null +++ b/tests/test_mcp_client.py @@ -0,0 +1,104 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +from unittest.mock import MagicMock + +from huggingface_hub.inference._mcp.mcp_client import MCPClient + + +class TestMCPClient(unittest.TestCase): + def setUp(self): + self.client = MCPClient(model="test-model", provider="test-provider") + + def test_filter_tools_no_allowed_tools(self): + """Test that _filter_tools returns all tools when no allowed_tools is specified.""" + # Create mock tools + mock_tools = [ + MagicMock(name="tool1"), + MagicMock(name="tool2"), + MagicMock(name="tool3"), + ] + + result = self.client._filter_tools(mock_tools, None) + + self.assertEqual(len(result), 3) + self.assertEqual(result, mock_tools) + + def test_filter_tools_with_allowed_tools(self): + """Test that _filter_tools correctly filters tools based on allowed_tools list.""" + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool3 = MagicMock() + mock_tool3.name = "tool3" + + mock_tools = [mock_tool1, mock_tool2, mock_tool3] + allowed_tools = ["tool1", "tool3"] + + result = self.client._filter_tools(mock_tools, allowed_tools) + + self.assertEqual(len(result), 2) + self.assertIn(mock_tool1, result) + self.assertIn(mock_tool3, result) + self.assertNotIn(mock_tool2, result) + + def test_filter_tools_with_empty_allowed_tools(self): + """Test that _filter_tools returns empty list when allowed_tools is empty.""" + mock_tools = [ + MagicMock(name="tool1"), + MagicMock(name="tool2"), + ] + + result = self.client._filter_tools(mock_tools, []) + + self.assertEqual(len(result), 0) + + def test_filter_tools_with_nonexistent_tools(self): + """Test that _filter_tools handles non-existent tool names gracefully.""" + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tools = [mock_tool1] + + # Include a non-existent tool in allowed_tools + allowed_tools = ["tool1", "nonexistent_tool"] + + with self.assertLogs(level="WARNING") as log: + result = self.client._filter_tools(mock_tools, allowed_tools) + + # Should only return existing tools + self.assertEqual(len(result), 1) + self.assertEqual(result[0], mock_tool1) + + # Should log a warning about missing tools + self.assertIn("not found on server", log.output[0]) + + def test_filter_tools_all_nonexistent_tools(self): + """Test that _filter_tools returns empty list when all allowed_tools are non-existent.""" + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tools = [mock_tool1] + + allowed_tools = ["nonexistent_tool1", "nonexistent_tool2"] + + with self.assertLogs(level="WARNING") as log: + result = self.client._filter_tools(mock_tools, allowed_tools) + + self.assertEqual(len(result), 0) + self.assertIn("not found on server", log.output[0]) + + +if __name__ == "__main__": + unittest.main()