Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions src/huggingface_hub/inference/_mcp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import signal
import traceback
from typing import Any, Dict, List, Optional
from typing import Optional

import typer
from rich import print
Expand Down Expand Up @@ -40,8 +40,8 @@ async def run_agent(

config, prompt = _load_agent_config(agent_path)

inputs: List[Dict[str, Any]] = config.get("inputs", [])
servers: List[Dict[str, Any]] = config.get("servers", [])
inputs = config.get("inputs", [])
servers = config.get("servers", [])

abort_event = asyncio.Event()
exit_event = asyncio.Event()
Expand Down Expand Up @@ -82,38 +82,45 @@ def _sigint_handler() -> None:
env_special_value = "${input:" + input_id + "}" # Special value to indicate env variable injection

# Check env variables that will use this input
input_vars = list(
{
key
for server in servers
for key, value in server.get("config", {}).get("env", {}).items()
if value == env_special_value
}
)
input_vars = set()
for server in servers:
# Check stdio's "env" and http/sse's "headers" mappings
env_or_headers = (
server["config"].get("env", {})
if server["type"] == "stdio"
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
)
for key, value in env_or_headers.items():
if env_special_value in value:
input_vars.add(key)

if not input_vars:
print(f"[yellow]Input {input_id} defined in config but not used by any server.[/yellow]")
continue

# Prompt user for input
print(
f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(input_vars)}).",
f"[blue] • {input_id}[/blue]: {description}. (default: load from {', '.join(sorted(input_vars))}).",
end=" ",
)
user_input = (await _async_prompt(exit_event=exit_event)).strip()
if exit_event.is_set():
return

# Inject user input (or env variable) into servers' env
# Inject user input (or env variable) into stdio's env or http/sse's headers
for server in servers:
env = server.get("config", {}).get("env", {})
for key, value in env.items():
env_or_headers = (
server["config"].get("env", {})
if server["type"] == "stdio"
else server["config"].get("options", {}).get("requestInit", {}).get("headers", {})
)
for key, value in env_or_headers.items():
if value == env_special_value:
if user_input:
env[key] = user_input
env_or_headers[key] = env_or_headers[key].replace(env_special_value, user_input)
else:
value_from_env = os.getenv(key, "")
env[key] = value_from_env
env_or_headers[key] = env_or_headers[key].replace(env_special_value, value_from_env)
if value_from_env:
print(f"[green]Value successfully loaded from '{key}'[/green]")
else:
Expand All @@ -125,10 +132,10 @@ def _sigint_handler() -> None:

# Main agent loop
async with Agent(
provider=config.get("provider"),
provider=config.get("provider"), # type: ignore[arg-type]
model=config.get("model"),
base_url=config.get("endpointUrl"),
servers=servers,
base_url=config.get("endpointUrl"), # type: ignore[arg-type]
servers=servers, # type: ignore[arg-type]
prompt=prompt,
) as agent:
await agent.load_tools()
Expand Down
65 changes: 65 additions & 0 deletions src/huggingface_hub/inference/_mcp/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Dict, List, Literal, TypedDict, Union


# Input config
class InputConfig(TypedDict, total=False):
id: str
description: str
type: str
password: bool


# stdio server config
class StdioServerConfig(TypedDict, total=False):
command: str
args: List[str]
env: Dict[str, str]
cwd: str


class StdioServer(TypedDict):
type: Literal["stdio"]
config: StdioServerConfig


# http server config
class HTTPRequestInit(TypedDict, total=False):
headers: Dict[str, str]


class HTTPServerOptions(TypedDict, total=False):
requestInit: HTTPRequestInit
sessionId: str


class HTTPServerConfig(TypedDict, total=False):
url: str
options: HTTPServerOptions


class HTTPServer(TypedDict):
type: Literal["http"]
config: HTTPServerConfig
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should switch now to the standard server config schema https://code.visualstudio.com/docs/copilot/chat/mcp-servers and remove the nested config, wdyt?

cc @julien-c

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep probably... (no super strong opinion)



# sse server config
class SSEServerOptions(TypedDict, total=False):
requestInit: HTTPRequestInit


class SSEServerConfig(TypedDict):
url: str
options: SSEServerOptions


class SSEServer(TypedDict):
type: Literal["sse"]
config: SSEServerConfig


# AgentConfig root object
class AgentConfig(TypedDict):
model: str
provider: str
inputs: List[InputConfig]
servers: List[Union[StdioServer, HTTPServer, SSEServer]]
11 changes: 6 additions & 5 deletions src/huggingface_hub/inference/_mcp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@

import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

from huggingface_hub import snapshot_download
from huggingface_hub.errors import EntryNotFoundError

from .constants import DEFAULT_AGENT, DEFAULT_REPO_ID, FILENAME_CONFIG, FILENAME_PROMPT
from .types import AgentConfig


if TYPE_CHECKING:
Expand Down Expand Up @@ -83,21 +84,21 @@ def _get_base64_size(base64_str: str) -> int:
return (len(base64_str) * 3) // 4 - padding


def _load_agent_config(agent_path: Optional[str]) -> Tuple[Dict[str, Any], Optional[str]]:
def _load_agent_config(agent_path: Optional[str]) -> Tuple[AgentConfig, Optional[str]]:
"""Load server config and prompt."""

def _read_dir(directory: Path) -> Tuple[Dict[str, Any], Optional[str]]:
def _read_dir(directory: Path) -> Tuple[AgentConfig, Optional[str]]:
cfg_file = directory / FILENAME_CONFIG
if not cfg_file.exists():
raise FileNotFoundError(f" Config file not found in {directory}! Please make sure it exists locally")

config: Dict[str, Any] = json.loads(cfg_file.read_text(encoding="utf-8"))
config: AgentConfig = json.loads(cfg_file.read_text(encoding="utf-8"))
prompt_file = directory / FILENAME_PROMPT
prompt: Optional[str] = prompt_file.read_text(encoding="utf-8") if prompt_file.exists() else None
return config, prompt

if agent_path is None:
return DEFAULT_AGENT, None
return DEFAULT_AGENT, None # type: ignore[return-value]

path = Path(agent_path).expanduser()

Expand Down