diff --git a/src/huggingface_hub/inference/_mcp/cli.py b/src/huggingface_hub/inference/_mcp/cli.py index a029b952eb..a2243c0619 100644 --- a/src/huggingface_hub/inference/_mcp/cli.py +++ b/src/huggingface_hub/inference/_mcp/cli.py @@ -2,7 +2,7 @@ import os import signal import traceback -from typing import Any, Dict, List, Optional +from typing import Optional import typer from rich import print @@ -40,8 +40,8 @@ async def run_agent( config, prompt = _load_agent_config(agent_path) - inputs: List[Dict[str, Any]] = config.get("inputs", []) - servers: List[Dict[str, Any]] = config.get("servers", []) + inputs = config.get("inputs", []) + servers = config.get("servers", []) abort_event = asyncio.Event() exit_event = asyncio.Event() @@ -82,14 +82,17 @@ def _sigint_handler() -> None: env_special_value = "${input:" + input_id + "}" # Special value to indicate env variable injection # Check env variables that will use this input - input_vars = list( - { - key - for server in servers - for key, value in server.get("config", {}).get("env", {}).items() - if value == env_special_value - } - ) + input_vars = set() + for server in servers: + # Check stdio's "env" and http/sse's "headers" mappings + env_or_headers = ( + server["config"].get("env", {}) + if server["type"] == "stdio" + else server["config"].get("options", {}).get("requestInit", {}).get("headers", {}) + ) + for key, value in env_or_headers.items(): + if env_special_value in value: + input_vars.add(key) if not input_vars: print(f"[yellow]Input {input_id} defined in config but not used by any server.[/yellow]") @@ -97,23 +100,27 @@ def _sigint_handler() -> None: # Prompt user for input print( - f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(input_vars)}).", + f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(sorted(input_vars))}).", end=" ", ) user_input = (await _async_prompt(exit_event=exit_event)).strip() if exit_event.is_set(): return - # Inject user input (or env variable) into servers' env + # Inject user input (or env variable) into stdio's env or http/sse's headers for server in servers: - env = server.get("config", {}).get("env", {}) - for key, value in env.items(): - if value == env_special_value: + env_or_headers = ( + server["config"].get("env", {}) + if server["type"] == "stdio" + else server["config"].get("options", {}).get("requestInit", {}).get("headers", {}) + ) + for key, value in env_or_headers.items(): + if env_special_value in value: if user_input: - env[key] = user_input + env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input) else: value_from_env = os.getenv(key, "") - env[key] = value_from_env + env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env) if value_from_env: print(f"[green]Value successfully loaded from '{key}'[/green]") else: @@ -125,10 +132,10 @@ def _sigint_handler() -> None: # Main agent loop async with Agent( - provider=config.get("provider"), + provider=config.get("provider"), # type: ignore[arg-type] model=config.get("model"), - base_url=config.get("endpointUrl"), - servers=servers, + base_url=config.get("endpointUrl"), # type: ignore[arg-type] + servers=servers, # type: ignore[arg-type] prompt=prompt, ) as agent: await agent.load_tools() diff --git a/src/huggingface_hub/inference/_mcp/types.py b/src/huggingface_hub/inference/_mcp/types.py new file mode 100644 index 0000000000..dd9c0c1de4 --- /dev/null +++ b/src/huggingface_hub/inference/_mcp/types.py @@ -0,0 +1,65 @@ +from typing import Dict, List, Literal, TypedDict, Union + + +# Input config +class InputConfig(TypedDict, total=False): + id: str + description: str + type: str + password: bool + + +# stdio server config +class StdioServerConfig(TypedDict, total=False): + command: str + args: List[str] + env: Dict[str, str] + cwd: str + + +class StdioServer(TypedDict): + type: Literal["stdio"] + config: StdioServerConfig + + +# http server config +class HTTPRequestInit(TypedDict, total=False): + headers: Dict[str, str] + + +class HTTPServerOptions(TypedDict, total=False): + requestInit: HTTPRequestInit + sessionId: str + + +class HTTPServerConfig(TypedDict, total=False): + url: str + options: HTTPServerOptions + + +class HTTPServer(TypedDict): + type: Literal["http"] + config: HTTPServerConfig + + +# sse server config +class SSEServerOptions(TypedDict, total=False): + requestInit: HTTPRequestInit + + +class SSEServerConfig(TypedDict): + url: str + options: SSEServerOptions + + +class SSEServer(TypedDict): + type: Literal["sse"] + config: SSEServerConfig + + +# AgentConfig root object +class AgentConfig(TypedDict): + model: str + provider: str + inputs: List[InputConfig] + servers: List[Union[StdioServer, HTTPServer, SSEServer]] diff --git a/src/huggingface_hub/inference/_mcp/utils.py b/src/huggingface_hub/inference/_mcp/utils.py index 0490eb600d..6e75eef45f 100644 --- a/src/huggingface_hub/inference/_mcp/utils.py +++ b/src/huggingface_hub/inference/_mcp/utils.py @@ -6,12 +6,13 @@ import json from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from huggingface_hub import snapshot_download from huggingface_hub.errors import EntryNotFoundError from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT +from .types import AgentConfig if TYPE_CHECKING: @@ -83,21 +84,21 @@ def _get_base64_size(base64_str: str) -> int: return (len(base64_str) * 3) // 4 - padding -def _load_agent_config(agent_path: Optional[str]) -> Tuple[Dict[str, Any], Optional[str]]: +def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional[str]]: """Load server config and prompt.""" - def _read_dir(directory: Path) -> Tuple[Dict[str, Any], Optional[str]]: + def _read_dir(directory: Path) -> Tuple[AgentConfig, Optional[str]]: cfg_file = directory / FILENAME_CONFIG if not cfg_file.exists(): raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally") - config: Dict[str, Any] = json.loads(cfg_file.read_text(encoding="utf-8")) + config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8")) prompt_file = directory / FILENAME_PROMPT prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None return config, prompt if agent_path is None: - return DEFAULT_AGENT, None + return DEFAULT_AGENT, None # type: ignore[return-value] path = Path(agent_path).expanduser()