Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ dmypy.json
# Spell checker config
cspell.json

tmp*
tmp*

# Claude Code
CLAUDE.md
16 changes: 15 additions & 1 deletion src/huggingface_hub/inference/_mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,27 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
- args (List[str], optional): Arguments for the command
- env (Dict[str, str], optional): Environment variables for the command
- cwd (Union[str, Path, None], optional): Working directory for the command
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For SSE servers:
- url (str): The URL of the SSE server
- headers (Dict[str, Any], optional): Headers for the SSE connection
- timeout (float, optional): Connection timeout
- sse_read_timeout (float, optional): SSE read timeout
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For StreamableHTTP servers:
- url (str): The URL of the StreamableHTTP server
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
- timeout (timedelta, optional): Connection timeout
- sse_read_timeout (timedelta, optional): SSE read timeout
- terminate_on_close (bool, optional): Whether to terminate on close
- allowed_tools (List[str], optional): List of tool names to allow from this server
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types

# Extract allowed_tools configuration if provided
allowed_tools = params.pop("allowed_tools", None)

# Determine server type and create appropriate parameters
if type == "stdio":
# Handle stdio server
Expand Down Expand Up @@ -211,7 +217,15 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
response = await session.list_tools()
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])

for tool in response.tools:
# Filter tools based on allowed_tools configuration
filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]

if allowed_tools:
logger.debug(
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
)

for tool in filtered_tools:
if tool.name in self.sessions:
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
continue
Expand Down
3 changes: 3 additions & 0 deletions src/huggingface_hub/inference/_mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
args: List[str]
env: Dict[str, str]
cwd: str
allowed_tools: NotRequired[List[str]]


class HTTPServerConfig(TypedDict):
type: Literal["http"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


class SSEServerConfig(TypedDict):
type: Literal["sse"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
Expand Down
Loading