Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ dmypy.json
# Spell checker config
cspell.json

tmp*
tmp*

# Claude Code
CLAUDE.md
41 changes: 39 additions & 2 deletions src/huggingface_hub/inference/_mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,27 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
- args (List[str], optional): Arguments for the command
- env (Dict[str, str], optional): Environment variables for the command
- cwd (Union[str, Path, None], optional): Working directory for the command
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For SSE servers:
- url (str): The URL of the SSE server
- headers (Dict[str, Any], optional): Headers for the SSE connection
- timeout (float, optional): Connection timeout
- sse_read_timeout (float, optional): SSE read timeout
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For StreamableHTTP servers:
- url (str): The URL of the StreamableHTTP server
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
- timeout (timedelta, optional): Connection timeout
- sse_read_timeout (timedelta, optional): SSE read timeout
- terminate_on_close (bool, optional): Whether to terminate on close
- allowed_tools (List[str], optional): List of tool names to allow from this server
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types

# Extract allowed_tools configuration if provided
allowed_tools = params.pop("allowed_tools", None)

# Determine server type and create appropriate parameters
if type == "stdio":
# Handle stdio server
Expand Down Expand Up @@ -209,9 +215,18 @@ async def add_mcp_server(self, type: ServerType, **params: Any):

# List available tools
response = await session.list_tools()
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
all_tool_names = [tool.name for tool in response.tools]
logger.debug("Connected to server with tools:", all_tool_names)

# Filter tools based on allowed_tools configuration
filtered_tools = self._filter_tools(response.tools, allowed_tools)

if allowed_tools:
logger.info(
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
)

for tool in response.tools:
for tool in filtered_tools:
if tool.name in self.sessions:
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
continue
Expand All @@ -233,6 +248,28 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
)
)

def _filter_tools(self, tools: List[Any], allowed_tools: Optional[List[str]]) -> List[Any]:
"""Filter tools based on allowed_tools list.

Args:
tools: List of MCP tool objects
allowed_tools: Optional list of tool names to allow

Returns:
Filtered list of tools
"""
if allowed_tools is None:
return tools

# Validate that specified tools exist
all_tool_names = [tool.name for tool in tools]
missing_tools = set(allowed_tools) - set(all_tool_names)
if missing_tools:
logger.warning(f"Tools specified in 'allowed_tools' not found on server: {list(missing_tools)}")

# Filter tools using list comprehension
return [tool for tool in tools if tool.name in allowed_tools]

async def process_single_turn_with_tools(
self,
messages: List[Union[Dict, ChatCompletionInputMessage]],
Expand Down
3 changes: 3 additions & 0 deletions src/huggingface_hub/inference/_mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
args: List[str]
env: Dict[str, str]
cwd: str
allowed_tools: NotRequired[List[str]]


class HTTPServerConfig(TypedDict):
type: Literal["http"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


class SSEServerConfig(TypedDict):
type: Literal["sse"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
Expand Down
104 changes: 104 additions & 0 deletions tests/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import MagicMock

from huggingface_hub.inference._mcp.mcp_client import MCPClient


class TestMCPClient(unittest.TestCase):
def setUp(self):
self.client = MCPClient(model="test-model", provider="test-provider")

def test_filter_tools_no_allowed_tools(self):
"""Test that _filter_tools returns all tools when no allowed_tools is specified."""
# Create mock tools
mock_tools = [
MagicMock(name="tool1"),
MagicMock(name="tool2"),
MagicMock(name="tool3"),
]

result = self.client._filter_tools(mock_tools, None)

self.assertEqual(len(result), 3)
self.assertEqual(result, mock_tools)

def test_filter_tools_with_allowed_tools(self):
"""Test that _filter_tools correctly filters tools based on allowed_tools list."""
# Create mock tools
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool2 = MagicMock()
mock_tool2.name = "tool2"
mock_tool3 = MagicMock()
mock_tool3.name = "tool3"

mock_tools = [mock_tool1, mock_tool2, mock_tool3]
allowed_tools = ["tool1", "tool3"]

result = self.client._filter_tools(mock_tools, allowed_tools)

self.assertEqual(len(result), 2)
self.assertIn(mock_tool1, result)
self.assertIn(mock_tool3, result)
self.assertNotIn(mock_tool2, result)

def test_filter_tools_with_empty_allowed_tools(self):
"""Test that _filter_tools returns empty list when allowed_tools is empty."""
mock_tools = [
MagicMock(name="tool1"),
MagicMock(name="tool2"),
]

result = self.client._filter_tools(mock_tools, [])

self.assertEqual(len(result), 0)

def test_filter_tools_with_nonexistent_tools(self):
"""Test that _filter_tools handles non-existent tool names gracefully."""
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tools = [mock_tool1]

# Include a non-existent tool in allowed_tools
allowed_tools = ["tool1", "nonexistent_tool"]

with self.assertLogs(level="WARNING") as log:
result = self.client._filter_tools(mock_tools, allowed_tools)

# Should only return existing tools
self.assertEqual(len(result), 1)
self.assertEqual(result[0], mock_tool1)

# Should log a warning about missing tools
self.assertIn("not found on server", log.output[0])

def test_filter_tools_all_nonexistent_tools(self):
"""Test that _filter_tools returns empty list when all allowed_tools are non-existent."""
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tools = [mock_tool1]

allowed_tools = ["nonexistent_tool1", "nonexistent_tool2"]

with self.assertLogs(level="WARNING") as log:
result = self.client._filter_tools(mock_tools, allowed_tools)

self.assertEqual(len(result), 0)
self.assertIn("not found on server", log.output[0])


if __name__ == "__main__":
unittest.main()
Loading