diff --git a/src/huggingface_hub/inference/_mcp/agent.py b/src/huggingface_hub/inference/_mcp/agent.py index ceac07c791..4f88016ba7 100644 --- a/src/huggingface_hub/inference/_mcp/agent.py +++ b/src/huggingface_hub/inference/_mcp/agent.py @@ -7,6 +7,7 @@ from .._providers import PROVIDER_OR_POLICY_T from .constants import DEFAULT_SYSTEM_PROMPT, EXIT_LOOP_TOOLS, MAX_NUM_TURNS +from .types import ServerConfig class Agent(MCPClient): @@ -40,7 +41,7 @@ def __init__( self, *, model: Optional[str] = None, - servers: Iterable[Dict], + servers: Iterable[ServerConfig], provider: Optional[PROVIDER_OR_POLICY_T] = None, base_url: Optional[str] = None, api_key: Optional[str] = None, @@ -54,7 +55,7 @@ def __init__( async def load_tools(self) -> None: for cfg in self._servers_cfg: - await self.add_mcp_server(cfg["type"], **cfg["config"]) + await self.add_mcp_server(**cfg) async def run( self, diff --git a/src/huggingface_hub/inference/_mcp/cli.py b/src/huggingface_hub/inference/_mcp/cli.py index a2243c0619..5d599ee829 100644 --- a/src/huggingface_hub/inference/_mcp/cli.py +++ b/src/huggingface_hub/inference/_mcp/cli.py @@ -85,11 +85,7 @@ def _sigint_handler() -> None: input_vars = set() for server in servers: # Check stdio's "env" and http/sse's "headers" mappings - env_or_headers = ( - server["config"].get("env", {}) - if server["type"] == "stdio" - else server["config"].get("options", {}).get("requestInit", {}).get("headers", {}) - ) + env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {}) for key, value in env_or_headers.items(): if env_special_value in value: input_vars.add(key) @@ -99,8 +95,9 @@ def _sigint_handler() -> None: continue # Prompt user for input + env_variable_key = input_id.replace("-", "_").upper() print( - f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(sorted(input_vars))}).", + f"[blue] • {input_id}[/blue]: {description}. (default: load from {env_variable_key}).", end=" ", ) user_input = (await _async_prompt(exit_event=exit_event)).strip() @@ -109,23 +106,19 @@ def _sigint_handler() -> None: # Inject user input (or env variable) into stdio's env or http/sse's headers for server in servers: - env_or_headers = ( - server["config"].get("env", {}) - if server["type"] == "stdio" - else server["config"].get("options", {}).get("requestInit", {}).get("headers", {}) - ) + env_or_headers = server.get("env", {}) if server["type"] == "stdio" else server.get("headers", {}) for key, value in env_or_headers.items(): if env_special_value in value: if user_input: env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input) else: - value_from_env = os.getenv(key, "") + value_from_env = os.getenv(env_variable_key, "") env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env) if value_from_env: - print(f"[green]Value successfully loaded from '{key}'[/green]") + print(f"[green]Value successfully loaded from '{env_variable_key}'[/green]") else: print( - f"[yellow]No value found for '{key}' in environment variables. Continuing.[/yellow]" + f"[yellow]No value found for '{env_variable_key}' in environment variables. Continuing.[/yellow]" ) print() diff --git a/src/huggingface_hub/inference/_mcp/types.py b/src/huggingface_hub/inference/_mcp/types.py index dd9c0c1de4..7177695e7b 100644 --- a/src/huggingface_hub/inference/_mcp/types.py +++ b/src/huggingface_hub/inference/_mcp/types.py @@ -1,7 +1,6 @@ from typing import Dict, List, Literal, TypedDict, Union -# Input config class InputConfig(TypedDict, total=False): id: str description: str @@ -9,52 +8,27 @@ class InputConfig(TypedDict, total=False): password: bool -# stdio server config -class StdioServerConfig(TypedDict, total=False): +class StdioServerConfig(TypedDict): + type: Literal["stdio"] command: str args: List[str] env: Dict[str, str] cwd: str -class StdioServer(TypedDict): - type: Literal["stdio"] - config: StdioServerConfig - - -# http server config -class HTTPRequestInit(TypedDict, total=False): - headers: Dict[str, str] - - -class HTTPServerOptions(TypedDict, total=False): - requestInit: HTTPRequestInit - sessionId: str - - -class HTTPServerConfig(TypedDict, total=False): - url: str - options: HTTPServerOptions - - -class HTTPServer(TypedDict): +class HTTPServerConfig(TypedDict): type: Literal["http"] - config: HTTPServerConfig - - -# sse server config -class SSEServerOptions(TypedDict, total=False): - requestInit: HTTPRequestInit + url: str + headers: Dict[str, str] class SSEServerConfig(TypedDict): + type: Literal["sse"] url: str - options: SSEServerOptions + headers: Dict[str, str] -class SSEServer(TypedDict): - type: Literal["sse"] - config: SSEServerConfig +ServerConfig = Union[StdioServerConfig, HTTPServerConfig, SSEServerConfig] # AgentConfig root object @@ -62,4 +36,4 @@ class AgentConfig(TypedDict): model: str provider: str inputs: List[InputConfig] - servers: List[Union[StdioServer, HTTPServer, SSEServer]] + servers: List[ServerConfig]