Skip to content

[Tiny Agents] Add tools to config #3242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ dmypy.json
# Spell checker config
cspell.json

tmp*
tmp*

# Claude Code
CLAUDE.md
41 changes: 39 additions & 2 deletions src/huggingface_hub/inference/_mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,27 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
- args (List[str], optional): Arguments for the command
- env (Dict[str, str], optional): Environment variables for the command
- cwd (Union[str, Path, None], optional): Working directory for the command
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For SSE servers:
- url (str): The URL of the SSE server
- headers (Dict[str, Any], optional): Headers for the SSE connection
- timeout (float, optional): Connection timeout
- sse_read_timeout (float, optional): SSE read timeout
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For StreamableHTTP servers:
- url (str): The URL of the StreamableHTTP server
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
- timeout (timedelta, optional): Connection timeout
- sse_read_timeout (timedelta, optional): SSE read timeout
- terminate_on_close (bool, optional): Whether to terminate on close
- allowed_tools (List[str], optional): List of tool names to allow from this server
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types

# Extract allowed_tools configuration if provided
allowed_tools = params.pop("allowed_tools", None)

# Determine server type and create appropriate parameters
if type == "stdio":
# Handle stdio server
Expand Down Expand Up @@ -209,9 +215,18 @@ async def add_mcp_server(self, type: ServerType, **params: Any):

# List available tools
response = await session.list_tools()
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
all_tool_names = [tool.name for tool in response.tools]
logger.debug("Connected to server with tools:", all_tool_names)

# Filter tools based on allowed_tools configuration
filtered_tools = self._filter_tools(response.tools, allowed_tools)

if allowed_tools:
logger.info(
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
)

for tool in response.tools:
for tool in filtered_tools:
if tool.name in self.sessions:
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
continue
Expand All @@ -233,6 +248,28 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
)
)

def _filter_tools(self, tools: List[Any], allowed_tools: Optional[List[str]]) -> List[Any]:
"""Filter tools based on allowed_tools list.

Args:
tools: List of MCP tool objects
allowed_tools: Optional list of tool names to allow

Returns:
Filtered list of tools
"""
if allowed_tools is None:
return tools

# Validate that specified tools exist
all_tool_names = [tool.name for tool in tools]
missing_tools = set(allowed_tools) - set(all_tool_names)
if missing_tools:
logger.warning(f"Tools specified in 'allowed_tools' not found on server: {list(missing_tools)}")

# Filter tools using list comprehension
return [tool for tool in tools if tool.name in allowed_tools]

async def process_single_turn_with_tools(
self,
messages: List[Union[Dict, ChatCompletionInputMessage]],
Expand Down
3 changes: 3 additions & 0 deletions src/huggingface_hub/inference/_mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
args: List[str]
env: Dict[str, str]
cwd: str
allowed_tools: NotRequired[List[str]]


class HTTPServerConfig(TypedDict):
type: Literal["http"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


class SSEServerConfig(TypedDict):
type: Literal["sse"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
Expand Down
104 changes: 104 additions & 0 deletions tests/test_mcp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import MagicMock

from huggingface_hub.inference._mcp.mcp_client import MCPClient


class TestMCPClient(unittest.TestCase):
def setUp(self):
self.client = MCPClient(model="test-model", provider="test-provider")

def test_filter_tools_no_allowed_tools(self):
"""Test that _filter_tools returns all tools when no allowed_tools is specified."""
# Create mock tools
mock_tools = [
MagicMock(name="tool1"),
MagicMock(name="tool2"),
MagicMock(name="tool3"),
]

result = self.client._filter_tools(mock_tools, None)

self.assertEqual(len(result), 3)
self.assertEqual(result, mock_tools)

def test_filter_tools_with_allowed_tools(self):
"""Test that _filter_tools correctly filters tools based on allowed_tools list."""
# Create mock tools
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tool2 = MagicMock()
mock_tool2.name = "tool2"
mock_tool3 = MagicMock()
mock_tool3.name = "tool3"

mock_tools = [mock_tool1, mock_tool2, mock_tool3]
allowed_tools = ["tool1", "tool3"]

result = self.client._filter_tools(mock_tools, allowed_tools)

self.assertEqual(len(result), 2)
self.assertIn(mock_tool1, result)
self.assertIn(mock_tool3, result)
self.assertNotIn(mock_tool2, result)

def test_filter_tools_with_empty_allowed_tools(self):
"""Test that _filter_tools returns empty list when allowed_tools is empty."""
mock_tools = [
MagicMock(name="tool1"),
MagicMock(name="tool2"),
]

result = self.client._filter_tools(mock_tools, [])

self.assertEqual(len(result), 0)

def test_filter_tools_with_nonexistent_tools(self):
"""Test that _filter_tools handles non-existent tool names gracefully."""
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tools = [mock_tool1]

# Include a non-existent tool in allowed_tools
allowed_tools = ["tool1", "nonexistent_tool"]

with self.assertLogs(level="WARNING") as log:
result = self.client._filter_tools(mock_tools, allowed_tools)

# Should only return existing tools
self.assertEqual(len(result), 1)
self.assertEqual(result[0], mock_tool1)

# Should log a warning about missing tools
self.assertIn("not found on server", log.output[0])

def test_filter_tools_all_nonexistent_tools(self):
"""Test that _filter_tools returns empty list when all allowed_tools are non-existent."""
mock_tool1 = MagicMock()
mock_tool1.name = "tool1"
mock_tools = [mock_tool1]

allowed_tools = ["nonexistent_tool1", "nonexistent_tool2"]

with self.assertLogs(level="WARNING") as log:
result = self.client._filter_tools(mock_tools, allowed_tools)

self.assertEqual(len(result), 0)
self.assertIn("not found on server", log.output[0])


if __name__ == "__main__":
unittest.main()
Loading