Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ dmypy.json
# Spell checker config
cspell.json

tmp*
tmp*

# Claude Code
CLAUDE.md
19 changes: 17 additions & 2 deletions src/huggingface_hub/inference/_mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,27 @@ async def add_mcp_server(self, type: ServerType, **params: Any):
- args (List[str], optional): Arguments for the command
- env (Dict[str, str], optional): Environment variables for the command
- cwd (Union[str, Path, None], optional): Working directory for the command
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For SSE servers:
- url (str): The URL of the SSE server
- headers (Dict[str, Any], optional): Headers for the SSE connection
- timeout (float, optional): Connection timeout
- sse_read_timeout (float, optional): SSE read timeout
- allowed_tools (List[str], optional): List of tool names to allow from this server
- For StreamableHTTP servers:
- url (str): The URL of the StreamableHTTP server
- headers (Dict[str, Any], optional): Headers for the StreamableHTTP connection
- timeout (timedelta, optional): Connection timeout
- sse_read_timeout (timedelta, optional): SSE read timeout
- terminate_on_close (bool, optional): Whether to terminate on close
- allowed_tools (List[str], optional): List of tool names to allow from this server
"""
from mcp import ClientSession, StdioServerParameters
from mcp import types as mcp_types

# Extract allowed_tools configuration if provided
allowed_tools = params.pop("allowed_tools", None)

# Determine server type and create appropriate parameters
if type == "stdio":
# Handle stdio server
Expand Down Expand Up @@ -209,9 +215,18 @@ async def add_mcp_server(self, type: ServerType, **params: Any):

# List available tools
response = await session.list_tools()
logger.debug("Connected to server with tools:", [tool.name for tool in response.tools])
all_tool_names = [tool.name for tool in response.tools]
logger.debug("Connected to server with tools:", all_tool_names)

# Filter tools based on allowed_tools configuration
filtered_tools = [tool for tool in response.tools if tool.name in allowed_tools]

if allowed_tools:
logger.debug(
f"Tool filtering applied. Using {len(filtered_tools)} of {len(response.tools)} available tools: {[tool.name for tool in filtered_tools]}"
)

for tool in response.tools:
for tool in filtered_tools:
if tool.name in self.sessions:
logger.warning(f"Tool '{tool.name}' already defined by another server. Skipping.")
continue
Expand Down
3 changes: 3 additions & 0 deletions src/huggingface_hub/inference/_mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,21 @@ class StdioServerConfig(TypedDict):
args: List[str]
env: Dict[str, str]
cwd: str
allowed_tools: NotRequired[List[str]]


class HTTPServerConfig(TypedDict):
type: Literal["http"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


class SSEServerConfig(TypedDict):
type: Literal["sse"]
url: str
headers: Dict[str, str]
allowed_tools: NotRequired[List[str]]


ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]
Expand Down
Loading