Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/huggingface_hub/inference/_mcp/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .._providers import PROVIDER_OR_POLICY_T
from .constants import DEFAULT_SYSTEM_PROMPT, EXIT_LOOP_TOOLS, MAX_NUM_TURNS
from .types import ServerConfig


class Agent(MCPClient):
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(
self,
*,
model: Optional[str] = None,
servers: Iterable[Dict],
servers: Iterable[ServerConfig],
provider: Optional[PROVIDER_OR_POLICY_T] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
Expand All @@ -54,7 +55,7 @@ def __init__(

async def load_tools(self) -> None:
for cfg in self._servers_cfg:
await self.add_mcp_server(cfg["type"], **cfg["config"])
await self.add_mcp_server(**cfg)

async def run(
self,
Expand Down
21 changes: 7 additions & 14 deletions src/huggingface_hub/inference/_mcp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,7 @@ def _sigint_handler() -> None:
input_vars = set()
for server in servers:
# Check stdio's "env" and http/sse's "headers" mappings
env_or_headers = (
server["config"].get("env", {})
if server["type"] == "stdio"
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
)
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
for key, value in env_or_headers.items():
if env_special_value in value:
input_vars.add(key)
Expand All @@ -99,8 +95,9 @@ def _sigint_handler() -> None:
continue

# Prompt user for input
env_variable_key = input_id.replace("-", "_").upper()
print(
f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(sorted(input_vars))}).",
f"[blue] • {input_id}[/blue]: {description}. (default: load from {env_variable_key}).",
end=" ",
)
user_input = (await _async_prompt(exit_event=exit_event)).strip()
Expand All @@ -109,23 +106,19 @@ def _sigint_handler() -> None:

# Inject user input (or env variable) into stdio's env or http/sse's headers
for server in servers:
env_or_headers = (
server["config"].get("env", {})
if server["type"] == "stdio"
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
)
env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {})
for key, value in env_or_headers.items():
if env_special_value in value:
if user_input:
env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input)
else:
value_from_env = os.getenv(key, "")
value_from_env = os.getenv(env_variable_key, "")
env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env)
if value_from_env:
print(f"[green]Value successfully loaded from '{key}'[/green]")
print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]")
else:
print(
f"[yellow]No value found for '{key}' in environment variables. Continuing.[/yellow]"
f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]"
)

print()
Expand Down
44 changes: 9 additions & 35 deletions src/huggingface_hub/inference/_mcp/types.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,39 @@
from typing import Dict, List, Literal, TypedDict, Union


# Input config
class InputConfig(TypedDict, total=False):
id: str
description: str
type: str
password: bool


# stdio server config
class StdioServerConfig(TypedDict, total=False):
class StdioServerConfig(TypedDict):
type: Literal["stdio"]
command: str
args: List[str]
env: Dict[str, str]
cwd: str


class StdioServer(TypedDict):
type: Literal["stdio"]
config: StdioServerConfig


# http server config
class HTTPRequestInit(TypedDict, total=False):
headers: Dict[str, str]


class HTTPServerOptions(TypedDict, total=False):
requestInit: HTTPRequestInit
sessionId: str


class HTTPServerConfig(TypedDict, total=False):
url: str
options: HTTPServerOptions


class HTTPServer(TypedDict):
class HTTPServerConfig(TypedDict):
type: Literal["http"]
config: HTTPServerConfig


# sse server config
class SSEServerOptions(TypedDict, total=False):
requestInit: HTTPRequestInit
url: str
headers: Dict[str, str]


class SSEServerConfig(TypedDict):
type: Literal["sse"]
url: str
options: SSEServerOptions
headers: Dict[str, str]


class SSEServer(TypedDict):
type: Literal["sse"]
config: SSEServerConfig
ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig]


# AgentConfig root object
class AgentConfig(TypedDict):
model: str
provider: str
inputs: List[InputConfig]
servers: List[Union[StdioServer, HTTPServer, SSEServer]]
servers: List[ServerConfig]