From c77dd2cfad4d1ff4dbc5430093b5a9d19967c526 Mon Sep 17 00:00:00 2001 From: "wenhui.zhang" Date: Tue, 26 Aug 2025 14:49:26 -0700 Subject: [PATCH 1/2] add secure annotation --- src/mcp/mcp/__init__.py | 116 ++ src/mcp/mcp/cli/__init__.py | 6 + src/mcp/mcp/cli/claude.py | 148 ++ src/mcp/mcp/cli/cli.py | 488 ++++++ src/mcp/mcp/client/__init__.py | 0 src/mcp/mcp/client/__main__.py | 85 ++ src/mcp/mcp/client/auth.py | 551 +++++++ src/mcp/mcp/client/session.py | 458 ++++++ src/mcp/mcp/client/session_group.py | 366 +++++ src/mcp/mcp/client/sse.py | 144 ++ src/mcp/mcp/client/stdio/__init__.py | 277 ++++ src/mcp/mcp/client/streamable_http.py | 511 +++++++ src/mcp/mcp/client/websocket.py | 86 ++ src/mcp/mcp/os/__init__.py | 1 + src/mcp/mcp/os/posix/__init__.py | 1 + src/mcp/mcp/os/posix/utilities.py | 60 + src/mcp/mcp/os/win32/__init__.py | 1 + src/mcp/mcp/os/win32/utilities.py | 338 +++++ src/mcp/mcp/py.typed | 0 src/mcp/mcp/server/__init__.py | 5 + src/mcp/mcp/server/__main__.py | 50 + src/mcp/mcp/server/auth/__init__.py | 3 + src/mcp/mcp/server/auth/errors.py | 5 + src/mcp/mcp/server/auth/handlers/__init__.py | 3 + src/mcp/mcp/server/auth/handlers/authorize.py | 224 +++ src/mcp/mcp/server/auth/handlers/metadata.py | 29 + src/mcp/mcp/server/auth/handlers/register.py | 120 ++ src/mcp/mcp/server/auth/handlers/revoke.py | 94 ++ src/mcp/mcp/server/auth/handlers/token.py | 238 +++ src/mcp/mcp/server/auth/json_response.py | 10 + .../mcp/server/auth/middleware/__init__.py | 3 + .../server/auth/middleware/auth_context.py | 48 + .../mcp/server/auth/middleware/bearer_auth.py | 128 ++ .../mcp/server/auth/middleware/client_auth.py | 51 + src/mcp/mcp/server/auth/provider.py | 306 ++++ src/mcp/mcp/server/auth/routes.py | 227 +++ src/mcp/mcp/server/auth/settings.py | 30 + src/mcp/mcp/server/elicitation.py | 111 ++ src/mcp/mcp/server/fastmcp/__init__.py | 9 + src/mcp/mcp/server/fastmcp/exceptions.py | 21 + .../mcp/server/fastmcp/prompts/__init__.py | 4 + src/mcp/mcp/server/fastmcp/prompts/base.py | 153 ++ src/mcp/mcp/server/fastmcp/prompts/manager.py | 48 + .../mcp/server/fastmcp/resources/__init__.py | 23 + src/mcp/mcp/server/fastmcp/resources/base.py | 45 + .../fastmcp/resources/resource_manager.py | 97 ++ .../mcp/server/fastmcp/resources/templates.py | 83 ++ src/mcp/mcp/server/fastmcp/resources/types.py | 196 +++ src/mcp/mcp/server/fastmcp/secure/__init__.py | 54 + .../mcp/server/fastmcp/secure/annotations.py | 326 ++++ src/mcp/mcp/server/fastmcp/secure/identity.py | 390 +++++ src/mcp/mcp/server/fastmcp/secure/prompt.py | 380 +++++ src/mcp/mcp/server/fastmcp/secure/resource.py | 270 ++++ src/mcp/mcp/server/fastmcp/secure/session.py | 473 ++++++ src/mcp/mcp/server/fastmcp/secure/tool.py | 272 ++++ src/mcp/mcp/server/fastmcp/secure/utils.py | 628 ++++++++ src/mcp/mcp/server/fastmcp/server.py | 1188 +++++++++++++++ src/mcp/mcp/server/fastmcp/tools/__init__.py | 4 + src/mcp/mcp/server/fastmcp/tools/base.py | 119 ++ .../mcp/server/fastmcp/tools/tool_manager.py | 83 ++ .../mcp/server/fastmcp/utilities/__init__.py | 1 + .../server/fastmcp/utilities/func_metadata.py | 523 +++++++ .../mcp/server/fastmcp/utilities/logging.py | 43 + src/mcp/mcp/server/fastmcp/utilities/types.py | 101 ++ src/mcp/mcp/server/lowlevel/__init__.py | 3 + src/mcp/mcp/server/lowlevel/helper_types.py | 9 + src/mcp/mcp/server/lowlevel/server.py | 686 +++++++++ src/mcp/mcp/server/models.py | 17 + src/mcp/mcp/server/session.py | 359 +++++ src/mcp/mcp/server/sse.py | 249 ++++ src/mcp/mcp/server/stdio.py | 88 ++ src/mcp/mcp/server/streamable_http.py | 901 +++++++++++ src/mcp/mcp/server/streamable_http_manager.py | 279 ++++ .../mcp/server/streaming_asgi_transport.py | 203 +++ src/mcp/mcp/server/transport_security.py | 127 ++ src/mcp/mcp/server/websocket.py | 62 + src/mcp/mcp/shared/__init__.py | 0 src/mcp/mcp/shared/_httpx_utils.py | 83 ++ src/mcp/mcp/shared/auth.py | 155 ++ src/mcp/mcp/shared/auth_utils.py | 69 + src/mcp/mcp/shared/context.py | 20 + src/mcp/mcp/shared/exceptions.py | 14 + src/mcp/mcp/shared/memory.py | 99 ++ src/mcp/mcp/shared/message.py | 43 + src/mcp/mcp/shared/metadata_utils.py | 45 + src/mcp/mcp/shared/progress.py | 58 + src/mcp/mcp/shared/session.py | 470 ++++++ src/mcp/mcp/shared/version.py | 3 + src/mcp/mcp/types.py | 1320 +++++++++++++++++ 89 files changed, 16218 insertions(+) create mode 100644 src/mcp/mcp/__init__.py create mode 100644 src/mcp/mcp/cli/__init__.py create mode 100644 src/mcp/mcp/cli/claude.py create mode 100644 src/mcp/mcp/cli/cli.py create mode 100644 src/mcp/mcp/client/__init__.py create mode 100644 src/mcp/mcp/client/__main__.py create mode 100644 src/mcp/mcp/client/auth.py create mode 100644 src/mcp/mcp/client/session.py create mode 100644 src/mcp/mcp/client/session_group.py create mode 100644 src/mcp/mcp/client/sse.py create mode 100644 src/mcp/mcp/client/stdio/__init__.py create mode 100644 src/mcp/mcp/client/streamable_http.py create mode 100644 src/mcp/mcp/client/websocket.py create mode 100644 src/mcp/mcp/os/__init__.py create mode 100644 src/mcp/mcp/os/posix/__init__.py create mode 100644 src/mcp/mcp/os/posix/utilities.py create mode 100644 src/mcp/mcp/os/win32/__init__.py create mode 100644 src/mcp/mcp/os/win32/utilities.py create mode 100644 src/mcp/mcp/py.typed create mode 100644 src/mcp/mcp/server/__init__.py create mode 100644 src/mcp/mcp/server/__main__.py create mode 100644 src/mcp/mcp/server/auth/__init__.py create mode 100644 src/mcp/mcp/server/auth/errors.py create mode 100644 src/mcp/mcp/server/auth/handlers/__init__.py create mode 100644 src/mcp/mcp/server/auth/handlers/authorize.py create mode 100644 src/mcp/mcp/server/auth/handlers/metadata.py create mode 100644 src/mcp/mcp/server/auth/handlers/register.py create mode 100644 src/mcp/mcp/server/auth/handlers/revoke.py create mode 100644 src/mcp/mcp/server/auth/handlers/token.py create mode 100644 src/mcp/mcp/server/auth/json_response.py create mode 100644 src/mcp/mcp/server/auth/middleware/__init__.py create mode 100644 src/mcp/mcp/server/auth/middleware/auth_context.py create mode 100644 src/mcp/mcp/server/auth/middleware/bearer_auth.py create mode 100644 src/mcp/mcp/server/auth/middleware/client_auth.py create mode 100644 src/mcp/mcp/server/auth/provider.py create mode 100644 src/mcp/mcp/server/auth/routes.py create mode 100644 src/mcp/mcp/server/auth/settings.py create mode 100644 src/mcp/mcp/server/elicitation.py create mode 100644 src/mcp/mcp/server/fastmcp/__init__.py create mode 100644 src/mcp/mcp/server/fastmcp/exceptions.py create mode 100644 src/mcp/mcp/server/fastmcp/prompts/__init__.py create mode 100644 src/mcp/mcp/server/fastmcp/prompts/base.py create mode 100644 src/mcp/mcp/server/fastmcp/prompts/manager.py create mode 100644 src/mcp/mcp/server/fastmcp/resources/__init__.py create mode 100644 src/mcp/mcp/server/fastmcp/resources/base.py create mode 100644 src/mcp/mcp/server/fastmcp/resources/resource_manager.py create mode 100644 src/mcp/mcp/server/fastmcp/resources/templates.py create mode 100644 src/mcp/mcp/server/fastmcp/resources/types.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/__init__.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/annotations.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/identity.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/prompt.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/resource.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/session.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/tool.py create mode 100644 src/mcp/mcp/server/fastmcp/secure/utils.py create mode 100644 src/mcp/mcp/server/fastmcp/server.py create mode 100644 src/mcp/mcp/server/fastmcp/tools/__init__.py create mode 100644 src/mcp/mcp/server/fastmcp/tools/base.py create mode 100644 src/mcp/mcp/server/fastmcp/tools/tool_manager.py create mode 100644 src/mcp/mcp/server/fastmcp/utilities/__init__.py create mode 100644 src/mcp/mcp/server/fastmcp/utilities/func_metadata.py create mode 100644 src/mcp/mcp/server/fastmcp/utilities/logging.py create mode 100644 src/mcp/mcp/server/fastmcp/utilities/types.py create mode 100644 src/mcp/mcp/server/lowlevel/__init__.py create mode 100644 src/mcp/mcp/server/lowlevel/helper_types.py create mode 100644 src/mcp/mcp/server/lowlevel/server.py create mode 100644 src/mcp/mcp/server/models.py create mode 100644 src/mcp/mcp/server/session.py create mode 100644 src/mcp/mcp/server/sse.py create mode 100644 src/mcp/mcp/server/stdio.py create mode 100644 src/mcp/mcp/server/streamable_http.py create mode 100644 src/mcp/mcp/server/streamable_http_manager.py create mode 100644 src/mcp/mcp/server/streaming_asgi_transport.py create mode 100644 src/mcp/mcp/server/transport_security.py create mode 100644 src/mcp/mcp/server/websocket.py create mode 100644 src/mcp/mcp/shared/__init__.py create mode 100644 src/mcp/mcp/shared/_httpx_utils.py create mode 100644 src/mcp/mcp/shared/auth.py create mode 100644 src/mcp/mcp/shared/auth_utils.py create mode 100644 src/mcp/mcp/shared/context.py create mode 100644 src/mcp/mcp/shared/exceptions.py create mode 100644 src/mcp/mcp/shared/memory.py create mode 100644 src/mcp/mcp/shared/message.py create mode 100644 src/mcp/mcp/shared/metadata_utils.py create mode 100644 src/mcp/mcp/shared/progress.py create mode 100644 src/mcp/mcp/shared/session.py create mode 100644 src/mcp/mcp/shared/version.py create mode 100644 src/mcp/mcp/types.py diff --git a/src/mcp/mcp/__init__.py b/src/mcp/mcp/__init__.py new file mode 100644 index 000000000..e93b95c90 --- /dev/null +++ b/src/mcp/mcp/__init__.py @@ -0,0 +1,116 @@ +from .client.session import ClientSession +from .client.session_group import ClientSessionGroup +from .client.stdio import StdioServerParameters, stdio_client +from .server.session import ServerSession +from .server.stdio import stdio_server +from .shared.exceptions import McpError +from .types import ( + CallToolRequest, + ClientCapabilities, + ClientNotification, + ClientRequest, + ClientResult, + CompleteRequest, + CreateMessageRequest, + CreateMessageResult, + ErrorData, + GetPromptRequest, + GetPromptResult, + Implementation, + IncludeContext, + InitializedNotification, + InitializeRequest, + InitializeResult, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + ListPromptsRequest, + ListPromptsResult, + ListResourcesRequest, + ListResourcesResult, + ListToolsResult, + LoggingLevel, + LoggingMessageNotification, + Notification, + PingRequest, + ProgressNotification, + PromptsCapability, + ReadResourceRequest, + ReadResourceResult, + Resource, + ResourcesCapability, + ResourceUpdatedNotification, + RootsCapability, + SamplingMessage, + ServerCapabilities, + ServerNotification, + ServerRequest, + ServerResult, + SetLevelRequest, + StopReason, + SubscribeRequest, + Tool, + ToolsCapability, + UnsubscribeRequest, +) +from .types import ( + Role as SamplingRole, +) + +__all__ = [ + "CallToolRequest", + "ClientCapabilities", + "ClientNotification", + "ClientRequest", + "ClientResult", + "ClientSession", + "ClientSessionGroup", + "CreateMessageRequest", + "CreateMessageResult", + "ErrorData", + "GetPromptRequest", + "GetPromptResult", + "Implementation", + "IncludeContext", + "InitializeRequest", + "InitializeResult", + "InitializedNotification", + "JSONRPCError", + "JSONRPCRequest", + "ListPromptsRequest", + "ListPromptsResult", + "ListResourcesRequest", + "ListResourcesResult", + "ListToolsResult", + "LoggingLevel", + "LoggingMessageNotification", + "McpError", + "Notification", + "PingRequest", + "ProgressNotification", + "PromptsCapability", + "ReadResourceRequest", + "ReadResourceResult", + "ResourcesCapability", + "ResourceUpdatedNotification", + "Resource", + "RootsCapability", + "SamplingMessage", + "SamplingRole", + "ServerCapabilities", + "ServerNotification", + "ServerRequest", + "ServerResult", + "ServerSession", + "SetLevelRequest", + "StdioServerParameters", + "StopReason", + "SubscribeRequest", + "Tool", + "ToolsCapability", + "UnsubscribeRequest", + "stdio_client", + "stdio_server", + "CompleteRequest", + "JSONRPCResponse", +] diff --git a/src/mcp/mcp/cli/__init__.py b/src/mcp/mcp/cli/__init__.py new file mode 100644 index 000000000..3ef56d806 --- /dev/null +++ b/src/mcp/mcp/cli/__init__.py @@ -0,0 +1,6 @@ +"""FastMCP CLI package.""" + +from .cli import app + +if __name__ == "__main__": + app() diff --git a/src/mcp/mcp/cli/claude.py b/src/mcp/mcp/cli/claude.py new file mode 100644 index 000000000..6a2effa3b --- /dev/null +++ b/src/mcp/mcp/cli/claude.py @@ -0,0 +1,148 @@ +"""Claude app integration utilities.""" + +import json +import os +import shutil +import sys +from pathlib import Path +from typing import Any + +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + +MCP_PACKAGE = "mcp[cli]" + + +def get_claude_config_path() -> Path | None: + """Get the Claude config directory based on platform.""" + if sys.platform == "win32": + path = Path(Path.home(), "AppData", "Roaming", "Claude") + elif sys.platform == "darwin": + path = Path(Path.home(), "Library", "Application Support", "Claude") + elif sys.platform.startswith("linux"): + path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude") + else: + return None + + if path.exists(): + return path + return None + + +def get_uv_path() -> str: + """Get the full path to the uv executable.""" + uv_path = shutil.which("uv") + if not uv_path: + logger.error( + "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" + ) + return "uv" # Fall back to just "uv" if not found + return uv_path + + +def update_claude_config( + file_spec: str, + server_name: str, + *, + with_editable: Path | None = None, + with_packages: list[str] | None = None, + env_vars: dict[str, str] | None = None, +) -> bool: + """Add or update a FastMCP server in Claude's configuration. + + Args: + file_spec: Path to the server file, optionally with :object suffix + server_name: Name for the server in Claude's config + with_editable: Optional directory to install in editable mode + with_packages: Optional list of additional packages to install + env_vars: Optional dictionary of environment variables. These are merged with + any existing variables, with new values taking precedence. + + Raises: + RuntimeError: If Claude Desktop's config directory is not found, indicating + Claude Desktop may not be installed or properly set up. + """ + config_dir = get_claude_config_path() + uv_path = get_uv_path() + if not config_dir: + raise RuntimeError( + "Claude Desktop config directory not found. Please ensure Claude Desktop" + " is installed and has been run at least once to initialize its config." + ) + + config_file = config_dir / "claude_desktop_config.json" + if not config_file.exists(): + try: + config_file.write_text("{}") + except Exception: + logger.exception( + "Failed to create Claude config file", + extra={ + "config_file": str(config_file), + }, + ) + return False + + try: + config = json.loads(config_file.read_text()) + if "mcpServers" not in config: + config["mcpServers"] = {} + + # Always preserve existing env vars and merge with new ones + if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: + existing_env = config["mcpServers"][server_name]["env"] + if env_vars: + # New vars take precedence over existing ones + env_vars = {**existing_env, **env_vars} + else: + env_vars = existing_env + + # Build uv run command + args = ["run"] + + # Collect all packages in a set to deduplicate + packages = {MCP_PACKAGE} + if with_packages: + packages.update(pkg for pkg in with_packages if pkg) + + # Add all packages with --with + for pkg in sorted(packages): + args.extend(["--with", pkg]) + + if with_editable: + args.extend(["--with-editable", str(with_editable)]) + + # Convert file path to absolute before adding to command + # Split off any :object suffix first + if ":" in file_spec: + file_path, server_object = file_spec.rsplit(":", 1) + file_spec = f"{Path(file_path).resolve()}:{server_object}" + else: + file_spec = str(Path(file_spec).resolve()) + + # Add fastmcp run command + args.extend(["mcp", "run", file_spec]) + + server_config: dict[str, Any] = {"command": uv_path, "args": args} + + # Add environment variables if specified + if env_vars: + server_config["env"] = env_vars + + config["mcpServers"][server_name] = server_config + + config_file.write_text(json.dumps(config, indent=2)) + logger.info( + f"Added server '{server_name}' to Claude config", + extra={"config_file": str(config_file)}, + ) + return True + except Exception: + logger.exception( + "Failed to update Claude config", + extra={ + "config_file": str(config_file), + }, + ) + return False diff --git a/src/mcp/mcp/cli/cli.py b/src/mcp/mcp/cli/cli.py new file mode 100644 index 000000000..4a7257a11 --- /dev/null +++ b/src/mcp/mcp/cli/cli.py @@ -0,0 +1,488 @@ +"""MCP CLI tools.""" + +import importlib.metadata +import importlib.util +import os +import subprocess +import sys +from pathlib import Path +from typing import Annotated, Any + +from mcp.server import FastMCP +from mcp.server import Server as LowLevelServer + +try: + import typer +except ImportError: + print("Error: typer is required. Install with 'pip install mcp[cli]'") + sys.exit(1) + +try: + from mcp.cli import claude + from mcp.server.fastmcp.utilities.logging import get_logger +except ImportError: + print("Error: mcp.server.fastmcp is not installed or not in PYTHONPATH") + sys.exit(1) + +try: + import dotenv +except ImportError: + dotenv = None + +logger = get_logger("cli") + +app = typer.Typer( + name="mcp", + help="MCP development tools", + add_completion=False, + no_args_is_help=True, # Show help if no args provided +) + + +def _get_npx_command(): + """Get the correct npx command for the current platform.""" + if sys.platform == "win32": + # Try both npx.cmd and npx.exe on Windows + for cmd in ["npx.cmd", "npx.exe", "npx"]: + try: + subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True) + return cmd + except subprocess.CalledProcessError: + continue + return None + return "npx" # On Unix-like systems, just use npx + + +def _parse_env_var(env_var: str) -> tuple[str, str]: + """Parse environment variable string in format KEY=VALUE.""" + if "=" not in env_var: + logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE") + sys.exit(1) + key, value = env_var.split("=", 1) + return key.strip(), value.strip() + + +def _build_uv_command( + file_spec: str, + with_editable: Path | None = None, + with_packages: list[str] | None = None, +) -> list[str]: + """Build the uv run command that runs a MCP server through mcp run.""" + cmd = ["uv"] + + cmd.extend(["run", "--with", "mcp"]) + + if with_editable: + cmd.extend(["--with-editable", str(with_editable)]) + + if with_packages: + for pkg in with_packages: + if pkg: + cmd.extend(["--with", pkg]) + + # Add mcp run command + cmd.extend(["mcp", "run", file_spec]) + return cmd + + +def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: + """Parse a file path that may include a server object specification. + + Args: + file_spec: Path to file, optionally with :object suffix + + Returns: + Tuple of (file_path, server_object) + """ + # First check if we have a Windows path (e.g., C:\...) + has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" + + # Split on the last colon, but only if it's not part of the Windows drive letter + # and there's actually another colon in the string after the drive letter + if ":" in (file_spec[2:] if has_windows_drive else file_spec): + file_str, server_object = file_spec.rsplit(":", 1) + else: + file_str, server_object = file_spec, None + + # Resolve the file path + file_path = Path(file_str).expanduser().resolve() + if not file_path.exists(): + logger.error(f"File not found: {file_path}") + sys.exit(1) + if not file_path.is_file(): + logger.error(f"Not a file: {file_path}") + sys.exit(1) + + return file_path, server_object + + +def _import_server(file: Path, server_object: str | None = None): + """Import a MCP server from a file. + + Args: + file: Path to the file + server_object: Optional object name in format "module:object" or just "object" + + Returns: + The server object + """ + # Add parent directory to Python path so imports can be resolved + file_dir = str(file.parent) + if file_dir not in sys.path: + sys.path.insert(0, file_dir) + + # Import the module + spec = importlib.util.spec_from_file_location("server_module", file) + if not spec or not spec.loader: + logger.error("Could not load module", extra={"file": str(file)}) + sys.exit(1) + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + def _check_server_object(server_object: Any, object_name: str): + """Helper function to check that the server object is supported + + Args: + server_object: The server object to check. + + Returns: + True if it's supported. + """ + if not isinstance(server_object, FastMCP): + logger.error(f"The server object {object_name} is of type {type(server_object)} (expecting {FastMCP}).") + if isinstance(server_object, LowLevelServer): + logger.warning( + "Note that only FastMCP server is supported. Low level Server class is not yet supported." + ) + return False + return True + + # If no object specified, try common server names + if not server_object: + # Look for the most common server object names + for name in ["mcp", "server", "app"]: + if hasattr(module, name): + if not _check_server_object(getattr(module, name), f"{file}:{name}"): + logger.error(f"Ignoring object '{file}:{name}' as it's not a valid server object") + continue + return getattr(module, name) + + logger.error( + f"No server object found in {file}. Please either:\n" + "1. Use a standard variable name (mcp, server, or app)\n" + "2. Specify the object name with file:object syntax" + "3. If the server creates the FastMCP object within main() " + " or another function, refactor the FastMCP object to be a " + " global variable named mcp, server, or app.", + extra={"file": str(file)}, + ) + sys.exit(1) + + # Handle module:object syntax + if ":" in server_object: + module_name, object_name = server_object.split(":", 1) + try: + server_module = importlib.import_module(module_name) + server = getattr(server_module, object_name, None) + except ImportError: + logger.error( + f"Could not import module '{module_name}'", + extra={"file": str(file)}, + ) + sys.exit(1) + else: + # Just object name + server = getattr(module, server_object, None) + + if server is None: + logger.error( + f"Server object '{server_object}' not found", + extra={"file": str(file)}, + ) + sys.exit(1) + + if not _check_server_object(server, server_object): + sys.exit(1) + + return server + + +@app.command() +def version() -> None: + """Show the MCP version.""" + try: + version = importlib.metadata.version("mcp") + print(f"MCP version {version}") + except importlib.metadata.PackageNotFoundError: + print("MCP version unknown (package not installed)") + sys.exit(1) + + +@app.command() +def dev( + file_spec: str = typer.Argument( + ..., + help="Python file to run, optionally with :object suffix", + ), + with_editable: Annotated[ + Path | None, + typer.Option( + "--with-editable", + "-e", + help="Directory containing pyproject.toml to install in editable mode", + exists=True, + file_okay=False, + resolve_path=True, + ), + ] = None, + with_packages: Annotated[ + list[str], + typer.Option( + "--with", + help="Additional packages to install", + ), + ] = [], +) -> None: + """Run a MCP server with the MCP Inspector.""" + file, server_object = _parse_file_path(file_spec) + + logger.debug( + "Starting dev server", + extra={ + "file": str(file), + "server_object": server_object, + "with_editable": str(with_editable) if with_editable else None, + "with_packages": with_packages, + }, + ) + + try: + # Import server to get dependencies + server = _import_server(file, server_object) + if hasattr(server, "dependencies"): + with_packages = list(set(with_packages + server.dependencies)) + + uv_cmd = _build_uv_command(file_spec, with_editable, with_packages) + + # Get the correct npx command + npx_cmd = _get_npx_command() + if not npx_cmd: + logger.error( + "npx not found. Please ensure Node.js and npm are properly installed and added to your system PATH." + ) + sys.exit(1) + + # Run the MCP Inspector command with shell=True on Windows + shell = sys.platform == "win32" + process = subprocess.run( + [npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd, + check=True, + shell=shell, + env=dict(os.environ.items()), # Convert to list of tuples for env update + ) + sys.exit(process.returncode) + except subprocess.CalledProcessError as e: + logger.error( + "Dev server failed", + extra={ + "file": str(file), + "error": str(e), + "returncode": e.returncode, + }, + ) + sys.exit(e.returncode) + except FileNotFoundError: + logger.error( + "npx not found. Please ensure Node.js and npm are properly installed " + "and added to your system PATH. You may need to restart your terminal " + "after installation.", + extra={"file": str(file)}, + ) + sys.exit(1) + + +@app.command() +def run( + file_spec: str = typer.Argument( + ..., + help="Python file to run, optionally with :object suffix", + ), + transport: Annotated[ + str | None, + typer.Option( + "--transport", + "-t", + help="Transport protocol to use (stdio or sse)", + ), + ] = None, +) -> None: + """Run a MCP server. + + The server can be specified in two ways:\n + 1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n + 2. Import approach: server.py:app - imports and runs the specified server object.\n\n + + Note: This command runs the server directly. You are responsible for ensuring + all dependencies are available.\n + For dependency management, use `mcp install` or `mcp dev` instead. + """ # noqa: E501 + file, server_object = _parse_file_path(file_spec) + + logger.debug( + "Running server", + extra={ + "file": str(file), + "server_object": server_object, + "transport": transport, + }, + ) + + try: + # Import and get server object + server = _import_server(file, server_object) + + # Run the server + kwargs = {} + if transport: + kwargs["transport"] = transport + + server.run(**kwargs) + + except Exception: + logger.exception( + "Failed to run server", + extra={ + "file": str(file), + }, + ) + sys.exit(1) + + +@app.command() +def install( + file_spec: str = typer.Argument( + ..., + help="Python file to run, optionally with :object suffix", + ), + server_name: Annotated[ + str | None, + typer.Option( + "--name", + "-n", + help="Custom name for the server (defaults to server's name attribute or file name)", + ), + ] = None, + with_editable: Annotated[ + Path | None, + typer.Option( + "--with-editable", + "-e", + help="Directory containing pyproject.toml to install in editable mode", + exists=True, + file_okay=False, + resolve_path=True, + ), + ] = None, + with_packages: Annotated[ + list[str], + typer.Option( + "--with", + help="Additional packages to install", + ), + ] = [], + env_vars: Annotated[ + list[str], + typer.Option( + "--env-var", + "-v", + help="Environment variables in KEY=VALUE format", + ), + ] = [], + env_file: Annotated[ + Path | None, + typer.Option( + "--env-file", + "-f", + help="Load environment variables from a .env file", + exists=True, + file_okay=True, + dir_okay=False, + resolve_path=True, + ), + ] = None, +) -> None: + """Install a MCP server in the Claude desktop app. + + Environment variables are preserved once added and only updated if new values + are explicitly provided. + """ + file, server_object = _parse_file_path(file_spec) + + logger.debug( + "Installing server", + extra={ + "file": str(file), + "server_name": server_name, + "server_object": server_object, + "with_editable": str(with_editable) if with_editable else None, + "with_packages": with_packages, + }, + ) + + if not claude.get_claude_config_path(): + logger.error("Claude app not found") + sys.exit(1) + + # Try to import server to get its name, but fall back to file name if dependencies + # missing + name = server_name + server = None + if not name: + try: + server = _import_server(file, server_object) + name = server.name + except (ImportError, ModuleNotFoundError) as e: + logger.debug( + "Could not import server (likely missing dependencies), using file name", + extra={"error": str(e)}, + ) + name = file.stem + + # Get server dependencies if available + server_dependencies = getattr(server, "dependencies", []) if server else [] + if server_dependencies: + with_packages = list(set(with_packages + server_dependencies)) + + # Process environment variables if provided + env_dict: dict[str, str] | None = None + if env_file or env_vars: + env_dict = {} + # Load from .env file if specified + if env_file: + if dotenv: + try: + env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None} + except (OSError, ValueError): + logger.exception("Failed to load .env file") + sys.exit(1) + else: + logger.error("python-dotenv is not installed. Cannot load .env file.") + sys.exit(1) + + # Add command line environment variables + for env_var in env_vars: + key, value = _parse_env_var(env_var) + env_dict[key] = value + + if claude.update_claude_config( + file_spec, + name, + with_editable=with_editable, + with_packages=with_packages, + env_vars=env_dict, + ): + logger.info(f"Successfully installed {name} in Claude app") + else: + logger.error(f"Failed to install {name} in Claude app") + sys.exit(1) diff --git a/src/mcp/mcp/client/__init__.py b/src/mcp/mcp/client/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mcp/mcp/client/__main__.py b/src/mcp/mcp/client/__main__.py new file mode 100644 index 000000000..2efe05d53 --- /dev/null +++ b/src/mcp/mcp/client/__main__.py @@ -0,0 +1,85 @@ +import argparse +import logging +import sys +from functools import partial +from urllib.parse import urlparse + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.shared.message import SessionMessage +from mcp.shared.session import RequestResponder + +if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("client") + + +async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, +) -> None: + if isinstance(message, Exception): + logger.error("Error: %s", message) + return + + logger.info("Received message from server: %s", message) + + +async def run_session( + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + client_info: types.Implementation | None = None, +): + async with ClientSession( + read_stream, + write_stream, + message_handler=message_handler, + client_info=client_info, + ) as session: + logger.info("Initializing session") + await session.initialize() + logger.info("Initialized") + + +async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]): + env_dict = dict(env) + + if urlparse(command_or_url).scheme in ("http", "https"): + # Use SSE client for HTTP(S) URLs + async with sse_client(command_or_url) as streams: + await run_session(*streams) + else: + # Use stdio client for commands + server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict) + async with stdio_client(server_parameters) as streams: + await run_session(*streams) + + +def cli(): + parser = argparse.ArgumentParser() + parser.add_argument("command_or_url", help="Command or URL to connect to") + parser.add_argument("args", nargs="*", help="Additional arguments") + parser.add_argument( + "-e", + "--env", + nargs=2, + action="append", + metavar=("KEY", "VALUE"), + help="Environment variables to set. Can be used multiple times.", + default=[], + ) + + args = parser.parse_args() + anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio") + + +if __name__ == "__main__": + cli() diff --git a/src/mcp/mcp/client/auth.py b/src/mcp/mcp/client/auth.py new file mode 100644 index 000000000..376036e8c --- /dev/null +++ b/src/mcp/mcp/client/auth.py @@ -0,0 +1,551 @@ +""" +OAuth2 Authentication implementation for HTTPX. + +Implements authorization code flow with PKCE and automatic token refresh. +""" + +import base64 +import hashlib +import logging +import re +import secrets +import string +import time +from collections.abc import AsyncGenerator, Awaitable, Callable +from dataclasses import dataclass, field +from typing import Protocol +from urllib.parse import urlencode, urljoin, urlparse + +import anyio +import httpx +from pydantic import BaseModel, Field, ValidationError + +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) +from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url +from mcp.types import LATEST_PROTOCOL_VERSION + +logger = logging.getLogger(__name__) + + +class OAuthFlowError(Exception): + """Base exception for OAuth flow errors.""" + + +class OAuthTokenError(OAuthFlowError): + """Raised when token operations fail.""" + + +class OAuthRegistrationError(OAuthFlowError): + """Raised when client registration fails.""" + + +class PKCEParameters(BaseModel): + """PKCE (Proof Key for Code Exchange) parameters.""" + + code_verifier: str = Field(..., min_length=43, max_length=128) + code_challenge: str = Field(..., min_length=43, max_length=128) + + @classmethod + def generate(cls) -> "PKCEParameters": + """Generate new PKCE parameters.""" + code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") + return cls(code_verifier=code_verifier, code_challenge=code_challenge) + + +class TokenStorage(Protocol): + """Protocol for token storage implementations.""" + + async def get_tokens(self) -> OAuthToken | None: + """Get stored tokens.""" + ... + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens.""" + ... + + async def get_client_info(self) -> OAuthClientInformationFull | None: + """Get stored client information.""" + ... + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information.""" + ... + + +@dataclass +class OAuthContext: + """OAuth flow context.""" + + server_url: str + client_metadata: OAuthClientMetadata + storage: TokenStorage + redirect_handler: Callable[[str], Awaitable[None]] + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] + timeout: float = 300.0 + + # Discovered metadata + protected_resource_metadata: ProtectedResourceMetadata | None = None + oauth_metadata: OAuthMetadata | None = None + auth_server_url: str | None = None + protocol_version: str | None = None + + # Client registration + client_info: OAuthClientInformationFull | None = None + + # Token management + current_tokens: OAuthToken | None = None + token_expiry_time: float | None = None + + # State + lock: anyio.Lock = field(default_factory=anyio.Lock) + + # Discovery state for fallback support + discovery_base_url: str | None = None + discovery_pathname: str | None = None + + def get_authorization_base_url(self, server_url: str) -> str: + """Extract base URL by removing path component.""" + parsed = urlparse(server_url) + return f"{parsed.scheme}://{parsed.netloc}" + + def update_token_expiry(self, token: OAuthToken) -> None: + """Update token expiry time.""" + if token.expires_in: + self.token_expiry_time = time.time() + token.expires_in + else: + self.token_expiry_time = None + + def is_token_valid(self) -> bool: + """Check if current token is valid.""" + return bool( + self.current_tokens + and self.current_tokens.access_token + and (not self.token_expiry_time or time.time() <= self.token_expiry_time) + ) + + def can_refresh_token(self) -> bool: + """Check if token can be refreshed.""" + return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) + + def clear_tokens(self) -> None: + """Clear current tokens.""" + self.current_tokens = None + self.token_expiry_time = None + + def get_resource_url(self) -> str: + """Get resource URL for RFC 8707. + + Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. + """ + resource = resource_url_from_server_url(self.server_url) + + # If PRM provides a resource that's a valid parent, use it + if self.protected_resource_metadata and self.protected_resource_metadata.resource: + prm_resource = str(self.protected_resource_metadata.resource) + if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): + resource = prm_resource + + return resource + + def should_include_resource_param(self, protocol_version: str | None = None) -> bool: + """Determine if the resource parameter should be included in OAuth requests. + + Returns True if: + - Protected resource metadata is available, OR + - MCP-Protocol-Version header is 2025-06-18 or later + """ + # If we have protected resource metadata, include the resource param + if self.protected_resource_metadata is not None: + return True + + # If no protocol version provided, don't include resource param + if not protocol_version: + return False + + # Check if protocol version is 2025-06-18 or later + # Version format is YYYY-MM-DD, so string comparison works + return protocol_version >= "2025-06-18" + + +class OAuthClientProvider(httpx.Auth): + """ + OAuth2 authentication for httpx. + Handles OAuth flow with automatic client registration and token storage. + """ + + requires_response_body = True + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + redirect_handler: Callable[[str], Awaitable[None]], + callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], + timeout: float = 300.0, + ): + """Initialize OAuth2 authentication.""" + self.context = OAuthContext( + server_url=server_url, + client_metadata=client_metadata, + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + timeout=timeout, + ) + self._initialized = False + + def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not init_response or init_response.status_code != 401: + return None + + www_auth_header = init_response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: + # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response + url = self._extract_resource_metadata_from_www_auth(init_response) + + if not url: + # Fallback to well-known discovery + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + async def _handle_protected_resource_response(self, response: httpx.Response) -> None: + """Handle discovery response.""" + if response.status_code == 200: + try: + content = await response.aread() + metadata = ProtectedResourceMetadata.model_validate_json(content) + self.context.protected_resource_metadata = metadata + if metadata.authorization_servers: + self.context.auth_server_url = str(metadata.authorization_servers[0]) + except ValidationError: + pass + + def _get_discovery_urls(self) -> list[str]: + """Generate ordered list of (url, type) tuples for discovery attempts.""" + urls: list[str] = [] + auth_server_url = self.context.auth_server_url or self.context.server_url + parsed = urlparse(auth_server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) + + # OAuth root fallback + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) + + # RFC 8414 section 5: Path-aware OIDC discovery + # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + if parsed.path and parsed.path != "/": + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) + + # OIDC 1.0 fallback (appends to full URL per OIDC spec) + oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" + urls.append(oidc_fallback) + + return urls + + async def _register_client(self) -> httpx.Request | None: + """Build registration request or skip if already registered.""" + if self.context.client_info: + return None + + if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: + registration_url = str(self.context.oauth_metadata.registration_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + registration_url = urljoin(auth_base_url, "/register") + + registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + + return httpx.Request( + "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} + ) + + async def _handle_registration_response(self, response: httpx.Response) -> None: + """Handle registration response.""" + if response.status_code not in (200, 201): + await response.aread() + raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") + + try: + content = await response.aread() + client_info = OAuthClientInformationFull.model_validate_json(content) + self.context.client_info = client_info + await self.context.storage.set_client_info(client_info) + except ValidationError as e: + raise OAuthRegistrationError(f"Invalid registration response: {e}") + + async def _perform_authorization(self) -> tuple[str, str]: + """Perform the authorization redirect and get auth code.""" + if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + auth_endpoint = urljoin(auth_base_url, "/authorize") + + if not self.context.client_info: + raise OAuthFlowError("No client info available for authorization") + + # Generate PKCE parameters + pkce_params = PKCEParameters.generate() + state = secrets.token_urlsafe(32) + + auth_params = { + "response_type": "code", + "client_id": self.context.client_info.client_id, + "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), + "state": state, + "code_challenge": pkce_params.code_challenge, + "code_challenge_method": "S256", + } + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 + + if self.context.client_metadata.scope: + auth_params["scope"] = self.context.client_metadata.scope + + authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" + await self.context.redirect_handler(authorization_url) + + # Wait for callback + auth_code, returned_state = await self.context.callback_handler() + + if returned_state is None or not secrets.compare_digest(returned_state, state): + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") + + if not auth_code: + raise OAuthFlowError("No authorization code received") + + # Return auth code and code verifier for token exchange + return auth_code, pkce_params.code_verifier + + async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request: + """Build token exchange request.""" + if not self.context.client_info: + raise OAuthFlowError("Missing client info") + + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + token_url = urljoin(auth_base_url, "/token") + + token_data = { + "grant_type": "authorization_code", + "code": auth_code, + "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), + "client_id": self.context.client_info.client_id, + "code_verifier": code_verifier, + } + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() # RFC 8707 + + if self.context.client_info.client_secret: + token_data["client_secret"] = self.context.client_info.client_secret + + return httpx.Request( + "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + + async def _handle_token_response(self, response: httpx.Response) -> None: + """Handle token exchange response.""" + if response.status_code != 200: + raise OAuthTokenError(f"Token exchange failed: {response.status_code}") + + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + + # Validate scopes + if token_response.scope and self.context.client_metadata.scope: + requested_scopes = set(self.context.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}") + + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) + except ValidationError as e: + raise OAuthTokenError(f"Invalid token response: {e}") + + async def _refresh_token(self) -> httpx.Request: + """Build token refresh request.""" + if not self.context.current_tokens or not self.context.current_tokens.refresh_token: + raise OAuthTokenError("No refresh token available") + + if not self.context.client_info: + raise OAuthTokenError("No client info available") + + if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: + token_url = str(self.context.oauth_metadata.token_endpoint) + else: + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + token_url = urljoin(auth_base_url, "/token") + + refresh_data = { + "grant_type": "refresh_token", + "refresh_token": self.context.current_tokens.refresh_token, + "client_id": self.context.client_info.client_id, + } + + # Only include resource param if conditions are met + if self.context.should_include_resource_param(self.context.protocol_version): + refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 + + if self.context.client_info.client_secret: + refresh_data["client_secret"] = self.context.client_info.client_secret + + return httpx.Request( + "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} + ) + + async def _handle_refresh_response(self, response: httpx.Response) -> bool: + """Handle token refresh response. Returns True if successful.""" + if response.status_code != 200: + logger.warning(f"Token refresh failed: {response.status_code}") + self.context.clear_tokens() + return False + + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) + + return True + except ValidationError: + logger.exception("Invalid refresh response") + self.context.clear_tokens() + return False + + async def _initialize(self) -> None: + """Load stored tokens and client info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = await self.context.storage.get_client_info() + self._initialized = True + + def _add_auth_header(self, request: httpx.Request) -> None: + """Add authorization header to request if we have valid tokens.""" + if self.context.current_tokens and self.context.current_tokens.access_token: + request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + + def _create_oauth_metadata_request(self, url: str) -> httpx.Request: + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: + content = await response.aread() + metadata = OAuthMetadata.model_validate_json(content) + self.context.oauth_metadata = metadata + # Apply default scope if needed + if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: + self.context.client_metadata.scope = " ".join(metadata.scopes_supported) + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + """HTTPX auth flow integration.""" + async with self.context.lock: + if not self._initialized: + await self._initialize() + + # Capture protocol version from request headers + self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) + + if not self.context.is_token_valid() and self.context.can_refresh_token(): + # Try to refresh token + refresh_request = await self._refresh_token() + refresh_response = yield refresh_request + + if not await self._handle_refresh_response(refresh_response): + # Refresh failed, need full re-authentication + self._initialized = False + + if self.context.is_token_valid(): + self._add_auth_header(request) + + response = yield request + + if response.status_code == 401: + # Perform full OAuth flow + try: + # OAuth flow must be inline due to generator constraints + # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) + discovery_request = await self._discover_protected_resource(response) + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) + + # Step 2: Discover OAuth metadata (with fallback for legacy servers) + discovery_urls = self._get_discovery_urls() + for url in discovery_urls: + oauth_metadata_request = self._create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request + + if oauth_metadata_response.status_code == 200: + try: + await self._handle_oauth_metadata_response(oauth_metadata_response) + break + except ValidationError: + continue + elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: + break # Non-4XX error, stop trying + + # Step 3: Register client if needed + registration_request = await self._register_client() + if registration_request: + registration_response = yield registration_request + await self._handle_registration_response(registration_response) + + # Step 4: Perform authorization + auth_code, code_verifier = await self._perform_authorization() + + # Step 5: Exchange authorization code for tokens + token_request = await self._exchange_token(auth_code, code_verifier) + token_response = yield token_request + await self._handle_token_response(token_response) + except Exception: + logger.exception("OAuth flow error") + raise + + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/src/mcp/mcp/client/session.py b/src/mcp/mcp/client/session.py new file mode 100644 index 000000000..1853ce7c1 --- /dev/null +++ b/src/mcp/mcp/client/session.py @@ -0,0 +1,458 @@ +import logging +from datetime import timedelta +from typing import Any, Protocol + +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from jsonschema import SchemaError, ValidationError, validate +from pydantic import AnyUrl, TypeAdapter + +import mcp.types as types +from mcp.shared.context import RequestContext +from mcp.shared.message import SessionMessage +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS + +DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") + +logger = logging.getLogger("client") + + +class SamplingFnT(Protocol): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, + ) -> types.CreateMessageResult | types.ErrorData: ... + + +class ElicitationFnT(Protocol): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + ) -> types.ElicitResult | types.ErrorData: ... + + +class ListRootsFnT(Protocol): + async def __call__( + self, context: RequestContext["ClientSession", Any] + ) -> types.ListRootsResult | types.ErrorData: ... + + +class LoggingFnT(Protocol): + async def __call__( + self, + params: types.LoggingMessageNotificationParams, + ) -> None: ... + + +class MessageHandlerFnT(Protocol): + async def __call__( + self, + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: ... + + +async def _default_message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, +) -> None: + await anyio.lowlevel.checkpoint() + + +async def _default_sampling_callback( + context: RequestContext["ClientSession", Any], + params: types.CreateMessageRequestParams, +) -> types.CreateMessageResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Sampling not supported", + ) + + +async def _default_elicitation_callback( + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, +) -> types.ElicitResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Elicitation not supported", + ) + + +async def _default_list_roots_callback( + context: RequestContext["ClientSession", Any], +) -> types.ListRootsResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="List roots not supported", + ) + + +async def _default_logging_callback( + params: types.LoggingMessageNotificationParams, +) -> None: + pass + + +ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) + + +class ClientSession( + BaseSession[ + types.ClientRequest, + types.ClientNotification, + types.ClientResult, + types.ServerRequest, + types.ServerNotification, + ] +): + def __init__( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, + ) -> None: + super().__init__( + read_stream, + write_stream, + types.ServerRequest, + types.ServerNotification, + read_timeout_seconds=read_timeout_seconds, + ) + self._client_info = client_info or DEFAULT_CLIENT_INFO + self._sampling_callback = sampling_callback or _default_sampling_callback + self._elicitation_callback = elicitation_callback or _default_elicitation_callback + self._list_roots_callback = list_roots_callback or _default_list_roots_callback + self._logging_callback = logging_callback or _default_logging_callback + self._message_handler = message_handler or _default_message_handler + self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} + + async def initialize(self) -> types.InitializeResult: + sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None + elicitation = ( + types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None + ) + roots = ( + # TODO: Should this be based on whether we + # _will_ send notifications, or only whether + # they're supported? + types.RootsCapability(listChanged=True) + if self._list_roots_callback is not _default_list_roots_callback + else None + ) + + result = await self.send_request( + types.ClientRequest( + types.InitializeRequest( + method="initialize", + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + sampling=sampling, + elicitation=elicitation, + experimental=None, + roots=roots, + ), + clientInfo=self._client_info, + ), + ) + ), + types.InitializeResult, + ) + + if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: + raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") + + await self.send_notification( + types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) + ) + + return result + + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + return await self.send_request( + types.ClientRequest( + types.PingRequest( + method="ping", + ) + ), + types.EmptyResult, + ) + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """Send a progress notification.""" + await self.send_notification( + types.ClientNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + message=message, + ), + ), + ) + ) + + async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: + """Send a logging/setLevel request.""" + return await self.send_request( + types.ClientRequest( + types.SetLevelRequest( + method="logging/setLevel", + params=types.SetLevelRequestParams(level=level), + ) + ), + types.EmptyResult, + ) + + async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult: + """Send a resources/list request.""" + return await self.send_request( + types.ClientRequest( + types.ListResourcesRequest( + method="resources/list", + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + ) + ), + types.ListResourcesResult, + ) + + async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult: + """Send a resources/templates/list request.""" + return await self.send_request( + types.ClientRequest( + types.ListResourceTemplatesRequest( + method="resources/templates/list", + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + ) + ), + types.ListResourceTemplatesResult, + ) + + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Send a resources/read request.""" + return await self.send_request( + types.ClientRequest( + types.ReadResourceRequest( + method="resources/read", + params=types.ReadResourceRequestParams(uri=uri), + ) + ), + types.ReadResourceResult, + ) + + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + return await self.send_request( + types.ClientRequest( + types.SubscribeRequest( + method="resources/subscribe", + params=types.SubscribeRequestParams(uri=uri), + ) + ), + types.EmptyResult, + ) + + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + return await self.send_request( + types.ClientRequest( + types.UnsubscribeRequest( + method="resources/unsubscribe", + params=types.UnsubscribeRequestParams(uri=uri), + ) + ), + types.EmptyResult, + ) + + async def call_tool( + self, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.CallToolResult: + """Send a tools/call request with optional progress callback support.""" + + result = await self.send_request( + types.ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name=name, + arguments=arguments, + ), + ) + ), + types.CallToolResult, + request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, + ) + + if not result.isError: + await self._validate_tool_result(name, result) + + return result + + async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: + """Validate the structured content of a tool result against its output schema.""" + if name not in self._tool_output_schemas: + # refresh output schema cache + await self.list_tools() + + output_schema = None + if name in self._tool_output_schemas: + output_schema = self._tool_output_schemas.get(name) + else: + logger.warning(f"Tool {name} not listed by server, cannot validate any structured content") + + if output_schema is not None: + if result.structuredContent is None: + raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") + try: + validate(result.structuredContent, output_schema) + except ValidationError as e: + raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") + except SchemaError as e: + raise RuntimeError(f"Invalid schema for tool {name}: {e}") + + async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: + """Send a prompts/list request.""" + return await self.send_request( + types.ClientRequest( + types.ListPromptsRequest( + method="prompts/list", + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + ) + ), + types.ListPromptsResult, + ) + + async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: + """Send a prompts/get request.""" + return await self.send_request( + types.ClientRequest( + types.GetPromptRequest( + method="prompts/get", + params=types.GetPromptRequestParams(name=name, arguments=arguments), + ) + ), + types.GetPromptResult, + ) + + async def complete( + self, + ref: types.ResourceTemplateReference | types.PromptReference, + argument: dict[str, str], + context_arguments: dict[str, str] | None = None, + ) -> types.CompleteResult: + """Send a completion/complete request.""" + context = None + if context_arguments is not None: + context = types.CompletionContext(arguments=context_arguments) + + return await self.send_request( + types.ClientRequest( + types.CompleteRequest( + method="completion/complete", + params=types.CompleteRequestParams( + ref=ref, + argument=types.CompletionArgument(**argument), + context=context, + ), + ) + ), + types.CompleteResult, + ) + + async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: + """Send a tools/list request.""" + result = await self.send_request( + types.ClientRequest( + types.ListToolsRequest( + method="tools/list", + params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, + ) + ), + types.ListToolsResult, + ) + + # Cache tool output schemas for future validation + # Note: don't clear the cache, as we may be using a cursor + for tool in result.tools: + self._tool_output_schemas[tool.name] = tool.outputSchema + + return result + + async def send_roots_list_changed(self) -> None: + """Send a roots/list_changed notification.""" + await self.send_notification( + types.ClientNotification( + types.RootsListChangedNotification( + method="notifications/roots/list_changed", + ) + ) + ) + + async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: + ctx = RequestContext[ClientSession, Any]( + request_id=responder.request_id, + meta=responder.request_meta, + session=self, + lifespan_context=None, + ) + + match responder.request.root: + case types.CreateMessageRequest(params=params): + with responder: + response = await self._sampling_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + + case types.ElicitRequest(params=params): + with responder: + response = await self._elicitation_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + + case types.ListRootsRequest(): + with responder: + response = await self._list_roots_callback(ctx) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + + case types.PingRequest(): + with responder: + return await responder.respond(types.ClientResult(root=types.EmptyResult())) + + async def _handle_incoming( + self, + req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + """Handle incoming messages by forwarding to the message handler.""" + await self._message_handler(req) + + async def _received_notification(self, notification: types.ServerNotification) -> None: + """Handle notifications from the server.""" + # Process specific notification types + match notification.root: + case types.LoggingMessageNotification(params=params): + await self._logging_callback(params) + case _: + pass diff --git a/src/mcp/mcp/client/session_group.py b/src/mcp/mcp/client/session_group.py new file mode 100644 index 000000000..700b5417f --- /dev/null +++ b/src/mcp/mcp/client/session_group.py @@ -0,0 +1,366 @@ +""" +SessionGroup concurrently manages multiple MCP session connections. + +Tools, resources, and prompts are aggregated across servers. Servers may +be connected to or disconnected from at any point after initialization. + +This abstractions can handle naming collisions using a custom user-provided +hook. +""" + +import contextlib +import logging +from collections.abc import Callable +from datetime import timedelta +from types import TracebackType +from typing import Any, TypeAlias + +import anyio +from pydantic import BaseModel +from typing_extensions import Self + +import mcp +from mcp import types +from mcp.client.sse import sse_client +from mcp.client.stdio import StdioServerParameters +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.exceptions import McpError + + +class SseServerParameters(BaseModel): + """Parameters for intializing a sse_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: float = 5 + + # Timeout for SSE read operations. + sse_read_timeout: float = 60 * 5 + + +class StreamableHttpParameters(BaseModel): + """Parameters for intializing a streamablehttp_client.""" + + # The endpoint URL. + url: str + + # Optional headers to include in requests. + headers: dict[str, Any] | None = None + + # HTTP timeout for regular operations. + timeout: timedelta = timedelta(seconds=30) + + # Timeout for SSE read operations. + sse_read_timeout: timedelta = timedelta(seconds=60 * 5) + + # Close the client session when the transport closes. + terminate_on_close: bool = True + + +ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters + + +class ClientSessionGroup: + """Client for managing connections to multiple MCP servers. + + This class is responsible for encapsulating management of server connections. + It aggregates tools, resources, and prompts from all connected servers. + + For auxiliary handlers, such as resource subscription, this is delegated to + the client and can be accessed via the session. + + Example Usage: + name_fn = lambda name, server_info: f"{(server_info.name)}_{name}" + async with ClientSessionGroup(component_name_hook=name_fn) as group: + for server_params in server_params: + await group.connect_to_server(server_param) + ... + + """ + + class _ComponentNames(BaseModel): + """Used for reverse index to find components.""" + + prompts: set[str] = set() + resources: set[str] = set() + tools: set[str] = set() + + # Standard MCP components. + _prompts: dict[str, types.Prompt] + _resources: dict[str, types.Resource] + _tools: dict[str, types.Tool] + + # Client-server connection management. + _sessions: dict[mcp.ClientSession, _ComponentNames] + _tool_to_session: dict[str, mcp.ClientSession] + _exit_stack: contextlib.AsyncExitStack + _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] + + # Optional fn consuming (component_name, serverInfo) for custom names. + # This is provide a means to mitigate naming conflicts across servers. + # Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}" + _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] + _component_name_hook: _ComponentNameHook | None + + def __init__( + self, + exit_stack: contextlib.AsyncExitStack | None = None, + component_name_hook: _ComponentNameHook | None = None, + ) -> None: + """Initializes the MCP client.""" + + self._tools = {} + self._resources = {} + self._prompts = {} + + self._sessions = {} + self._tool_to_session = {} + if exit_stack is None: + self._exit_stack = contextlib.AsyncExitStack() + self._owns_exit_stack = True + else: + self._exit_stack = exit_stack + self._owns_exit_stack = False + self._session_exit_stacks = {} + self._component_name_hook = component_name_hook + + async def __aenter__(self) -> Self: + # Enter the exit stack only if we created it ourselves + if self._owns_exit_stack: + await self._exit_stack.__aenter__() + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: TracebackType | None, + ) -> bool | None: + """Closes session exit stacks and main exit stack upon completion.""" + + # Only close the main exit stack if we created it + if self._owns_exit_stack: + await self._exit_stack.aclose() + + # Concurrently close session stacks. + async with anyio.create_task_group() as tg: + for exit_stack in self._session_exit_stacks.values(): + tg.start_soon(exit_stack.aclose) + + @property + def sessions(self) -> list[mcp.ClientSession]: + """Returns the list of sessions being managed.""" + return list(self._sessions.keys()) + + @property + def prompts(self) -> dict[str, types.Prompt]: + """Returns the prompts as a dictionary of names to prompts.""" + return self._prompts + + @property + def resources(self) -> dict[str, types.Resource]: + """Returns the resources as a dictionary of names to resources.""" + return self._resources + + @property + def tools(self) -> dict[str, types.Tool]: + """Returns the tools as a dictionary of names to tools.""" + return self._tools + + async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: + """Executes a tool given its name and arguments.""" + session = self._tool_to_session[name] + session_tool_name = self.tools[name].name + return await session.call_tool(session_tool_name, args) + + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: + """Disconnects from a single MCP server.""" + + session_known_for_components = session in self._sessions + session_known_for_stack = session in self._session_exit_stacks + + if not session_known_for_components and not session_known_for_stack: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message="Provided session is not managed or already disconnected.", + ) + ) + + if session_known_for_components: + component_names = self._sessions.pop(session) # Pop from _sessions tracking + + # Remove prompts associated with the session. + for name in component_names.prompts: + if name in self._prompts: + del self._prompts[name] + # Remove resources associated with the session. + for name in component_names.resources: + if name in self._resources: + del self._resources[name] + # Remove tools associated with the session. + for name in component_names.tools: + if name in self._tools: + del self._tools[name] + if name in self._tool_to_session: + del self._tool_to_session[name] + + # Clean up the session's resources via its dedicated exit stack + if session_known_for_stack: + session_stack_to_close = self._session_exit_stacks.pop(session) + await session_stack_to_close.aclose() + + async def connect_with_session( + self, server_info: types.Implementation, session: mcp.ClientSession + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + await self._aggregate_components(server_info, session) + return session + + async def connect_to_server( + self, + server_params: ServerParameters, + ) -> mcp.ClientSession: + """Connects to a single MCP server.""" + server_info, session = await self._establish_session(server_params) + return await self.connect_with_session(server_info, session) + + async def _establish_session( + self, server_params: ServerParameters + ) -> tuple[types.Implementation, mcp.ClientSession]: + """Establish a client session to an MCP server.""" + + session_stack = contextlib.AsyncExitStack() + try: + # Create read and write streams that facilitate io with the server. + if isinstance(server_params, StdioServerParameters): + client = mcp.stdio_client(server_params) + read, write = await session_stack.enter_async_context(client) + elif isinstance(server_params, SseServerParameters): + client = sse_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + ) + read, write = await session_stack.enter_async_context(client) + else: + client = streamablehttp_client( + url=server_params.url, + headers=server_params.headers, + timeout=server_params.timeout, + sse_read_timeout=server_params.sse_read_timeout, + terminate_on_close=server_params.terminate_on_close, + ) + read, write, _ = await session_stack.enter_async_context(client) + + session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) + result = await session.initialize() + + # Session successfully initialized. + # Store its stack and register the stack with the main group stack. + self._session_exit_stacks[session] = session_stack + # session_stack itself becomes a resource managed by the + # main _exit_stack. + await self._exit_stack.enter_async_context(session_stack) + + return result.serverInfo, session + except Exception: + # If anything during this setup fails, ensure the session-specific + # stack is closed. + await session_stack.aclose() + raise + + async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: + """Aggregates prompts, resources, and tools from a given session.""" + + # Create a reverse index so we can find all prompts, resources, and + # tools belonging to this session. Used for removing components from + # the session group via self.disconnect_from_server. + component_names = self._ComponentNames() + + # Temporary components dicts. We do not want to modify the aggregate + # lists in case of an intermediate failure. + prompts_temp: dict[str, types.Prompt] = {} + resources_temp: dict[str, types.Resource] = {} + tools_temp: dict[str, types.Tool] = {} + tool_to_session_temp: dict[str, mcp.ClientSession] = {} + + # Query the server for its prompts and aggregate to list. + try: + prompts = (await session.list_prompts()).prompts + for prompt in prompts: + name = self._component_name(prompt.name, server_info) + prompts_temp[name] = prompt + component_names.prompts.add(name) + except McpError as err: + logging.warning(f"Could not fetch prompts: {err}") + + # Query the server for its resources and aggregate to list. + try: + resources = (await session.list_resources()).resources + for resource in resources: + name = self._component_name(resource.name, server_info) + resources_temp[name] = resource + component_names.resources.add(name) + except McpError as err: + logging.warning(f"Could not fetch resources: {err}") + + # Query the server for its tools and aggregate to list. + try: + tools = (await session.list_tools()).tools + for tool in tools: + name = self._component_name(tool.name, server_info) + tools_temp[name] = tool + tool_to_session_temp[name] = session + component_names.tools.add(name) + except McpError as err: + logging.warning(f"Could not fetch tools: {err}") + + # Clean up exit stack for session if we couldn't retrieve anything + # from the server. + if not any((prompts_temp, resources_temp, tools_temp)): + del self._session_exit_stacks[session] + + # Check for duplicates. + matching_prompts = prompts_temp.keys() & self._prompts.keys() + if matching_prompts: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_prompts} already exist in group prompts.", + ) + ) + matching_resources = resources_temp.keys() & self._resources.keys() + if matching_resources: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_resources} already exist in group resources.", + ) + ) + matching_tools = tools_temp.keys() & self._tools.keys() + if matching_tools: + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"{matching_tools} already exist in group tools.", + ) + ) + + # Aggregate components. + self._sessions[session] = component_names + self._prompts.update(prompts_temp) + self._resources.update(resources_temp) + self._tools.update(tools_temp) + self._tool_to_session.update(tool_to_session_temp) + + def _component_name(self, name: str, server_info: types.Implementation) -> str: + if self._component_name_hook: + return self._component_name_hook(name, server_info) + return name diff --git a/src/mcp/mcp/client/sse.py b/src/mcp/mcp/client/sse.py new file mode 100644 index 000000000..7ca8d19af --- /dev/null +++ b/src/mcp/mcp/client/sse.py @@ -0,0 +1,144 @@ +import logging +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import urljoin, urlparse + +import anyio +import httpx +from anyio.abc import TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import aconnect_sse + +import mcp.types as types +from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + + +def remove_request_params(url: str) -> str: + return urljoin(url, urlparse(url).path) + + +@asynccontextmanager +async def sse_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: float = 5, + sse_read_timeout: float = 60 * 5, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, + auth: httpx.Auth | None = None, +): + """ + Client transport for SSE. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Args: + url: The SSE endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + auth: Optional HTTPX authentication handler. + """ + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async with anyio.create_task_group() as tg: + try: + logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx_client_factory( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) as client: + async with aconnect_sse( + client, + "GET", + url, + ) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + async def sse_reader( + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + try: + async for sse in event_source.aiter_sse(): + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.debug(f"Received endpoint URL: {endpoint_url}") + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( + f"Endpoint origin does not match connection origin: {endpoint_url}" + ) + logger.error(error_msg) + raise ValueError(error_msg) + + task_status.started(endpoint_url) + + case "message": + try: + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data + ) + logger.debug(f"Received server message: {message}") + except Exception as exc: + logger.exception("Error parsing server message") + await read_stream_writer.send(exc) + continue + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + case _: + logger.warning(f"Unknown SSE event: {sse.event}") + except Exception as exc: + logger.exception("Error in sse_reader") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + logger.debug(f"Sending client message: {session_message}") + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + except Exception: + logger.exception("Error in post_writer") + finally: + await write_stream.aclose() + + endpoint_url = await tg.start(sse_reader) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") + tg.start_soon(post_writer, endpoint_url) + + try: + yield read_stream, write_stream + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/mcp/client/stdio/__init__.py b/src/mcp/mcp/client/stdio/__init__.py new file mode 100644 index 000000000..e3532e988 --- /dev/null +++ b/src/mcp/mcp/client/stdio/__init__.py @@ -0,0 +1,277 @@ +import logging +import os +import sys +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Literal, TextIO + +import anyio +import anyio.lowlevel +from anyio.abc import Process +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.streams.text import TextReceiveStream +from pydantic import BaseModel, Field + +import mcp.types as types +from mcp.os.posix.utilities import terminate_posix_process_tree +from mcp.os.win32.utilities import ( + FallbackProcess, + create_windows_process, + get_windows_executable_command, + terminate_windows_process_tree, +) +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + +# Environment variables to inherit by default +DEFAULT_INHERITED_ENV_VARS = ( + [ + "APPDATA", + "HOMEDRIVE", + "HOMEPATH", + "LOCALAPPDATA", + "PATH", + "PATHEXT", + "PROCESSOR_ARCHITECTURE", + "SYSTEMDRIVE", + "SYSTEMROOT", + "TEMP", + "USERNAME", + "USERPROFILE", + ] + if sys.platform == "win32" + else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] +) + +# Timeout for process termination before falling back to force kill +PROCESS_TERMINATION_TIMEOUT = 2.0 + + +def get_default_environment() -> dict[str, str]: + """ + Returns a default environment object including only environment variables deemed + safe to inherit. + """ + env: dict[str, str] = {} + + for key in DEFAULT_INHERITED_ENV_VARS: + value = os.environ.get(key) + if value is None: + continue + + if value.startswith("()"): + # Skip functions, which are a security risk + continue + + env[key] = value + + return env + + +class StdioServerParameters(BaseModel): + command: str + """The executable to run to start the server.""" + + args: list[str] = Field(default_factory=list) + """Command line arguments to pass to the executable.""" + + env: dict[str, str] | None = None + """ + The environment to use when spawning the process. + + If not specified, the result of get_default_environment() will be used. + """ + + cwd: str | Path | None = None + """The working directory to use when spawning the process.""" + + encoding: str = "utf-8" + """ + The text encoding used when sending/receiving messages to the server + + defaults to utf-8 + """ + + encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" + """ + The text encoding error handler. + + See https://docs.python.org/3/library/codecs.html#codec-base-classes for + explanations of possible values + """ + + +@asynccontextmanager +async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): + """ + Client transport for stdio: this will connect to a server by spawning a + process and communicating with it over stdin/stdout. + """ + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + try: + command = _get_executable_command(server.command) + + # Open process with stderr piped for capture + process = await _create_platform_compatible_process( + command=command, + args=server.args, + env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), + errlog=errlog, + cwd=server.cwd, + ) + except OSError: + # Clean up streams if process creation fails + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + raise + + async def stdout_reader(): + assert process.stdout, "Opened process is missing stdout" + + try: + async with read_stream_writer: + buffer = "" + async for chunk in TextReceiveStream( + process.stdout, + encoding=server.encoding, + errors=server.encoding_error_handler, + ): + lines = (buffer + chunk).split("\n") + buffer = lines.pop() + + for line in lines: + try: + message = types.JSONRPCMessage.model_validate_json(line) + except Exception as exc: + await read_stream_writer.send(exc) + continue + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async def stdin_writer(): + assert process.stdin, "Opened process is missing stdin" + + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + await process.stdin.send( + (json + "\n").encode( + encoding=server.encoding, + errors=server.encoding_error_handler, + ) + ) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async with ( + anyio.create_task_group() as tg, + process, + ): + tg.start_soon(stdout_reader) + tg.start_soon(stdin_writer) + try: + yield read_stream, write_stream + finally: + # MCP spec: stdio shutdown sequence + # 1. Close input stream to server + # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time + # 3. Send SIGKILL if still not exited + if process.stdin: + try: + await process.stdin.aclose() + except Exception: + # stdin might already be closed, which is fine + pass + + try: + # Give the process time to exit gracefully after stdin closes + with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): + await process.wait() + except TimeoutError: + # Process didn't exit from stdin closure, use platform-specific termination + # which handles SIGTERM -> SIGKILL escalation + await _terminate_process_tree(process) + except ProcessLookupError: + # Process already exited, which is fine + pass + await read_stream.aclose() + await write_stream.aclose() + await read_stream_writer.aclose() + await write_stream_reader.aclose() + + +def _get_executable_command(command: str) -> str: + """ + Get the correct executable command normalized for the current platform. + + Args: + command: Base command (e.g., 'uvx', 'npx') + + Returns: + str: Platform-appropriate command + """ + if sys.platform == "win32": + return get_windows_executable_command(command) + else: + return command + + +async def _create_platform_compatible_process( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO = sys.stderr, + cwd: Path | str | None = None, +): + """ + Creates a subprocess in a platform-compatible way. + + Unix: Creates process in a new session/process group for killpg support + Windows: Creates process in a Job Object for reliable child termination + """ + if sys.platform == "win32": + process = await create_windows_process(command, args, env, errlog, cwd) + else: + process = await anyio.open_process( + [command, *args], + env=env, + stderr=errlog, + cwd=cwd, + start_new_session=True, + ) + + return process + + +async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None: + """ + Terminate a process and all its children using platform-specific methods. + + Unix: Uses os.killpg() for atomic process group termination + Windows: Uses Job Objects via pywin32 for reliable child process cleanup + + Args: + process: The process to terminate + timeout_seconds: Timeout in seconds before force killing (default: 2.0) + """ + if sys.platform == "win32": + await terminate_windows_process_tree(process, timeout_seconds) + else: + # FallbackProcess should only be used for Windows compatibility + assert isinstance(process, Process) + await terminate_posix_process_tree(process, timeout_seconds) diff --git a/src/mcp/mcp/client/streamable_http.py b/src/mcp/mcp/client/streamable_http.py new file mode 100644 index 000000000..b1ab2c079 --- /dev/null +++ b/src/mcp/mcp/client/streamable_http.py @@ -0,0 +1,511 @@ +""" +StreamableHTTP Client Transport Module + +This module implements the StreamableHTTP transport for MCP clients, +providing support for HTTP POST requests with optional SSE streaming responses +and session management. +""" + +import logging +from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass +from datetime import timedelta + +import anyio +import httpx +from anyio.abc import TaskGroup +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpx_sse import EventSource, ServerSentEvent, aconnect_sse + +from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client +from mcp.shared.message import ClientMessageMetadata, SessionMessage +from mcp.types import ( + ErrorData, + InitializeResult, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +logger = logging.getLogger(__name__) + + +SessionMessageOrError = SessionMessage | Exception +StreamWriter = MemoryObjectSendStream[SessionMessageOrError] +StreamReader = MemoryObjectReceiveStream[SessionMessage] +GetSessionIdCallback = Callable[[], str | None] + +MCP_SESSION_ID = "mcp-session-id" +MCP_PROTOCOL_VERSION = "mcp-protocol-version" +LAST_EVENT_ID = "last-event-id" +CONTENT_TYPE = "content-type" +ACCEPT = "accept" + + +JSON = "application/json" +SSE = "text/event-stream" + + +class StreamableHTTPError(Exception): + """Base exception for StreamableHTTP transport errors.""" + + +class ResumptionError(StreamableHTTPError): + """Raised when resumption request is invalid.""" + + +@dataclass +class RequestContext: + """Context for a request operation.""" + + client: httpx.AsyncClient + headers: dict[str, str] + session_id: str | None + session_message: SessionMessage + metadata: ClientMessageMetadata | None + read_stream_writer: StreamWriter + sse_read_timeout: float + + +class StreamableHTTPTransport: + """StreamableHTTP client transport implementation.""" + + def __init__( + self, + url: str, + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, + auth: httpx.Auth | None = None, + ) -> None: + """Initialize the StreamableHTTP transport. + + Args: + url: The endpoint URL. + headers: Optional headers to include in requests. + timeout: HTTP timeout for regular operations. + sse_read_timeout: Timeout for SSE read operations. + auth: Optional HTTPX authentication handler. + """ + self.url = url + self.headers = headers or {} + self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout + self.sse_read_timeout = ( + sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout + ) + self.auth = auth + self.session_id = None + self.protocol_version = None + self.request_headers = { + ACCEPT: f"{JSON}, {SSE}", + CONTENT_TYPE: JSON, + **self.headers, + } + + def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: + """Update headers with session ID and protocol version if available.""" + headers = base_headers.copy() + if self.session_id: + headers[MCP_SESSION_ID] = self.session_id + if self.protocol_version: + headers[MCP_PROTOCOL_VERSION] = self.protocol_version + return headers + + def _is_initialization_request(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialization request.""" + return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + + def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: + """Check if the message is an initialized notification.""" + return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" + + def _maybe_extract_session_id_from_response( + self, + response: httpx.Response, + ) -> None: + """Extract and store session ID from response headers.""" + new_session_id = response.headers.get(MCP_SESSION_ID) + if new_session_id: + self.session_id = new_session_id + logger.info(f"Received session ID: {self.session_id}") + + def _maybe_extract_protocol_version_from_message( + self, + message: JSONRPCMessage, + ) -> None: + """Extract protocol version from initialization response message.""" + if isinstance(message.root, JSONRPCResponse) and message.root.result: + try: + # Parse the result as InitializeResult for type safety + init_result = InitializeResult.model_validate(message.root.result) + self.protocol_version = str(init_result.protocolVersion) + logger.info(f"Negotiated protocol version: {self.protocol_version}") + except Exception as exc: + logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") + logger.warning(f"Raw result: {message.root.result}") + + async def _handle_sse_event( + self, + sse: ServerSentEvent, + read_stream_writer: StreamWriter, + original_request_id: RequestId | None = None, + resumption_callback: Callable[[str], Awaitable[None]] | None = None, + is_initialization: bool = False, + ) -> bool: + """Handle an SSE event, returning True if the response is complete.""" + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"SSE message: {message}") + + # Extract protocol version from initialization response + if is_initialization: + self._maybe_extract_protocol_version_from_message(message) + + # If this is a response and we have original_request_id, replace it + if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): + message.root.id = original_request_id + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + + # Call resumption token callback if we have an ID + if sse.id and resumption_callback: + await resumption_callback(sse.id) + + # If this is a response or error return True indicating completion + # Otherwise, return False to continue listening + return isinstance(message.root, JSONRPCResponse | JSONRPCError) + + except Exception as exc: + logger.exception("Error parsing SSE message") + await read_stream_writer.send(exc) + return False + else: + logger.warning(f"Unknown SSE event: {sse.event}") + return False + + async def handle_get_stream( + self, + client: httpx.AsyncClient, + read_stream_writer: StreamWriter, + ) -> None: + """Handle GET stream for server-initiated messages.""" + try: + if not self.session_id: + return + + headers = self._prepare_request_headers(self.request_headers) + + async with aconnect_sse( + client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + await self._handle_sse_event(sse, read_stream_writer) + + except Exception as exc: + logger.debug(f"GET stream error (non-fatal): {exc}") + + async def _handle_resumption_request(self, ctx: RequestContext) -> None: + """Handle a resumption request using GET with SSE.""" + headers = self._prepare_request_headers(ctx.headers) + if ctx.metadata and ctx.metadata.resumption_token: + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + else: + raise ResumptionError("Resumption request requires a resumption token") + + # Extract original request ID to map responses + original_request_id = None + if isinstance(ctx.session_message.message.root, JSONRPCRequest): + original_request_id = ctx.session_message.message.root.id + + async with aconnect_sse( + ctx.client, + "GET", + self.url, + headers=headers, + timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("Resumption GET SSE connection established") + + async for sse in event_source.aiter_sse(): + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + original_request_id, + ctx.metadata.on_resumption_token_update if ctx.metadata else None, + ) + if is_complete: + await event_source.response.aclose() + break + + async def _handle_post_request(self, ctx: RequestContext) -> None: + """Handle a POST request with response processing.""" + headers = self._prepare_request_headers(ctx.headers) + message = ctx.session_message.message + is_initialization = self._is_initialization_request(message) + + async with ctx.client.stream( + "POST", + self.url, + json=message.model_dump(by_alias=True, mode="json", exclude_none=True), + headers=headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + return + + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + await self._send_session_terminated_error( + ctx.read_stream_writer, + message.root.id, + ) + return + + response.raise_for_status() + if is_initialization: + self._maybe_extract_session_id_from_response(response) + + content_type = response.headers.get(CONTENT_TYPE, "").lower() + + if content_type.startswith(JSON): + await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) + elif content_type.startswith(SSE): + await self._handle_sse_response(response, ctx, is_initialization) + else: + await self._handle_unexpected_content_type( + content_type, + ctx.read_stream_writer, + ) + + async def _handle_json_response( + self, + response: httpx.Response, + read_stream_writer: StreamWriter, + is_initialization: bool = False, + ) -> None: + """Handle JSON response from the server.""" + try: + content = await response.aread() + message = JSONRPCMessage.model_validate_json(content) + + # Extract protocol version from initialization response + if is_initialization: + self._maybe_extract_protocol_version_from_message(message) + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except Exception as exc: + logger.exception("Error parsing JSON response") + await read_stream_writer.send(exc) + + async def _handle_sse_response( + self, + response: httpx.Response, + ctx: RequestContext, + is_initialization: bool = False, + ) -> None: + """Handle SSE response from the server.""" + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + is_complete = await self._handle_sse_event( + sse, + ctx.read_stream_writer, + resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), + is_initialization=is_initialization, + ) + # If the SSE event indicates completion, like returning respose/error + # break the loop + if is_complete: + await response.aclose() + break + except Exception as e: + logger.exception("Error reading SSE stream:") + await ctx.read_stream_writer.send(e) + + async def _handle_unexpected_content_type( + self, + content_type: str, + read_stream_writer: StreamWriter, + ) -> None: + """Handle unexpected content type in response.""" + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + async def _send_session_terminated_error( + self, + read_stream_writer: StreamWriter, + request_id: RequestId, + ) -> None: + """Send a session terminated error response.""" + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=32600, message="Session terminated"), + ) + session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + await read_stream_writer.send(session_message) + + async def post_writer( + self, + client: httpx.AsyncClient, + write_stream_reader: StreamReader, + read_stream_writer: StreamWriter, + write_stream: MemoryObjectSendStream[SessionMessage], + start_get_stream: Callable[[], None], + tg: TaskGroup, + ) -> None: + """Handle writing requests to the server.""" + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + message = session_message.message + metadata = ( + session_message.metadata + if isinstance(session_message.metadata, ClientMessageMetadata) + else None + ) + + # Check if this is a resumption request + is_resumption = bool(metadata and metadata.resumption_token) + + logger.debug(f"Sending client message: {message}") + + # Handle initialized notification + if self._is_initialized_notification(message): + start_get_stream() + + ctx = RequestContext( + client=client, + headers=self.request_headers, + session_id=self.session_id, + session_message=session_message, + metadata=metadata, + read_stream_writer=read_stream_writer, + sse_read_timeout=self.sse_read_timeout, + ) + + async def handle_request_async(): + if is_resumption: + await self._handle_resumption_request(ctx) + else: + await self._handle_post_request(ctx) + + # If this is a request, start a new task to handle it + if isinstance(message.root, JSONRPCRequest): + tg.start_soon(handle_request_async) + else: + await handle_request_async() + + except Exception: + logger.exception("Error in post_writer") + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + async def terminate_session(self, client: httpx.AsyncClient) -> None: + """Terminate the session by sending a DELETE request.""" + if not self.session_id: + return + + try: + headers = self._prepare_request_headers(self.request_headers) + response = await client.delete(self.url, headers=headers) + + if response.status_code == 405: + logger.debug("Server does not allow session termination") + elif response.status_code not in (200, 204): + logger.warning(f"Session termination failed: {response.status_code}") + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + + def get_session_id(self) -> str | None: + """Get the current session ID.""" + return self.session_id + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, str] | None = None, + timeout: float | timedelta = 30, + sse_read_timeout: float | timedelta = 60 * 5, + terminate_on_close: bool = True, + httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, + auth: httpx.Auth | None = None, +) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + GetSessionIdCallback, + ], + None, +]: + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple containing: + - read_stream: Stream for reading messages from the server + - write_stream: Stream for sending messages to the server + - get_session_id_callback: Function to retrieve the current session ID + """ + transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + async with anyio.create_task_group() as tg: + try: + logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + + async with httpx_client_factory( + headers=transport.request_headers, + timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), + auth=transport.auth, + ) as client: + # Define callbacks that need access to tg + def start_get_stream() -> None: + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + tg, + ) + + try: + yield ( + read_stream, + write_stream, + transport.get_session_id, + ) + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/mcp/client/websocket.py b/src/mcp/mcp/client/websocket.py new file mode 100644 index 000000000..0a371610b --- /dev/null +++ b/src/mcp/mcp/client/websocket.py @@ -0,0 +1,86 @@ +import json +import logging +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from websockets.asyncio.client import connect as ws_connect +from websockets.typing import Subprotocol + +import mcp.types as types +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def websocket_client( + url: str, +) -> AsyncGenerator[ + tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], + None, +]: + """ + WebSocket client transport for MCP, symmetrical to the server version. + + Connects to 'url' using the 'mcp' subprotocol, then yields: + (read_stream, write_stream) + + - read_stream: As you read from this stream, you'll receive either valid + JSONRPCMessage objects or Exception objects (when validation fails). + - write_stream: Write JSONRPCMessage objects to this stream to send them + over the WebSocket to the server. + """ + + # Create two in-memory streams: + # - One for incoming messages (read_stream, written by ws_reader) + # - One for outgoing messages (write_stream, read by ws_writer) + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + # Connect using websockets, requesting the "mcp" subprotocol + async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: + + async def ws_reader(): + """ + Reads text messages from the WebSocket, parses them as JSON-RPC messages, + and sends them into read_stream_writer. + """ + async with read_stream_writer: + async for raw_text in ws: + try: + message = types.JSONRPCMessage.model_validate_json(raw_text) + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except ValidationError as exc: + # If JSON parse or model validation fails, send the exception + await read_stream_writer.send(exc) + + async def ws_writer(): + """ + Reads JSON-RPC messages from write_stream_reader and + sends them to the server. + """ + async with write_stream_reader: + async for session_message in write_stream_reader: + # Convert to a dict, then to JSON + msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) + await ws.send(json.dumps(msg_dict)) + + async with anyio.create_task_group() as tg: + # Start reader and writer tasks + tg.start_soon(ws_reader) + tg.start_soon(ws_writer) + + # Yield the receive/send streams + yield (read_stream, write_stream) + + # Once the caller's 'async with' block exits, we shut down + tg.cancel_scope.cancel() diff --git a/src/mcp/mcp/os/__init__.py b/src/mcp/mcp/os/__init__.py new file mode 100644 index 000000000..fa5dbc809 --- /dev/null +++ b/src/mcp/mcp/os/__init__.py @@ -0,0 +1 @@ +"""Platform-specific utilities for MCP.""" diff --git a/src/mcp/mcp/os/posix/__init__.py b/src/mcp/mcp/os/posix/__init__.py new file mode 100644 index 000000000..23aff8bb0 --- /dev/null +++ b/src/mcp/mcp/os/posix/__init__.py @@ -0,0 +1 @@ +"""POSIX-specific utilities for MCP.""" diff --git a/src/mcp/mcp/os/posix/utilities.py b/src/mcp/mcp/os/posix/utilities.py new file mode 100644 index 000000000..dd1aea363 --- /dev/null +++ b/src/mcp/mcp/os/posix/utilities.py @@ -0,0 +1,60 @@ +""" +POSIX-specific functionality for stdio client operations. +""" + +import logging +import os +import signal + +import anyio +from anyio.abc import Process + +logger = logging.getLogger(__name__) + + +async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None: + """ + Terminate a process and all its children on POSIX systems. + + Uses os.killpg() for atomic process group termination. + + Args: + process: The process to terminate + timeout_seconds: Timeout in seconds before force killing (default: 2.0) + """ + pid = getattr(process, "pid", None) or getattr(getattr(process, "popen", None), "pid", None) + if not pid: + # No PID means there's no process to terminate - it either never started, + # already exited, or we have an invalid process object + return + + try: + pgid = os.getpgid(pid) + os.killpg(pgid, signal.SIGTERM) + + with anyio.move_on_after(timeout_seconds): + while True: + try: + # Check if process group still exists (signal 0 = check only) + os.killpg(pgid, 0) + await anyio.sleep(0.1) + except ProcessLookupError: + return + + try: + os.killpg(pgid, signal.SIGKILL) + except ProcessLookupError: + pass + + except (ProcessLookupError, PermissionError, OSError) as e: + logger.warning(f"Process group termination failed for PID {pid}: {e}, falling back to simple terminate") + try: + process.terminate() + with anyio.fail_after(timeout_seconds): + await process.wait() + except Exception: + logger.warning(f"Process termination failed for PID {pid}, attempting force kill") + try: + process.kill() + except Exception: + logger.exception(f"Failed to kill process {pid}") diff --git a/src/mcp/mcp/os/win32/__init__.py b/src/mcp/mcp/os/win32/__init__.py new file mode 100644 index 000000000..f1ebab98d --- /dev/null +++ b/src/mcp/mcp/os/win32/__init__.py @@ -0,0 +1 @@ +"""Windows-specific utilities for MCP.""" diff --git a/src/mcp/mcp/os/win32/utilities.py b/src/mcp/mcp/os/win32/utilities.py new file mode 100644 index 000000000..962be0229 --- /dev/null +++ b/src/mcp/mcp/os/win32/utilities.py @@ -0,0 +1,338 @@ +""" +Windows-specific functionality for stdio client operations. +""" + +import logging +import shutil +import subprocess +import sys +from pathlib import Path +from typing import BinaryIO, TextIO, cast + +import anyio +from anyio import to_thread +from anyio.abc import Process +from anyio.streams.file import FileReadStream, FileWriteStream +from typing_extensions import deprecated + +logger = logging.getLogger("client.stdio.win32") + +# Windows-specific imports for Job Objects +if sys.platform == "win32": + import pywintypes + import win32api + import win32con + import win32job +else: + # Type stubs for non-Windows platforms + win32api = None + win32con = None + win32job = None + pywintypes = None + +JobHandle = int + + +def get_windows_executable_command(command: str) -> str: + """ + Get the correct executable command normalized for Windows. + + On Windows, commands might exist with specific extensions (.exe, .cmd, etc.) + that need to be located for proper execution. + + Args: + command: Base command (e.g., 'uvx', 'npx') + + Returns: + str: Windows-appropriate command path + """ + try: + # First check if command exists in PATH as-is + if command_path := shutil.which(command): + return command_path + + # Check for Windows-specific extensions + for ext in [".cmd", ".bat", ".exe", ".ps1"]: + ext_version = f"{command}{ext}" + if ext_path := shutil.which(ext_version): + return ext_path + + # For regular commands or if we couldn't find special versions + return command + except OSError: + # Handle file system errors during path resolution + # (permissions, broken symlinks, etc.) + return command + + +class FallbackProcess: + """ + A fallback process wrapper for Windows to handle async I/O + when using subprocess.Popen, which provides sync-only FileIO objects. + + This wraps stdin and stdout into async-compatible + streams (FileReadStream, FileWriteStream), + so that MCP clients expecting async streams can work properly. + """ + + def __init__(self, popen_obj: subprocess.Popen[bytes]): + self.popen: subprocess.Popen[bytes] = popen_obj + self.stdin_raw = popen_obj.stdin # type: ignore[assignment] + self.stdout_raw = popen_obj.stdout # type: ignore[assignment] + self.stderr = popen_obj.stderr # type: ignore[assignment] + + self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None + self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None + + async def __aenter__(self): + """Support async context manager entry.""" + return self + + async def __aexit__( + self, + exc_type: BaseException | None, + exc_val: BaseException | None, + exc_tb: object | None, + ) -> None: + """Terminate and wait on process exit inside a thread.""" + self.popen.terminate() + await to_thread.run_sync(self.popen.wait) + + # Close the file handles to prevent ResourceWarning + if self.stdin: + await self.stdin.aclose() + if self.stdout: + await self.stdout.aclose() + if self.stdin_raw: + self.stdin_raw.close() + if self.stdout_raw: + self.stdout_raw.close() + if self.stderr: + self.stderr.close() + + async def wait(self): + """Async wait for process completion.""" + return await to_thread.run_sync(self.popen.wait) + + def terminate(self): + """Terminate the subprocess immediately.""" + return self.popen.terminate() + + def kill(self) -> None: + """Kill the subprocess immediately (alias for terminate).""" + self.terminate() + + @property + def pid(self) -> int: + """Return the process ID.""" + return self.popen.pid + + +# ------------------------ +# Updated function +# ------------------------ + + +async def create_windows_process( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO | None = sys.stderr, + cwd: Path | str | None = None, +) -> Process | FallbackProcess: + """ + Creates a subprocess in a Windows-compatible way with Job Object support. + + Attempt to use anyio's open_process for async subprocess creation. + In some cases this will throw NotImplementedError on Windows, e.g. + when using the SelectorEventLoop which does not support async subprocesses. + In that case, we fall back to using subprocess.Popen. + + The process is automatically added to a Job Object to ensure all child + processes are terminated when the parent is terminated. + + Args: + command (str): The executable to run + args (list[str]): List of command line arguments + env (dict[str, str] | None): Environment variables + errlog (TextIO | None): Where to send stderr output (defaults to sys.stderr) + cwd (Path | str | None): Working directory for the subprocess + + Returns: + Process | FallbackProcess: Async-compatible subprocess with stdin and stdout streams + """ + job = _create_job_object() + process = None + + try: + # First try using anyio with Windows-specific flags to hide console window + process = await anyio.open_process( + [command, *args], + env=env, + # Ensure we don't create console windows for each process + creationflags=subprocess.CREATE_NO_WINDOW # type: ignore + if hasattr(subprocess, "CREATE_NO_WINDOW") + else 0, + stderr=errlog, + cwd=cwd, + ) + except NotImplementedError: + # If Windows doesn't support async subprocess creation, use fallback + process = await _create_windows_fallback_process(command, args, env, errlog, cwd) + except Exception: + # Try again without creation flags + process = await anyio.open_process( + [command, *args], + env=env, + stderr=errlog, + cwd=cwd, + ) + + _maybe_assign_process_to_job(process, job) + return process + + +async def _create_windows_fallback_process( + command: str, + args: list[str], + env: dict[str, str] | None = None, + errlog: TextIO | None = sys.stderr, + cwd: Path | str | None = None, +) -> FallbackProcess: + """ + Create a subprocess using subprocess.Popen as a fallback when anyio fails. + + This function wraps the sync subprocess.Popen in an async-compatible interface. + """ + try: + # Try launching with creationflags to avoid opening a new console window + popen_obj = subprocess.Popen( + [command, *args], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=errlog, + env=env, + cwd=cwd, + bufsize=0, # Unbuffered output + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), + ) + except Exception: + # If creationflags failed, fallback without them + popen_obj = subprocess.Popen( + [command, *args], + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=errlog, + env=env, + cwd=cwd, + bufsize=0, + ) + return FallbackProcess(popen_obj) + + +def _create_job_object() -> int | None: + """ + Create a Windows Job Object configured to terminate all processes when closed. + """ + if sys.platform != "win32" or not win32job: + return None + + try: + job = win32job.CreateJobObject(None, "") + extended_info = win32job.QueryInformationJobObject(job, win32job.JobObjectExtendedLimitInformation) + + extended_info["BasicLimitInformation"]["LimitFlags"] |= win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + win32job.SetInformationJobObject(job, win32job.JobObjectExtendedLimitInformation, extended_info) + return job + except Exception as e: + logger.warning(f"Failed to create Job Object for process tree management: {e}") + return None + + +def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None: + """ + Try to assign a process to a job object. If assignment fails + for any reason, the job handle is closed. + """ + if not job: + return + + if sys.platform != "win32" or not win32api or not win32con or not win32job: + return + + try: + process_handle = win32api.OpenProcess( + win32con.PROCESS_SET_QUOTA | win32con.PROCESS_TERMINATE, False, process.pid + ) + if not process_handle: + raise Exception("Failed to open process handle") + + try: + win32job.AssignProcessToJobObject(job, process_handle) + process._job_object = job + finally: + win32api.CloseHandle(process_handle) + except Exception as e: + logger.warning(f"Failed to assign process {process.pid} to Job Object: {e}") + if win32api: + win32api.CloseHandle(job) + + +async def terminate_windows_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None: + """ + Terminate a process and all its children on Windows. + + If the process has an associated job object, it will be terminated. + Otherwise, falls back to basic process termination. + + Args: + process: The process to terminate + timeout_seconds: Timeout in seconds before force killing (default: 2.0) + """ + if sys.platform != "win32": + return + + job = getattr(process, "_job_object", None) + if job and win32job: + try: + win32job.TerminateJobObject(job, 1) + except Exception: + # Job might already be terminated + pass + finally: + if win32api: + try: + win32api.CloseHandle(job) + except Exception: + pass + + # Always try to terminate the process itself as well + try: + process.terminate() + except Exception: + pass + + +@deprecated( + "terminate_windows_process is deprecated and will be removed in a future version. " + "Process termination is now handled internally by the stdio_client context manager." +) +async def terminate_windows_process(process: Process | FallbackProcess): + """ + Terminate a Windows process. + + Note: On Windows, terminating a process with process.terminate() doesn't + always guarantee immediate process termination. + So we give it 2s to exit, or we call process.kill() + which sends a SIGKILL equivalent signal. + + Args: + process: The process to terminate + """ + try: + process.terminate() + with anyio.fail_after(2.0): + await process.wait() + except TimeoutError: + # Force kill if it doesn't terminate + process.kill() diff --git a/src/mcp/mcp/py.typed b/src/mcp/mcp/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/mcp/mcp/server/__init__.py b/src/mcp/mcp/server/__init__.py new file mode 100644 index 000000000..0feed368e --- /dev/null +++ b/src/mcp/mcp/server/__init__.py @@ -0,0 +1,5 @@ +from .fastmcp import FastMCP +from .lowlevel import NotificationOptions, Server +from .models import InitializationOptions + +__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/mcp/server/__main__.py b/src/mcp/mcp/server/__main__.py new file mode 100644 index 000000000..1970eca7d --- /dev/null +++ b/src/mcp/mcp/server/__main__.py @@ -0,0 +1,50 @@ +import importlib.metadata +import logging +import sys + +import anyio + +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.server.stdio import stdio_server +from mcp.types import ServerCapabilities + +if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("server") + + +async def receive_loop(session: ServerSession): + logger.info("Starting receive loop") + async for message in session.incoming_messages: + if isinstance(message, Exception): + logger.error("Error: %s", message) + continue + + logger.info("Received message from client: %s", message) + + +async def main(): + version = importlib.metadata.version("mcp") + async with stdio_server() as (read_stream, write_stream): + async with ( + ServerSession( + read_stream, + write_stream, + InitializationOptions( + server_name="mcp", + server_version=version, + capabilities=ServerCapabilities(), + ), + ) as session, + write_stream, + ): + await receive_loop(session) + + +if __name__ == "__main__": + anyio.run(main, backend="trio") diff --git a/src/mcp/mcp/server/auth/__init__.py b/src/mcp/mcp/server/auth/__init__.py new file mode 100644 index 000000000..6888ffe8d --- /dev/null +++ b/src/mcp/mcp/server/auth/__init__.py @@ -0,0 +1,3 @@ +""" +MCP OAuth server authorization components. +""" diff --git a/src/mcp/mcp/server/auth/errors.py b/src/mcp/mcp/server/auth/errors.py new file mode 100644 index 000000000..117deea83 --- /dev/null +++ b/src/mcp/mcp/server/auth/errors.py @@ -0,0 +1,5 @@ +from pydantic import ValidationError + + +def stringify_pydantic_error(validation_error: ValidationError) -> str: + return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) diff --git a/src/mcp/mcp/server/auth/handlers/__init__.py b/src/mcp/mcp/server/auth/handlers/__init__.py new file mode 100644 index 000000000..e99a62de1 --- /dev/null +++ b/src/mcp/mcp/server/auth/handlers/__init__.py @@ -0,0 +1,3 @@ +""" +Request handlers for MCP authorization endpoints. +""" diff --git a/src/mcp/mcp/server/auth/handlers/authorize.py b/src/mcp/mcp/server/auth/handlers/authorize.py new file mode 100644 index 000000000..f484cf886 --- /dev/null +++ b/src/mcp/mcp/server/auth/handlers/authorize.py @@ -0,0 +1,224 @@ +import logging +from dataclasses import dataclass +from typing import Any, Literal + +from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError +from starlette.datastructures import FormData, QueryParams +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response + +from mcp.server.auth.errors import stringify_pydantic_error +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.provider import ( + AuthorizationErrorCode, + AuthorizationParams, + AuthorizeError, + OAuthAuthorizationServerProvider, + construct_redirect_uri, +) +from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError + +logger = logging.getLogger(__name__) + + +class AuthorizationRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 + client_id: str = Field(..., description="The client ID") + redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization") + + # see OAuthClientMetadata; we only support `code` + response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") + code_challenge: str = Field(..., description="PKCE code challenge") + code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") + state: str | None = Field(None, description="Optional state parameter") + scope: str | None = Field( + None, + description="Optional scope; if specified, should be a space-separated list of scope strings", + ) + resource: str | None = Field( + None, + description="RFC 8707 resource indicator - the MCP server this token will be used with", + ) + + +class AuthorizationErrorResponse(BaseModel): + error: AuthorizationErrorCode + error_description: str | None + error_uri: AnyUrl | None = None + # must be set if provided in the request + state: str | None = None + + +def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None: + if params is None: + return None + value = params.get(key) + if isinstance(value, str): + return value + return None + + +class AnyUrlModel(RootModel[AnyUrl]): + root: AnyUrl + + +@dataclass +class AuthorizationHandler: + provider: OAuthAuthorizationServerProvider[Any, Any, Any] + + async def handle(self, request: Request) -> Response: + # implements authorization requests for grant_type=code; + # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 + + state = None + redirect_uri = None + client = None + params = None + + async def error_response( + error: AuthorizationErrorCode, + error_description: str | None, + attempt_load_client: bool = True, + ): + # Error responses take two different formats: + # 1. The request has a valid client ID & redirect_uri: we issue a redirect + # back to the redirect_uri with the error response fields as query + # parameters. This allows the client to be notified of the error. + # 2. Otherwise, we return an error response directly to the end user; + # we choose to do so in JSON, but this is left undefined in the + # specification. + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 + # + # This logic is a bit awkward to handle, because the error might be thrown + # very early in request validation, before we've done the usual Pydantic + # validation, loaded the client, etc. To handle this, error_response() + # contains fallback logic which attempts to load the parameters directly + # from the request. + + nonlocal client, redirect_uri, state + if client is None and attempt_load_client: + # make last-ditch attempt to load the client + client_id = best_effort_extract_string("client_id", params) + client = client_id and await self.provider.get_client(client_id) + if redirect_uri is None and client: + # make last-ditch effort to load the redirect uri + try: + if params is not None and "redirect_uri" not in params: + raw_redirect_uri = None + else: + raw_redirect_uri = AnyUrlModel.model_validate( + best_effort_extract_string("redirect_uri", params) + ).root + redirect_uri = client.validate_redirect_uri(raw_redirect_uri) + except (ValidationError, InvalidRedirectUriError): + # if the redirect URI is invalid, ignore it & just return the + # initial error + pass + + # the error response MUST contain the state specified by the client, if any + if state is None: + # make last-ditch effort to load state + state = best_effort_extract_string("state", params) + + error_resp = AuthorizationErrorResponse( + error=error, + error_description=error_description, + state=state, + ) + + if redirect_uri and client: + return RedirectResponse( + url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + else: + return PydanticJSONResponse( + status_code=400, + content=error_resp, + headers={"Cache-Control": "no-store"}, + ) + + try: + # Parse request parameters + if request.method == "GET": + # Convert query_params to dict for pydantic validation + params = request.query_params + else: + # Parse form data for POST requests + params = await request.form() + + # Save state if it exists, even before validation + state = best_effort_extract_string("state", params) + + try: + auth_request = AuthorizationRequest.model_validate(params) + state = auth_request.state # Update with validated state + except ValidationError as validation_error: + error: AuthorizationErrorCode = "invalid_request" + for e in validation_error.errors(): + if e["loc"] == ("response_type",) and e["type"] == "literal_error": + error = "unsupported_response_type" + break + return await error_response(error, stringify_pydantic_error(validation_error)) + + # Get client information + client = await self.provider.get_client( + auth_request.client_id, + ) + if not client: + # For client_id validation errors, return direct error (no redirect) + return await error_response( + error="invalid_request", + error_description=f"Client ID '{auth_request.client_id}' not found", + attempt_load_client=False, + ) + + # Validate redirect_uri against client's registered URIs + try: + redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri) + except InvalidRedirectUriError as validation_error: + # For redirect_uri validation errors, return direct error (no redirect) + return await error_response( + error="invalid_request", + error_description=validation_error.message, + ) + + # Validate scope - for scope errors, we can redirect + try: + scopes = client.validate_scope(auth_request.scope) + except InvalidScopeError as validation_error: + # For scope errors, redirect with error parameters + return await error_response( + error="invalid_scope", + error_description=validation_error.message, + ) + + # Setup authorization parameters + auth_params = AuthorizationParams( + state=state, + scopes=scopes, + code_challenge=auth_request.code_challenge, + redirect_uri=redirect_uri, + redirect_uri_provided_explicitly=auth_request.redirect_uri is not None, + resource=auth_request.resource, # RFC 8707 + ) + + try: + # Let the provider pick the next URI to redirect to + return RedirectResponse( + url=await self.provider.authorize( + client, + auth_params, + ), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + except AuthorizeError as e: + # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 + return await error_response(error=e.error, error_description=e.error_description) + + except Exception as validation_error: + # Catch-all for unexpected errors + logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) + return await error_response(error="server_error", error_description="An unexpected error occurred") diff --git a/src/mcp/mcp/server/auth/handlers/metadata.py b/src/mcp/mcp/server/auth/handlers/metadata.py new file mode 100644 index 000000000..f12644215 --- /dev/null +++ b/src/mcp/mcp/server/auth/handlers/metadata.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass + +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata + + +@dataclass +class MetadataHandler: + metadata: OAuthMetadata + + async def handle(self, request: Request) -> Response: + return PydanticJSONResponse( + content=self.metadata, + headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour + ) + + +@dataclass +class ProtectedResourceMetadataHandler: + metadata: ProtectedResourceMetadata + + async def handle(self, request: Request) -> Response: + return PydanticJSONResponse( + content=self.metadata, + headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour + ) diff --git a/src/mcp/mcp/server/auth/handlers/register.py b/src/mcp/mcp/server/auth/handlers/register.py new file mode 100644 index 000000000..e6d99e66d --- /dev/null +++ b/src/mcp/mcp/server/auth/handlers/register.py @@ -0,0 +1,120 @@ +import secrets +import time +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, RootModel, ValidationError +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.errors import stringify_pydantic_error +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode +from mcp.server.auth.settings import ClientRegistrationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + + +class RegistrationRequest(RootModel[OAuthClientMetadata]): + # this wrapper is a no-op; it's just to separate out the types exposed to the + # provider from what we use in the HTTP handler + root: OAuthClientMetadata + + +class RegistrationErrorResponse(BaseModel): + error: RegistrationErrorCode + error_description: str | None + + +@dataclass +class RegistrationHandler: + provider: OAuthAuthorizationServerProvider[Any, Any, Any] + options: ClientRegistrationOptions + + async def handle(self, request: Request) -> Response: + # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 + try: + # Parse request body as JSON + body = await request.json() + client_metadata = OAuthClientMetadata.model_validate(body) + + # Scope validation is handled below + except ValidationError as validation_error: + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description=stringify_pydantic_error(validation_error), + ), + status_code=400, + ) + + client_id = str(uuid4()) + client_secret = None + if client_metadata.token_endpoint_auth_method != "none": + # cryptographically secure random 32-byte hex string + client_secret = secrets.token_hex(32) + + if client_metadata.scope is None and self.options.default_scopes is not None: + client_metadata.scope = " ".join(self.options.default_scopes) + elif client_metadata.scope is not None and self.options.valid_scopes is not None: + requested_scopes = set(client_metadata.scope.split()) + valid_scopes = set(self.options.valid_scopes) + if not requested_scopes.issubset(valid_scopes): + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="Requested scopes are not valid: " + f"{', '.join(requested_scopes - valid_scopes)}", + ), + status_code=400, + ) + if set(client_metadata.grant_types) != {"authorization_code", "refresh_token"}: + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="grant_types must be authorization_code and refresh_token", + ), + status_code=400, + ) + + client_id_issued_at = int(time.time()) + client_secret_expires_at = ( + client_id_issued_at + self.options.client_secret_expiry_seconds + if self.options.client_secret_expiry_seconds is not None + else None + ) + + client_info = OAuthClientInformationFull( + client_id=client_id, + client_id_issued_at=client_id_issued_at, + client_secret=client_secret, + client_secret_expires_at=client_secret_expires_at, + # passthrough information from the client request + redirect_uris=client_metadata.redirect_uris, + token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, + grant_types=client_metadata.grant_types, + response_types=client_metadata.response_types, + client_name=client_metadata.client_name, + client_uri=client_metadata.client_uri, + logo_uri=client_metadata.logo_uri, + scope=client_metadata.scope, + contacts=client_metadata.contacts, + tos_uri=client_metadata.tos_uri, + policy_uri=client_metadata.policy_uri, + jwks_uri=client_metadata.jwks_uri, + jwks=client_metadata.jwks, + software_id=client_metadata.software_id, + software_version=client_metadata.software_version, + ) + try: + # Register client + await self.provider.register_client(client_info) + + # Return client information + return PydanticJSONResponse(content=client_info, status_code=201) + except RegistrationError as e: + # Handle registration errors as defined in RFC 7591 Section 3.2.2 + return PydanticJSONResponse( + content=RegistrationErrorResponse(error=e.error, error_description=e.error_description), + status_code=400, + ) diff --git a/src/mcp/mcp/server/auth/handlers/revoke.py b/src/mcp/mcp/server/auth/handlers/revoke.py new file mode 100644 index 000000000..478ad7a01 --- /dev/null +++ b/src/mcp/mcp/server/auth/handlers/revoke.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from functools import partial +from typing import Any, Literal + +from pydantic import BaseModel, ValidationError +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.errors import ( + stringify_pydantic_error, +) +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken + + +class RevocationRequest(BaseModel): + """ + # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 + """ + + token: str + token_type_hint: Literal["access_token", "refresh_token"] | None = None + client_id: str + client_secret: str | None + + +class RevocationErrorResponse(BaseModel): + error: Literal["invalid_request", "unauthorized_client"] + error_description: str | None = None + + +@dataclass +class RevocationHandler: + provider: OAuthAuthorizationServerProvider[Any, Any, Any] + client_authenticator: ClientAuthenticator + + async def handle(self, request: Request) -> Response: + """ + Handler for the OAuth 2.0 Token Revocation endpoint. + """ + try: + form_data = await request.form() + revocation_request = RevocationRequest.model_validate(dict(form_data)) + except ValidationError as e: + return PydanticJSONResponse( + status_code=400, + content=RevocationErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(e), + ), + ) + + # Authenticate client + try: + client = await self.client_authenticator.authenticate( + revocation_request.client_id, revocation_request.client_secret + ) + except AuthenticationError as e: + return PydanticJSONResponse( + status_code=401, + content=RevocationErrorResponse( + error="unauthorized_client", + error_description=e.message, + ), + ) + + loaders = [ + self.provider.load_access_token, + partial(self.provider.load_refresh_token, client), + ] + if revocation_request.token_type_hint == "refresh_token": + loaders = reversed(loaders) + + token: None | AccessToken | RefreshToken = None + for loader in loaders: + token = await loader(revocation_request.token) + if token is not None: + break + + # if token is not found, just return HTTP 200 per the RFC + if token and token.client_id == client.client_id: + # Revoke token; provider is not meant to be able to do validation + # at this point that would result in an error + await self.provider.revoke_token(token) + + # Return successful empty response + return Response( + status_code=200, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) diff --git a/src/mcp/mcp/server/auth/handlers/token.py b/src/mcp/mcp/server/auth/handlers/token.py new file mode 100644 index 000000000..4e15e6265 --- /dev/null +++ b/src/mcp/mcp/server/auth/handlers/token.py @@ -0,0 +1,238 @@ +import base64 +import hashlib +import time +from dataclasses import dataclass +from typing import Annotated, Any, Literal + +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from starlette.requests import Request + +from mcp.server.auth.errors import stringify_pydantic_error +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode +from mcp.shared.auth import OAuthToken + + +class AuthorizationCodeRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 + grant_type: Literal["authorization_code"] + code: str = Field(..., description="The authorization code") + redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") + client_id: str + # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + client_secret: str | None = None + # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 + code_verifier: str = Field(..., description="PKCE code verifier") + # RFC 8707 resource indicator + resource: str | None = Field(None, description="Resource indicator for the token") + + +class RefreshTokenRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 + grant_type: Literal["refresh_token"] + refresh_token: str = Field(..., description="The refresh token") + scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + client_secret: str | None = None + # RFC 8707 resource indicator + resource: str | None = Field(None, description="Resource indicator for the token") + + +class TokenRequest( + RootModel[ + Annotated[ + AuthorizationCodeRequest | RefreshTokenRequest, + Field(discriminator="grant_type"), + ] + ] +): + root: Annotated[ + AuthorizationCodeRequest | RefreshTokenRequest, + Field(discriminator="grant_type"), + ] + + +class TokenErrorResponse(BaseModel): + """ + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + """ + + error: TokenErrorCode + error_description: str | None = None + error_uri: AnyHttpUrl | None = None + + +class TokenSuccessResponse(RootModel[OAuthToken]): + # this is just a wrapper over OAuthToken; the only reason we do this + # is to have some separation between the HTTP response type, and the + # type returned by the provider + root: OAuthToken + + +@dataclass +class TokenHandler: + provider: OAuthAuthorizationServerProvider[Any, Any, Any] + client_authenticator: ClientAuthenticator + + def response(self, obj: TokenSuccessResponse | TokenErrorResponse): + status_code = 200 + if isinstance(obj, TokenErrorResponse): + status_code = 400 + + return PydanticJSONResponse( + content=obj, + status_code=status_code, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) + + async def handle(self, request: Request): + try: + form_data = await request.form() + token_request = TokenRequest.model_validate(dict(form_data)).root + except ValidationError as validation_error: + return self.response( + TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(validation_error), + ) + ) + + try: + client_info = await self.client_authenticator.authenticate( + client_id=token_request.client_id, + client_secret=token_request.client_secret, + ) + except AuthenticationError as e: + return self.response( + TokenErrorResponse( + error="unauthorized_client", + error_description=e.message, + ) + ) + + if token_request.grant_type not in client_info.grant_types: + return self.response( + TokenErrorResponse( + error="unsupported_grant_type", + error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"), + ) + ) + + tokens: OAuthToken + + match token_request: + case AuthorizationCodeRequest(): + auth_code = await self.provider.load_authorization_code(client_info, token_request.code) + if auth_code is None or auth_code.client_id != token_request.client_id: + # if code belongs to different client, pretend it doesn't exist + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code does not exist", + ) + ) + + # make auth codes expire after a deadline + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 + if auth_code.expires_at < time.time(): + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code has expired", + ) + ) + + # verify redirect_uri doesn't change between /authorize and /tokens + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + if auth_code.redirect_uri_provided_explicitly: + authorize_request_redirect_uri = auth_code.redirect_uri + else: + authorize_request_redirect_uri = None + + # Convert both sides to strings for comparison to handle AnyUrl vs string issues + token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None + auth_redirect_str = ( + str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None + ) + + if token_redirect_str != auth_redirect_str: + return self.response( + TokenErrorResponse( + error="invalid_request", + error_description=("redirect_uri did not match the one used when creating auth code"), + ) + ) + + # Verify PKCE code verifier + sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() + hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + if hashed_code_verifier != auth_code.code_challenge: + # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="incorrect code_verifier", + ) + ) + + try: + # Exchange authorization code for tokens + tokens = await self.provider.exchange_authorization_code(client_info, auth_code) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + case RefreshTokenRequest(): + refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) + if refresh_token is None or refresh_token.client_id != token_request.client_id: + # if token belongs to different client, pretend it doesn't exist + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token does not exist", + ) + ) + + if refresh_token.expires_at and refresh_token.expires_at < time.time(): + # if the refresh token has expired, pretend it doesn't exist + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token has expired", + ) + ) + + # Parse scopes if provided + scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes + + for scope in scopes: + if scope not in refresh_token.scopes: + return self.response( + TokenErrorResponse( + error="invalid_scope", + error_description=(f"cannot request scope `{scope}` not provided by refresh token"), + ) + ) + + try: + # Exchange refresh token for new tokens + tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + return self.response(TokenSuccessResponse(root=tokens)) diff --git a/src/mcp/mcp/server/auth/json_response.py b/src/mcp/mcp/server/auth/json_response.py new file mode 100644 index 000000000..bd95bd693 --- /dev/null +++ b/src/mcp/mcp/server/auth/json_response.py @@ -0,0 +1,10 @@ +from typing import Any + +from starlette.responses import JSONResponse + + +class PydanticJSONResponse(JSONResponse): + # use pydantic json serialization instead of the stock `json.dumps`, + # so that we can handle serializing pydantic models like AnyHttpUrl + def render(self, content: Any) -> bytes: + return content.model_dump_json(exclude_none=True).encode("utf-8") diff --git a/src/mcp/mcp/server/auth/middleware/__init__.py b/src/mcp/mcp/server/auth/middleware/__init__.py new file mode 100644 index 000000000..ba3ff63c3 --- /dev/null +++ b/src/mcp/mcp/server/auth/middleware/__init__.py @@ -0,0 +1,3 @@ +""" +Middleware for MCP authorization. +""" diff --git a/src/mcp/mcp/server/auth/middleware/auth_context.py b/src/mcp/mcp/server/auth/middleware/auth_context.py new file mode 100644 index 000000000..e2116c3bf --- /dev/null +++ b/src/mcp/mcp/server/auth/middleware/auth_context.py @@ -0,0 +1,48 @@ +import contextvars + +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken + +# Create a contextvar to store the authenticated user +# The default is None, indicating no authenticated user is present +auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None) + + +def get_access_token() -> AccessToken | None: + """ + Get the access token from the current context. + + Returns: + The access token if an authenticated user is available, None otherwise. + """ + auth_user = auth_context_var.get() + return auth_user.access_token if auth_user else None + + +class AuthContextMiddleware: + """ + Middleware that extracts the authenticated user from the request + and sets it in a contextvar for easy access throughout the request lifecycle. + + This middleware should be added after the AuthenticationMiddleware in the + middleware stack to ensure that the user is properly authenticated before + being stored in the context. + """ + + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + # Set the authenticated user in the contextvar + token = auth_context_var.set(user) + try: + await self.app(scope, receive, send) + finally: + auth_context_var.reset(token) + else: + # No authenticated user, just process the request + await self.app(scope, receive, send) diff --git a/src/mcp/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/mcp/server/auth/middleware/bearer_auth.py new file mode 100644 index 000000000..6251e5ad5 --- /dev/null +++ b/src/mcp/mcp/server/auth/middleware/bearer_auth.py @@ -0,0 +1,128 @@ +import json +import time +from typing import Any + +from pydantic import AnyHttpUrl +from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser +from starlette.requests import HTTPConnection +from starlette.types import Receive, Scope, Send + +from mcp.server.auth.provider import AccessToken, TokenVerifier + + +class AuthenticatedUser(SimpleUser): + """User with authentication info.""" + + def __init__(self, auth_info: AccessToken): + super().__init__(auth_info.client_id) + self.access_token = auth_info + self.scopes = auth_info.scopes + + +class BearerAuthBackend(AuthenticationBackend): + """ + Authentication backend that validates Bearer tokens using a TokenVerifier. + """ + + def __init__(self, token_verifier: TokenVerifier): + self.token_verifier = token_verifier + + async def authenticate(self, conn: HTTPConnection): + auth_header = next( + (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), + None, + ) + if not auth_header or not auth_header.lower().startswith("bearer "): + return None + + token = auth_header[7:] # Remove "Bearer " prefix + + # Validate the token with the verifier + auth_info = await self.token_verifier.verify_token(token) + + if not auth_info: + return None + + if auth_info.expires_at and auth_info.expires_at < int(time.time()): + return None + + return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) + + +class RequireAuthMiddleware: + """ + Middleware that requires a valid Bearer token in the Authorization header. + + This will validate the token with the auth provider and store the resulting + auth info in the request state. + """ + + def __init__( + self, + app: Any, + required_scopes: list[str], + resource_metadata_url: AnyHttpUrl | None = None, + ): + """ + Initialize the middleware. + + Args: + app: ASGI application + required_scopes: List of scopes that the token must have + resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header + """ + self.app = app + self.required_scopes = required_scopes + self.resource_metadata_url = resource_metadata_url + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + auth_user = scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + await self._send_auth_error( + send, status_code=401, error="invalid_token", description="Authentication required" + ) + return + + auth_credentials = scope.get("auth") + + for required_scope in self.required_scopes: + # auth_credentials should always be provided; this is just paranoia + if auth_credentials is None or required_scope not in auth_credentials.scopes: + await self._send_auth_error( + send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}" + ) + return + + await self.app(scope, receive, send) + + async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None: + """Send an authentication error response with WWW-Authenticate header.""" + # Build WWW-Authenticate header value + www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] + if self.resource_metadata_url: + www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') + + www_authenticate = f"Bearer {', '.join(www_auth_parts)}" + + # Send response + body = {"error": error, "error_description": description} + body_bytes = json.dumps(body).encode() + + await send( + { + "type": "http.response.start", + "status": status_code, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body_bytes)).encode()), + (b"www-authenticate", www_authenticate.encode()), + ], + } + ) + + await send( + { + "type": "http.response.body", + "body": body_bytes, + } + ) diff --git a/src/mcp/mcp/server/auth/middleware/client_auth.py b/src/mcp/mcp/server/auth/middleware/client_auth.py new file mode 100644 index 000000000..d5f473b48 --- /dev/null +++ b/src/mcp/mcp/server/auth/middleware/client_auth.py @@ -0,0 +1,51 @@ +import time +from typing import Any + +from mcp.server.auth.provider import OAuthAuthorizationServerProvider +from mcp.shared.auth import OAuthClientInformationFull + + +class AuthenticationError(Exception): + def __init__(self, message: str): + self.message = message + + +class ClientAuthenticator: + """ + ClientAuthenticator is a callable which validates requests from a client + application, used to verify /token calls. + If, during registration, the client requested to be issued a secret, the + authenticator asserts that /token calls must be authenticated with + that same token. + NOTE: clients can opt for no authentication during registration, in which case this + logic is skipped. + """ + + def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): + """ + Initialize the dependency. + + Args: + provider: Provider to look up client information + """ + self.provider = provider + + async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull: + # Look up client information + client = await self.provider.get_client(client_id) + if not client: + raise AuthenticationError("Invalid client_id") + + # If client from the store expects a secret, validate that the request provides + # that secret + if client.client_secret: + if not client_secret: + raise AuthenticationError("Client secret is required") + + if client.client_secret != client_secret: + raise AuthenticationError("Invalid client_secret") + + if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): + raise AuthenticationError("Client secret has expired") + + return client diff --git a/src/mcp/mcp/server/auth/provider.py b/src/mcp/mcp/server/auth/provider.py new file mode 100644 index 000000000..a7b108602 --- /dev/null +++ b/src/mcp/mcp/server/auth/provider.py @@ -0,0 +1,306 @@ +from dataclasses import dataclass +from typing import Generic, Literal, Protocol, TypeVar +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from pydantic import AnyUrl, BaseModel + +from mcp.shared.auth import OAuthClientInformationFull, OAuthToken + + +class AuthorizationParams(BaseModel): + state: str | None + scopes: list[str] | None + code_challenge: str + redirect_uri: AnyUrl + redirect_uri_provided_explicitly: bool + resource: str | None = None # RFC 8707 resource indicator + + +class AuthorizationCode(BaseModel): + code: str + scopes: list[str] + expires_at: float + client_id: str + code_challenge: str + redirect_uri: AnyUrl + redirect_uri_provided_explicitly: bool + resource: str | None = None # RFC 8707 resource indicator + + +class RefreshToken(BaseModel): + token: str + client_id: str + scopes: list[str] + expires_at: int | None = None + + +class AccessToken(BaseModel): + token: str + client_id: str + scopes: list[str] + expires_at: int | None = None + resource: str | None = None # RFC 8707 resource indicator + + +RegistrationErrorCode = Literal[ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_software_statement", + "unapproved_software_statement", +] + + +@dataclass(frozen=True) +class RegistrationError(Exception): + error: RegistrationErrorCode + error_description: str | None = None + + +AuthorizationErrorCode = Literal[ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", +] + + +@dataclass(frozen=True) +class AuthorizeError(Exception): + error: AuthorizationErrorCode + error_description: str | None = None + + +TokenErrorCode = Literal[ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", +] + + +@dataclass(frozen=True) +class TokenError(Exception): + error: TokenErrorCode + error_description: str | None = None + + +class TokenVerifier(Protocol): + """Protocol for verifying bearer tokens.""" + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify a bearer token and return access info if valid.""" + + +# NOTE: FastMCP doesn't render any of these types in the user response, so it's +# OK to add fields to subclasses which should not be exposed externally. +AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) +RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) +AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) + + +class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]): + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """ + Retrieves client information by client ID. + + Implementors MAY raise NotImplementedError if dynamic client registration is + disabled in ClientRegistrationOptions. + + Args: + client_id: The ID of the client to retrieve. + + Returns: + The client information, or None if the client does not exist. + """ + ... + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + """ + Saves client information as part of registering it. + + Implementors MAY raise NotImplementedError if dynamic client registration is + disabled in ClientRegistrationOptions. + + Args: + client_info: The client metadata to register. + + Raises: + RegistrationError: If the client metadata is invalid. + """ + ... + + async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: + """ + Called as part of the /authorize endpoint, and returns a URL that the client + will be redirected to. + Many MCP implementations will redirect to a third-party provider to perform + a second OAuth exchange with that provider. In this sort of setup, the client + has an OAuth connection with the MCP server, and the MCP server has an OAuth + connection with the 3rd-party provider. At the end of this flow, the client + should be redirected to the redirect_uri from params.redirect_uri. + + +--------+ +------------+ +-------------------+ + | | | | | | + | Client | --> | MCP Server | --> | 3rd Party OAuth | + | | | | | Server | + +--------+ +------------+ +-------------------+ + | ^ | + +------------+ | | | + | | | | Redirect | + |redirect_uri|<-----+ +------------------+ + | | + +------------+ + + Implementations will need to define another handler on the MCP server return + flow to perform the second redirect, and generate and store an authorization + code as part of completing the OAuth authorization step. + + Implementations SHOULD generate an authorization code with at least 160 bits of + entropy, + and MUST generate an authorization code with at least 128 bits of entropy. + See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. + + Args: + client: The client requesting authorization. + params: The parameters of the authorization request. + + Returns: + A URL to redirect the client to for authorization. + + Raises: + AuthorizeError: If the authorization request is invalid. + """ + ... + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCodeT | None: + """ + Loads an AuthorizationCode by its code. + + Args: + client: The client that requested the authorization code. + authorization_code: The authorization code to get the challenge for. + + Returns: + The AuthorizationCode, or None if not found + """ + ... + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT + ) -> OAuthToken: + """ + Exchanges an authorization code for an access token and refresh token. + + Args: + client: The client exchanging the authorization code. + authorization_code: The authorization code to exchange. + + Returns: + The OAuth token, containing access and refresh tokens. + + Raises: + TokenError: If the request is invalid + """ + ... + + async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None: + """ + Loads a RefreshToken by its token string. + + Args: + client: The client that is requesting to load the refresh token. + refresh_token: The refresh token string to load. + + Returns: + The RefreshToken object if found, or None if not found. + """ + + ... + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshTokenT, + scopes: list[str], + ) -> OAuthToken: + """ + Exchanges a refresh token for an access token and refresh token. + + Implementations SHOULD rotate both the access token and refresh token. + + Args: + client: The client exchanging the refresh token. + refresh_token: The refresh token to exchange. + scopes: Optional scopes to request with the new access token. + + Returns: + The OAuth token, containing access and refresh tokens. + + Raises: + TokenError: If the request is invalid + """ + ... + + async def load_access_token(self, token: str) -> AccessTokenT | None: + """ + Loads an access token by its token. + + Args: + token: The access token to verify. + + Returns: + The AuthInfo, or None if the token is invalid. + """ + ... + + async def revoke_token( + self, + token: AccessTokenT | RefreshTokenT, + ) -> None: + """ + Revokes an access or refresh token. + + If the given token is invalid or already revoked, this method should do nothing. + + Implementations SHOULD revoke both the access token and its corresponding + refresh token, regardless of which of the access token or refresh token is + provided. + + Args: + token: the token to revoke + """ + ... + + +def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: + parsed_uri = urlparse(redirect_uri_base) + query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs] + for k, v in params.items(): + if v is not None: + query_params.append((k, v)) + + redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params))) + return redirect_uri + + +class ProviderTokenVerifier(TokenVerifier): + """Token verifier that uses an OAuthAuthorizationServerProvider. + + This is provided for backwards compatibility with existing auth_server_provider + configurations. For new implementations using AS/RS separation, consider using + the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier. + """ + + def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"): + self.provider = provider + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify token using the provider's load_access_token method.""" + return await self.provider.load_access_token(token) diff --git a/src/mcp/mcp/server/auth/routes.py b/src/mcp/mcp/server/auth/routes.py new file mode 100644 index 000000000..bce32df52 --- /dev/null +++ b/src/mcp/mcp/server/auth/routes.py @@ -0,0 +1,227 @@ +from collections.abc import Awaitable, Callable +from typing import Any + +from pydantic import AnyHttpUrl +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Route, request_response # type: ignore +from starlette.types import ASGIApp + +from mcp.server.auth.handlers.authorize import AuthorizationHandler +from mcp.server.auth.handlers.metadata import MetadataHandler +from mcp.server.auth.handlers.register import RegistrationHandler +from mcp.server.auth.handlers.revoke import RevocationHandler +from mcp.server.auth.handlers.token import TokenHandler +from mcp.server.auth.middleware.client_auth import ClientAuthenticator +from mcp.server.auth.provider import OAuthAuthorizationServerProvider +from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions +from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER +from mcp.shared.auth import OAuthMetadata + + +def validate_issuer_url(url: AnyHttpUrl): + """ + Validate that the issuer URL meets OAuth 2.0 requirements. + + Args: + url: The issuer URL to validate + + Raises: + ValueError: If the issuer URL is invalid + """ + + # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing + if ( + url.scheme != "https" + and url.host != "localhost" + and (url.host is not None and not url.host.startswith("127.0.0.1")) + ): + raise ValueError("Issuer URL must be HTTPS") + + # No fragments or query parameters allowed + if url.fragment: + raise ValueError("Issuer URL must not have a fragment") + if url.query: + raise ValueError("Issuer URL must not have a query string") + + +AUTHORIZATION_PATH = "/authorize" +TOKEN_PATH = "/token" +REGISTRATION_PATH = "/register" +REVOCATION_PATH = "/revoke" + + +def cors_middleware( + handler: Callable[[Request], Response | Awaitable[Response]], + allow_methods: list[str], +) -> ASGIApp: + cors_app = CORSMiddleware( + app=request_response(handler), + allow_origins="*", + allow_methods=allow_methods, + allow_headers=[MCP_PROTOCOL_VERSION_HEADER], + ) + return cors_app + + +def create_auth_routes( + provider: OAuthAuthorizationServerProvider[Any, Any, Any], + issuer_url: AnyHttpUrl, + service_documentation_url: AnyHttpUrl | None = None, + client_registration_options: ClientRegistrationOptions | None = None, + revocation_options: RevocationOptions | None = None, +) -> list[Route]: + validate_issuer_url(issuer_url) + + client_registration_options = client_registration_options or ClientRegistrationOptions() + revocation_options = revocation_options or RevocationOptions() + metadata = build_metadata( + issuer_url, + service_documentation_url, + client_registration_options, + revocation_options, + ) + client_authenticator = ClientAuthenticator(provider) + + # Create routes + # Allow CORS requests for endpoints meant to be hit by the OAuth client + # (with the client secret). This is intended to support things like MCP Inspector, + # where the client runs in a web browser. + routes = [ + Route( + "/.well-known/oauth-authorization-server", + endpoint=cors_middleware( + MetadataHandler(metadata).handle, + ["GET", "OPTIONS"], + ), + methods=["GET", "OPTIONS"], + ), + Route( + AUTHORIZATION_PATH, + # do not allow CORS for authorization endpoint; + # clients should just redirect to this + endpoint=AuthorizationHandler(provider).handle, + methods=["GET", "POST"], + ), + Route( + TOKEN_PATH, + endpoint=cors_middleware( + TokenHandler(provider, client_authenticator).handle, + ["POST", "OPTIONS"], + ), + methods=["POST", "OPTIONS"], + ), + ] + + if client_registration_options.enabled: + registration_handler = RegistrationHandler( + provider, + options=client_registration_options, + ) + routes.append( + Route( + REGISTRATION_PATH, + endpoint=cors_middleware( + registration_handler.handle, + ["POST", "OPTIONS"], + ), + methods=["POST", "OPTIONS"], + ) + ) + + if revocation_options.enabled: + revocation_handler = RevocationHandler(provider, client_authenticator) + routes.append( + Route( + REVOCATION_PATH, + endpoint=cors_middleware( + revocation_handler.handle, + ["POST", "OPTIONS"], + ), + methods=["POST", "OPTIONS"], + ) + ) + + return routes + + +def build_metadata( + issuer_url: AnyHttpUrl, + service_documentation_url: AnyHttpUrl | None, + client_registration_options: ClientRegistrationOptions, + revocation_options: RevocationOptions, +) -> OAuthMetadata: + authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH) + token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH) + + # Create metadata + metadata = OAuthMetadata( + issuer=issuer_url, + authorization_endpoint=authorization_url, + token_endpoint=token_url, + scopes_supported=client_registration_options.valid_scopes, + response_types_supported=["code"], + response_modes_supported=None, + grant_types_supported=["authorization_code", "refresh_token"], + token_endpoint_auth_methods_supported=["client_secret_post"], + token_endpoint_auth_signing_alg_values_supported=None, + service_documentation=service_documentation_url, + ui_locales_supported=None, + op_policy_uri=None, + op_tos_uri=None, + introspection_endpoint=None, + code_challenge_methods_supported=["S256"], + ) + + # Add registration endpoint if supported + if client_registration_options.enabled: + metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH) + + # Add revocation endpoint if supported + if revocation_options.enabled: + metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) + metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] + + return metadata + + +def create_protected_resource_routes( + resource_url: AnyHttpUrl, + authorization_servers: list[AnyHttpUrl], + scopes_supported: list[str] | None = None, + resource_name: str | None = None, + resource_documentation: AnyHttpUrl | None = None, +) -> list[Route]: + """ + Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728). + + Args: + resource_url: The URL of this resource server + authorization_servers: List of authorization servers that can issue tokens + scopes_supported: Optional list of scopes supported by this resource + + Returns: + List of Starlette routes for protected resource metadata + """ + from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler + from mcp.shared.auth import ProtectedResourceMetadata + + metadata = ProtectedResourceMetadata( + resource=resource_url, + authorization_servers=authorization_servers, + scopes_supported=scopes_supported, + resource_name=resource_name, + resource_documentation=resource_documentation, + # bearer_methods_supported defaults to ["header"] in the model + ) + + handler = ProtectedResourceMetadataHandler(metadata) + + return [ + Route( + "/.well-known/oauth-protected-resource", + endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]), + methods=["GET", "OPTIONS"], + ) + ] diff --git a/src/mcp/mcp/server/auth/settings.py b/src/mcp/mcp/server/auth/settings.py new file mode 100644 index 000000000..1649826db --- /dev/null +++ b/src/mcp/mcp/server/auth/settings.py @@ -0,0 +1,30 @@ +from pydantic import AnyHttpUrl, BaseModel, Field + + +class ClientRegistrationOptions(BaseModel): + enabled: bool = False + client_secret_expiry_seconds: int | None = None + valid_scopes: list[str] | None = None + default_scopes: list[str] | None = None + + +class RevocationOptions(BaseModel): + enabled: bool = False + + +class AuthSettings(BaseModel): + issuer_url: AnyHttpUrl = Field( + ..., + description="OAuth authorization server URL that issues tokens for this resource server.", + ) + service_documentation_url: AnyHttpUrl | None = None + client_registration_options: ClientRegistrationOptions | None = None + revocation_options: RevocationOptions | None = None + required_scopes: list[str] | None = None + + # Resource Server settings (when operating as RS only) + resource_server_url: AnyHttpUrl | None = Field( + ..., + description="The URL of the MCP server to be used as the resource identifier " + "and base route to look up OAuth Protected Resource Metadata.", + ) diff --git a/src/mcp/mcp/server/elicitation.py b/src/mcp/mcp/server/elicitation.py new file mode 100644 index 000000000..1e48738c8 --- /dev/null +++ b/src/mcp/mcp/server/elicitation.py @@ -0,0 +1,111 @@ +"""Elicitation utilities for MCP servers.""" + +from __future__ import annotations + +import types +from typing import Generic, Literal, TypeVar, Union, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo + +from mcp.server.session import ServerSession +from mcp.types import RequestId + +ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) + + +class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]): + """Result when user accepts the elicitation.""" + + action: Literal["accept"] = "accept" + data: ElicitSchemaModelT + + +class DeclinedElicitation(BaseModel): + """Result when user declines the elicitation.""" + + action: Literal["decline"] = "decline" + + +class CancelledElicitation(BaseModel): + """Result when user cancels the elicitation.""" + + action: Literal["cancel"] = "cancel" + + +ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation + + +# Primitive types allowed in elicitation schemas +_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) + + +def _validate_elicitation_schema(schema: type[BaseModel]) -> None: + """Validate that a Pydantic model only contains primitive field types.""" + for field_name, field_info in schema.model_fields.items(): + if not _is_primitive_field(field_info): + raise TypeError( + f"Elicitation schema field '{field_name}' must be a primitive type " + f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " + f"Complex types like lists, dicts, or nested models are not allowed." + ) + + +def _is_primitive_field(field_info: FieldInfo) -> bool: + """Check if a field is a primitive type allowed in elicitation schemas.""" + annotation = field_info.annotation + + # Handle None type + if annotation is types.NoneType: + return True + + # Handle basic primitive types + if annotation in _ELICITATION_PRIMITIVE_TYPES: + return True + + # Handle Union types + origin = get_origin(annotation) + if origin is Union or origin is types.UnionType: + args = get_args(annotation) + # All args must be primitive types or None + return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) + + return False + + +async def elicit_with_validation( + session: ServerSession, + message: str, + schema: type[ElicitSchemaModelT], + related_request_id: RequestId | None = None, +) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user with schema validation. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. Or in case a + client is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + """ + # Validate that schema only contains primitive types and fail loudly if not + _validate_elicitation_schema(schema) + + json_schema = schema.model_json_schema() + + result = await session.elicit( + message=message, + requestedSchema=json_schema, + related_request_id=related_request_id, + ) + + if result.action == "accept" and result.content: + # Validate and parse the content using the schema + validated_data = schema.model_validate(result.content) + return AcceptedElicitation(data=validated_data) + elif result.action == "decline": + return DeclinedElicitation() + elif result.action == "cancel": + return CancelledElicitation() + else: + # This should never happen, but handle it just in case + raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/mcp/server/fastmcp/__init__.py b/src/mcp/mcp/server/fastmcp/__init__.py new file mode 100644 index 000000000..f8f9c1c4c --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/__init__.py @@ -0,0 +1,9 @@ +"""FastMCP - A more ergonomic interface for MCP servers.""" + +from importlib.metadata import version + +from .server import Context, FastMCP +from .utilities.types import Audio, Image + +__version__ = version("mcp") +__all__ = ["FastMCP", "Context", "Image", "Audio"] diff --git a/src/mcp/mcp/server/fastmcp/exceptions.py b/src/mcp/mcp/server/fastmcp/exceptions.py new file mode 100644 index 000000000..fb5bda106 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/exceptions.py @@ -0,0 +1,21 @@ +"""Custom exceptions for FastMCP.""" + + +class FastMCPError(Exception): + """Base error for FastMCP.""" + + +class ValidationError(FastMCPError): + """Error in validating parameters or return values.""" + + +class ResourceError(FastMCPError): + """Error in resource operations.""" + + +class ToolError(FastMCPError): + """Error in tool operations.""" + + +class InvalidSignature(Exception): + """Invalid signature for use with FastMCP.""" diff --git a/src/mcp/mcp/server/fastmcp/prompts/__init__.py b/src/mcp/mcp/server/fastmcp/prompts/__init__.py new file mode 100644 index 000000000..763726964 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/prompts/__init__.py @@ -0,0 +1,4 @@ +from .base import Prompt +from .manager import PromptManager + +__all__ = ["Prompt", "PromptManager"] diff --git a/src/mcp/mcp/server/fastmcp/prompts/base.py b/src/mcp/mcp/server/fastmcp/prompts/base.py new file mode 100644 index 000000000..b45cfc917 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/prompts/base.py @@ -0,0 +1,153 @@ +"""Base classes for FastMCP prompts.""" + +import inspect +from collections.abc import Awaitable, Callable, Sequence +from typing import Any, Literal + +import pydantic_core +from pydantic import BaseModel, Field, TypeAdapter, validate_call + +from mcp.types import ContentBlock, TextContent + + +class Message(BaseModel): + """Base class for all prompt messages.""" + + role: Literal["user", "assistant"] + content: ContentBlock + + def __init__(self, content: str | ContentBlock, **kwargs: Any): + if isinstance(content, str): + content = TextContent(type="text", text=content) + super().__init__(content=content, **kwargs) + + +class UserMessage(Message): + """A message from the user.""" + + role: Literal["user", "assistant"] = "user" + + def __init__(self, content: str | ContentBlock, **kwargs: Any): + super().__init__(content=content, **kwargs) + + +class AssistantMessage(Message): + """A message from the assistant.""" + + role: Literal["user", "assistant"] = "assistant" + + def __init__(self, content: str | ContentBlock, **kwargs: Any): + super().__init__(content=content, **kwargs) + + +message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage) + +SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] +PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] + + +class PromptArgument(BaseModel): + """An argument that can be passed to a prompt.""" + + name: str = Field(description="Name of the argument") + description: str | None = Field(None, description="Description of what the argument does") + required: bool = Field(default=False, description="Whether the argument is required") + + +class Prompt(BaseModel): + """A prompt template that can be rendered with parameters.""" + + name: str = Field(description="Name of the prompt") + title: str | None = Field(None, description="Human-readable title of the prompt") + description: str | None = Field(None, description="Description of what the prompt does") + arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt") + fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) + + @classmethod + def from_function( + cls, + fn: Callable[..., PromptResult | Awaitable[PromptResult]], + name: str | None = None, + title: str | None = None, + description: str | None = None, + ) -> "Prompt": + """Create a Prompt from a function. + + The function can return: + - A string (converted to a message) + - A Message object + - A dict (converted to a message) + - A sequence of any of the above + """ + func_name = name or fn.__name__ + + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + # Get schema from TypeAdapter - will fail if function isn't properly typed + parameters = TypeAdapter(fn).json_schema() + + # Convert parameters to PromptArguments + arguments: list[PromptArgument] = [] + if "properties" in parameters: + for param_name, param in parameters["properties"].items(): + required = param_name in parameters.get("required", []) + arguments.append( + PromptArgument( + name=param_name, + description=param.get("description"), + required=required, + ) + ) + + # ensure the arguments are properly cast + fn = validate_call(fn) + + return cls( + name=func_name, + title=title, + description=description or fn.__doc__ or "", + arguments=arguments, + fn=fn, + ) + + async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]: + """Render the prompt with arguments.""" + # Validate required arguments + if self.arguments: + required = {arg.name for arg in self.arguments if arg.required} + provided = set(arguments or {}) + missing = required - provided + if missing: + raise ValueError(f"Missing required arguments: {missing}") + + try: + # Call function and check if result is a coroutine + result = self.fn(**(arguments or {})) + if inspect.iscoroutine(result): + result = await result + + # Validate messages + if not isinstance(result, list | tuple): + result = [result] + + # Convert result to messages + messages: list[Message] = [] + for msg in result: # type: ignore[reportUnknownVariableType] + try: + if isinstance(msg, Message): + messages.append(msg) + elif isinstance(msg, dict): + messages.append(message_validator.validate_python(msg)) + elif isinstance(msg, str): + content = TextContent(type="text", text=msg) + messages.append(UserMessage(content=content)) + else: + content = pydantic_core.to_json(msg, fallback=str, indent=2).decode() + messages.append(Message(role="user", content=content)) + except Exception: + raise ValueError(f"Could not convert prompt result to message: {msg}") + + return messages + except Exception as e: + raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/mcp/server/fastmcp/prompts/manager.py b/src/mcp/mcp/server/fastmcp/prompts/manager.py new file mode 100644 index 000000000..6b01d91cd --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/prompts/manager.py @@ -0,0 +1,48 @@ +"""Prompt management functionality.""" + +from typing import Any + +from mcp.server.fastmcp.prompts.base import Message, Prompt +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class PromptManager: + """Manages FastMCP prompts.""" + + def __init__(self, warn_on_duplicate_prompts: bool = True): + self._prompts: dict[str, Prompt] = {} + self.warn_on_duplicate_prompts = warn_on_duplicate_prompts + + def get_prompt(self, name: str) -> Prompt | None: + """Get prompt by name.""" + return self._prompts.get(name) + + def list_prompts(self) -> list[Prompt]: + """List all registered prompts.""" + return list(self._prompts.values()) + + def add_prompt( + self, + prompt: Prompt, + ) -> Prompt: + """Add a prompt to the manager.""" + + # Check for duplicates + existing = self._prompts.get(prompt.name) + if existing: + if self.warn_on_duplicate_prompts: + logger.warning(f"Prompt already exists: {prompt.name}") + return existing + + self._prompts[prompt.name] = prompt + return prompt + + async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: + """Render a prompt by name with arguments.""" + prompt = self.get_prompt(name) + if not prompt: + raise ValueError(f"Unknown prompt: {name}") + + return await prompt.render(arguments) diff --git a/src/mcp/mcp/server/fastmcp/resources/__init__.py b/src/mcp/mcp/server/fastmcp/resources/__init__.py new file mode 100644 index 000000000..b5805fb34 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/resources/__init__.py @@ -0,0 +1,23 @@ +from .base import Resource +from .resource_manager import ResourceManager +from .templates import ResourceTemplate +from .types import ( + BinaryResource, + DirectoryResource, + FileResource, + FunctionResource, + HttpResource, + TextResource, +) + +__all__ = [ + "Resource", + "TextResource", + "BinaryResource", + "FunctionResource", + "FileResource", + "HttpResource", + "DirectoryResource", + "ResourceTemplate", + "ResourceManager", +] diff --git a/src/mcp/mcp/server/fastmcp/resources/base.py b/src/mcp/mcp/server/fastmcp/resources/base.py new file mode 100644 index 000000000..f57631cc1 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/resources/base.py @@ -0,0 +1,45 @@ +"""Base classes and interfaces for FastMCP resources.""" + +import abc +from typing import Annotated + +from pydantic import ( + AnyUrl, + BaseModel, + ConfigDict, + Field, + UrlConstraints, + ValidationInfo, + field_validator, +) + + +class Resource(BaseModel, abc.ABC): + """Base class for all resources.""" + + model_config = ConfigDict(validate_default=True) + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource") + name: str | None = Field(description="Name of the resource", default=None) + title: str | None = Field(description="Human-readable title of the resource", default=None) + description: str | None = Field(description="Description of the resource", default=None) + mime_type: str = Field( + default="text/plain", + description="MIME type of the resource content", + pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$", + ) + + @field_validator("name", mode="before") + @classmethod + def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: + """Set default name from URI if not provided.""" + if name: + return name + if uri := info.data.get("uri"): + return str(uri) + raise ValueError("Either name or uri must be provided") + + @abc.abstractmethod + async def read(self) -> str | bytes: + """Read the resource content.""" + pass diff --git a/src/mcp/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/mcp/server/fastmcp/resources/resource_manager.py new file mode 100644 index 000000000..35e4ec04d --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/resources/resource_manager.py @@ -0,0 +1,97 @@ +"""Resource manager functionality.""" + +from collections.abc import Callable +from typing import Any + +from pydantic import AnyUrl + +from mcp.server.fastmcp.resources.base import Resource +from mcp.server.fastmcp.resources.templates import ResourceTemplate +from mcp.server.fastmcp.utilities.logging import get_logger + +logger = get_logger(__name__) + + +class ResourceManager: + """Manages FastMCP resources.""" + + def __init__(self, warn_on_duplicate_resources: bool = True): + self._resources: dict[str, Resource] = {} + self._templates: dict[str, ResourceTemplate] = {} + self.warn_on_duplicate_resources = warn_on_duplicate_resources + + def add_resource(self, resource: Resource) -> Resource: + """Add a resource to the manager. + + Args: + resource: A Resource instance to add + + Returns: + The added resource. If a resource with the same URI already exists, + returns the existing resource. + """ + logger.debug( + "Adding resource", + extra={ + "uri": resource.uri, + "type": type(resource).__name__, + "resource_name": resource.name, + }, + ) + existing = self._resources.get(str(resource.uri)) + if existing: + if self.warn_on_duplicate_resources: + logger.warning(f"Resource already exists: {resource.uri}") + return existing + self._resources[str(resource.uri)] = resource + return resource + + def add_template( + self, + fn: Callable[..., Any], + uri_template: str, + name: str | None = None, + title: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> ResourceTemplate: + """Add a template from a function.""" + template = ResourceTemplate.from_function( + fn, + uri_template=uri_template, + name=name, + title=title, + description=description, + mime_type=mime_type, + ) + self._templates[template.uri_template] = template + return template + + async def get_resource(self, uri: AnyUrl | str) -> Resource | None: + """Get resource by URI, checking concrete resources first, then templates.""" + uri_str = str(uri) + logger.debug("Getting resource", extra={"uri": uri_str}) + + # First check concrete resources + if resource := self._resources.get(uri_str): + return resource + + # Then check templates + for template in self._templates.values(): + if params := template.matches(uri_str): + try: + return await template.create_resource(uri_str, params) + except Exception as e: + raise ValueError(f"Error creating resource from template: {e}") + + raise ValueError(f"Unknown resource: {uri}") + + def list_resources(self) -> list[Resource]: + """List all registered resources.""" + logger.debug("Listing resources", extra={"count": len(self._resources)}) + return list(self._resources.values()) + + def list_templates(self) -> list[ResourceTemplate]: + """List all registered templates.""" + logger.debug("Listing templates", extra={"count": len(self._templates)}) + return list(self._templates.values()) diff --git a/src/mcp/mcp/server/fastmcp/resources/templates.py b/src/mcp/mcp/server/fastmcp/resources/templates.py new file mode 100644 index 000000000..b1c7b2711 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/resources/templates.py @@ -0,0 +1,83 @@ +"""Resource template functionality.""" + +from __future__ import annotations + +import inspect +import re +from collections.abc import Callable +from typing import Any + +from pydantic import BaseModel, Field, TypeAdapter, validate_call + +from mcp.server.fastmcp.resources.types import FunctionResource, Resource + + +class ResourceTemplate(BaseModel): + """A template for dynamically creating resources.""" + + uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)") + name: str = Field(description="Name of the resource") + title: str | None = Field(description="Human-readable title of the resource", default=None) + description: str | None = Field(description="Description of what the resource does") + mime_type: str = Field(default="text/plain", description="MIME type of the resource content") + fn: Callable[..., Any] = Field(exclude=True) + parameters: dict[str, Any] = Field(description="JSON schema for function parameters") + + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + uri_template: str, + name: str | None = None, + title: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> ResourceTemplate: + """Create a template from a function.""" + func_name = name or fn.__name__ + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + # Get schema from TypeAdapter - will fail if function isn't properly typed + parameters = TypeAdapter(fn).json_schema() + + # ensure the arguments are properly cast + fn = validate_call(fn) + + return cls( + uri_template=uri_template, + name=func_name, + title=title, + description=description or fn.__doc__ or "", + mime_type=mime_type or "text/plain", + fn=fn, + parameters=parameters, + ) + + def matches(self, uri: str) -> dict[str, Any] | None: + """Check if URI matches template and extract parameters.""" + # Convert template to regex pattern + pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") + match = re.match(f"^{pattern}$", uri) + if match: + return match.groupdict() + return None + + async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource: + """Create a resource from the template with the given parameters.""" + try: + # Call function and check if result is a coroutine + result = self.fn(**params) + if inspect.iscoroutine(result): + result = await result + + return FunctionResource( + uri=uri, # type: ignore + name=self.name, + title=self.title, + description=self.description, + mime_type=self.mime_type, + fn=lambda: result, # Capture result in closure + ) + except Exception as e: + raise ValueError(f"Error creating resource from template: {e}") diff --git a/src/mcp/mcp/server/fastmcp/resources/types.py b/src/mcp/mcp/server/fastmcp/resources/types.py new file mode 100644 index 000000000..f2a330706 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/resources/types.py @@ -0,0 +1,196 @@ +"""Concrete resource implementations.""" + +import inspect +import json +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import anyio +import anyio.to_thread +import httpx +import pydantic +import pydantic_core +from pydantic import AnyUrl, Field, ValidationInfo, validate_call + +from mcp.server.fastmcp.resources.base import Resource + + +class TextResource(Resource): + """A resource that reads from a string.""" + + text: str = Field(description="Text content of the resource") + + async def read(self) -> str: + """Read the text content.""" + return self.text + + +class BinaryResource(Resource): + """A resource that reads from bytes.""" + + data: bytes = Field(description="Binary content of the resource") + + async def read(self) -> bytes: + """Read the binary content.""" + return self.data + + +class FunctionResource(Resource): + """A resource that defers data loading by wrapping a function. + + The function is only called when the resource is read, allowing for lazy loading + of potentially expensive data. This is particularly useful when listing resources, + as the function won't be called until the resource is actually accessed. + + The function can return: + - str for text content (default) + - bytes for binary content + - other types will be converted to JSON + """ + + fn: Callable[[], Any] = Field(exclude=True) + + async def read(self) -> str | bytes: + """Read the resource by calling the wrapped function.""" + try: + # Call the function first to see if it returns a coroutine + result = self.fn() + # If it's a coroutine, await it + if inspect.iscoroutine(result): + result = await result + + if isinstance(result, Resource): + return await result.read() + elif isinstance(result, bytes): + return result + elif isinstance(result, str): + return result + else: + return pydantic_core.to_json(result, fallback=str, indent=2).decode() + except Exception as e: + raise ValueError(f"Error reading resource {self.uri}: {e}") + + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + uri: str, + name: str | None = None, + title: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> "FunctionResource": + """Create a FunctionResource from a function.""" + func_name = name or fn.__name__ + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + # ensure the arguments are properly cast + fn = validate_call(fn) + + return cls( + uri=AnyUrl(uri), + name=func_name, + title=title, + description=description or fn.__doc__ or "", + mime_type=mime_type or "text/plain", + fn=fn, + ) + + +class FileResource(Resource): + """A resource that reads from a file. + + Set is_binary=True to read file as binary data instead of text. + """ + + path: Path = Field(description="Path to the file") + is_binary: bool = Field( + default=False, + description="Whether to read the file as binary data", + ) + mime_type: str = Field( + default="text/plain", + description="MIME type of the resource content", + ) + + @pydantic.field_validator("path") + @classmethod + def validate_absolute_path(cls, path: Path) -> Path: + """Ensure path is absolute.""" + if not path.is_absolute(): + raise ValueError("Path must be absolute") + return path + + @pydantic.field_validator("is_binary") + @classmethod + def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool: + """Set is_binary based on mime_type if not explicitly set.""" + if is_binary: + return True + mime_type = info.data.get("mime_type", "text/plain") + return not mime_type.startswith("text/") + + async def read(self) -> str | bytes: + """Read the file content.""" + try: + if self.is_binary: + return await anyio.to_thread.run_sync(self.path.read_bytes) + return await anyio.to_thread.run_sync(self.path.read_text) + except Exception as e: + raise ValueError(f"Error reading file {self.path}: {e}") + + +class HttpResource(Resource): + """A resource that reads from an HTTP endpoint.""" + + url: str = Field(description="URL to fetch content from") + mime_type: str = Field(default="application/json", description="MIME type of the resource content") + + async def read(self) -> str | bytes: + """Read the HTTP content.""" + async with httpx.AsyncClient() as client: + response = await client.get(self.url) + response.raise_for_status() + return response.text + + +class DirectoryResource(Resource): + """A resource that lists files in a directory.""" + + path: Path = Field(description="Path to the directory") + recursive: bool = Field(default=False, description="Whether to list files recursively") + pattern: str | None = Field(default=None, description="Optional glob pattern to filter files") + mime_type: str = Field(default="application/json", description="MIME type of the resource content") + + @pydantic.field_validator("path") + @classmethod + def validate_absolute_path(cls, path: Path) -> Path: + """Ensure path is absolute.""" + if not path.is_absolute(): + raise ValueError("Path must be absolute") + return path + + def list_files(self) -> list[Path]: + """List files in the directory.""" + if not self.path.exists(): + raise FileNotFoundError(f"Directory not found: {self.path}") + if not self.path.is_dir(): + raise NotADirectoryError(f"Not a directory: {self.path}") + + try: + if self.pattern: + return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern)) + return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*")) + except Exception as e: + raise ValueError(f"Error listing directory {self.path}: {e}") + + async def read(self) -> str: # Always returns JSON string + """Read the directory listing.""" + try: + files = await anyio.to_thread.run_sync(self.list_files) + file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()] + return json.dumps({"files": file_list}, indent=2) + except Exception as e: + raise ValueError(f"Error reading directory {self.path}: {e}") diff --git a/src/mcp/mcp/server/fastmcp/secure/__init__.py b/src/mcp/mcp/server/fastmcp/secure/__init__.py new file mode 100644 index 000000000..a20c932b8 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/__init__.py @@ -0,0 +1,54 @@ +""" +Secure annotations and decorators for MCP tools, resources, and prompts. + +This module provides enhanced security features including: +- Bidirectional authentication (client ↔ tool) +- End-to-end encryption +- Tool attestation and signing +- Rate limiting and audit logging +""" + +from .annotations import ( + AuthMethod, + SecureAnnotations, + SecureToolAnnotations, + SecureResourceAnnotations, + SecurePromptAnnotations, +) +from .tool import SecureTool, secure_tool +from .resource import SecureResource, secure_resource +from .prompt import SecurePrompt, secure_prompt +from .identity import ToolIdentity, ClientIdentity, create_tool_identity +from .session import SecureSession, SessionManager +from .utils import SecureAnnotationProcessor, encrypt_data, decrypt_data + +__all__ = [ + # Annotations + "AuthMethod", + "SecureAnnotations", + "SecureToolAnnotations", + "SecureResourceAnnotations", + "SecurePromptAnnotations", + + # Secure wrappers + "SecureTool", + "SecureResource", + "SecurePrompt", + + # Decorators + "secure_tool", + "secure_resource", + "secure_prompt", + + # Identity & Session + "ToolIdentity", + "ClientIdentity", + "SecureSession", + "SessionManager", + "create_tool_identity", + + # Utils + "SecureAnnotationProcessor", + "encrypt_data", + "decrypt_data", +] \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/secure/annotations.py b/src/mcp/mcp/server/fastmcp/secure/annotations.py new file mode 100644 index 000000000..4c93eb36e --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/annotations.py @@ -0,0 +1,326 @@ +""" +Secure annotations for MCP tools, resources, and prompts. + +These annotations extend the standard MCP annotations with security features. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +from mcp.types import ToolAnnotations, ResourceAnnotations, PromptAnnotations + + +class AuthMethod(Enum): + """Supported authentication methods.""" + JWT = "jwt" + CERTIFICATE = "certificate" + TEE_ATTESTATION = "tee" + OAUTH = "oauth" + API_KEY = "api_key" + MTLS = "mtls" # Mutual TLS + + +@dataclass +class SecureAnnotations: + """ + Base security annotations that can be attached to tools, resources, or prompts. + + These annotations enable security features like authentication, encryption, + and attestation for MCP operations. + """ + + # Authentication settings + require_auth: bool = False + auth_methods: list[AuthMethod] = field(default_factory=lambda: [AuthMethod.JWT]) + required_permissions: set[str] = field(default_factory=set) + require_mutual_auth: bool = False # Bidirectional authentication + + # Encryption settings + encrypt_input: bool = False + encrypt_output: bool = False + encryption_algorithm: str = "AES-256-GCM" + key_exchange_method: str = "ECDH" # ECDH, RSA, Pre-shared + + # Tool/Server attestation + require_tool_attestation: bool = False + tool_certificate_fingerprint: Optional[str] = None + attestation_type: Optional[str] = None # "software", "sgx", "sev", "trustzone" + tool_signature_required: bool = False + + # Client verification + verify_client_certificate: bool = False + trusted_client_issuers: list[str] = field(default_factory=list) + client_attestation_required: bool = False + + # Audit and compliance + audit_log: bool = True + audit_include_inputs: bool = False + audit_include_outputs: bool = False + audit_retention_days: int = 90 + + # Rate limiting + rate_limit: Optional[int] = None # requests per minute + rate_limit_per_client: bool = True + burst_limit: Optional[int] = None + + # Data handling + security_level: str = "standard" # "standard", "high", "critical" + data_classification: str = "public" # "public", "internal", "confidential", "secret" + compliance_tags: list[str] = field(default_factory=list) # ["HIPAA", "PCI-DSS", "GDPR", "SOC2"] + + # Session management + session_timeout_minutes: int = 60 + require_session_binding: bool = False # Bind session to client IP/fingerprint + max_concurrent_sessions: Optional[int] = None + + # Advanced security + require_replay_protection: bool = False + max_request_age_seconds: int = 300 # For replay protection + require_integrity_check: bool = True # Verify message integrity + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "require_auth": self.require_auth, + "auth_methods": [method.value for method in self.auth_methods], + "required_permissions": list(self.required_permissions), + "require_mutual_auth": self.require_mutual_auth, + "encrypt_input": self.encrypt_input, + "encrypt_output": self.encrypt_output, + "encryption_algorithm": self.encryption_algorithm, + "key_exchange_method": self.key_exchange_method, + "require_tool_attestation": self.require_tool_attestation, + "tool_certificate_fingerprint": self.tool_certificate_fingerprint, + "attestation_type": self.attestation_type, + "tool_signature_required": self.tool_signature_required, + "verify_client_certificate": self.verify_client_certificate, + "trusted_client_issuers": self.trusted_client_issuers, + "client_attestation_required": self.client_attestation_required, + "audit_log": self.audit_log, + "audit_include_inputs": self.audit_include_inputs, + "audit_include_outputs": self.audit_include_outputs, + "security_level": self.security_level, + "data_classification": self.data_classification, + "compliance_tags": self.compliance_tags, + "session_timeout_minutes": self.session_timeout_minutes, + "require_session_binding": self.require_session_binding, + "require_replay_protection": self.require_replay_protection, + "require_integrity_check": self.require_integrity_check, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SecureAnnotations: + """Create from dictionary.""" + auth_methods = [AuthMethod(m) for m in data.get("auth_methods", ["jwt"])] + return cls( + require_auth=data.get("require_auth", False), + auth_methods=auth_methods, + required_permissions=set(data.get("required_permissions", [])), + require_mutual_auth=data.get("require_mutual_auth", False), + encrypt_input=data.get("encrypt_input", False), + encrypt_output=data.get("encrypt_output", False), + encryption_algorithm=data.get("encryption_algorithm", "AES-256-GCM"), + key_exchange_method=data.get("key_exchange_method", "ECDH"), + require_tool_attestation=data.get("require_tool_attestation", False), + tool_certificate_fingerprint=data.get("tool_certificate_fingerprint"), + attestation_type=data.get("attestation_type"), + tool_signature_required=data.get("tool_signature_required", False), + verify_client_certificate=data.get("verify_client_certificate", False), + trusted_client_issuers=data.get("trusted_client_issuers", []), + client_attestation_required=data.get("client_attestation_required", False), + audit_log=data.get("audit_log", True), + security_level=data.get("security_level", "standard"), + data_classification=data.get("data_classification", "public"), + compliance_tags=data.get("compliance_tags", []), + ) + + +class SecureToolAnnotations(ToolAnnotations): + """ + Tool annotations with integrated security features. + + This extends the standard ToolAnnotations with security metadata. + """ + + def __init__( + self, + secure: Optional[SecureAnnotations] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure = secure or SecureAnnotations() + + # Store security annotations in extensions + if not hasattr(self, 'extensions'): + self.extensions = {} + self.extensions["security"] = self.secure.to_dict() + + @classmethod + def create( + cls, + # Security parameters + require_auth: bool = False, + auth_methods: Optional[list[AuthMethod]] = None, + required_permissions: Optional[set[str]] = None, + encrypt_io: bool = False, + require_mutual_auth: bool = False, + security_level: str = "standard", + + # Standard tool annotation parameters + audience: Optional[list[str]] = None, + capabilities: Optional[dict[str, Any]] = None, + **kwargs + ) -> SecureToolAnnotations: + """ + Factory method to create secure tool annotations. + + Args: + require_auth: Whether to require authentication + auth_methods: List of accepted authentication methods + required_permissions: Set of required permissions + encrypt_io: Whether to encrypt input/output + require_mutual_auth: Whether to require bidirectional authentication + security_level: Security level (standard/high/critical) + audience: Target audience for the tool + capabilities: Tool capabilities + **kwargs: Additional security parameters + """ + secure_annotations = SecureAnnotations( + require_auth=require_auth, + auth_methods=auth_methods or [AuthMethod.JWT], + required_permissions=required_permissions or set(), + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + require_mutual_auth=require_mutual_auth, + security_level=security_level, + **kwargs + ) + + return cls( + secure=secure_annotations, + audience=audience, + capabilities=capabilities + ) + + +class SecureResourceAnnotations(ResourceAnnotations): + """ + Resource annotations with integrated security features. + + This extends the standard ResourceAnnotations with security metadata. + """ + + def __init__( + self, + secure: Optional[SecureAnnotations] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure = secure or SecureAnnotations() + + # Store security annotations in extensions + if not hasattr(self, 'extensions'): + self.extensions = {} + self.extensions["security"] = self.secure.to_dict() + + @classmethod + def create( + cls, + # Security parameters + require_auth: bool = False, + data_classification: str = "public", + encrypt_io: bool = False, + audit_access: bool = True, + + # Standard resource annotation parameters + content_type: Optional[str] = None, + cache_control: Optional[str] = None, + **kwargs + ) -> SecureResourceAnnotations: + """ + Factory method to create secure resource annotations. + + Args: + require_auth: Whether to require authentication + data_classification: Data classification level + encrypt_io: Whether to encrypt input/output + audit_access: Whether to audit resource access + content_type: Resource content type + cache_control: Cache control headers + **kwargs: Additional security parameters + """ + secure_annotations = SecureAnnotations( + require_auth=require_auth, + data_classification=data_classification, + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + audit_log=audit_access, + **kwargs + ) + + return cls( + secure=secure_annotations, + content_type=content_type, + cache_control=cache_control + ) + + +class SecurePromptAnnotations(PromptAnnotations): + """ + Prompt annotations with integrated security features. + + This extends the standard PromptAnnotations with security metadata. + """ + + def __init__( + self, + secure: Optional[SecureAnnotations] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure = secure or SecureAnnotations() + + # Store security annotations in extensions + if not hasattr(self, 'extensions'): + self.extensions = {} + self.extensions["security"] = self.secure.to_dict() + + @classmethod + def create( + cls, + # Security parameters + require_auth: bool = False, + audit_usage: bool = True, + compliance_tags: Optional[list[str]] = None, + + # Standard prompt annotation parameters + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + **kwargs + ) -> SecurePromptAnnotations: + """ + Factory method to create secure prompt annotations. + + Args: + require_auth: Whether to require authentication + audit_usage: Whether to audit prompt usage + compliance_tags: Compliance tags (e.g., ["GDPR", "HIPAA"]) + max_tokens: Maximum tokens for prompt + temperature: Temperature for prompt generation + **kwargs: Additional security parameters + """ + secure_annotations = SecureAnnotations( + require_auth=require_auth, + audit_log=audit_usage, + compliance_tags=compliance_tags or [], + **kwargs + ) + + return cls( + secure=secure_annotations, + max_tokens=max_tokens, + temperature=temperature + ) diff --git a/src/mcp/mcp/server/fastmcp/secure/identity.py b/src/mcp/mcp/server/fastmcp/secure/identity.py new file mode 100644 index 000000000..3c29b838e --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/identity.py @@ -0,0 +1,390 @@ +""" +Identity management for secure MCP operations. + +Handles both tool identity (server-side) and client identity verification. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Optional + +from cryptography import x509 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes, PublicKeyTypes +from cryptography.x509.oid import NameOID + +from .annotations import AuthMethod + + +@dataclass +class ToolIdentity: + """ + Represents the cryptographic identity of a tool/server. + + This is used for: + - Tool attestation (proving the tool is legitimate) + - Response signing (ensuring response integrity) + - Mutual authentication (bidirectional auth with client) + """ + + tool_id: str + name: str + version: str + certificate: x509.Certificate + private_key: PrivateKeyTypes + trusted_issuers: list[x509.Certificate] + + # Optional attestation for secure enclaves + attestation_report: Optional[dict] = None + attestation_type: Optional[str] = None # "software", "sgx", "sev", "trustzone" + + # Tool capabilities and metadata + capabilities: set[str] = field(default_factory=set) + metadata: dict[str, Any] = field(default_factory=dict) + + @property + def fingerprint(self) -> str: + """Get SHA256 fingerprint of the tool's certificate.""" + cert_der = self.certificate.public_bytes(serialization.Encoding.DER) + return hashlib.sha256(cert_der).hexdigest() + + @property + def public_key(self) -> PublicKeyTypes: + """Get the public key from the certificate.""" + return self.certificate.public_key() + + def sign_data(self, data: bytes) -> bytes: + """ + Sign data with the tool's private key. + + Args: + data: Data to sign + + Returns: + Digital signature + """ + if isinstance(self.private_key, rsa.RSAPrivateKey): + return self.private_key.sign( + data, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH + ), + hashes.SHA256() + ) + elif isinstance(self.private_key, ec.EllipticCurvePrivateKey): + return self.private_key.sign(data, ec.ECDSA(hashes.SHA256())) + else: + raise ValueError(f"Unsupported key type: {type(self.private_key)}") + + def verify_signature(self, data: bytes, signature: bytes, public_key: PublicKeyTypes) -> bool: + """ + Verify a signature using a public key. + + Args: + data: Original data + signature: Signature to verify + public_key: Public key to verify with + + Returns: + True if signature is valid + """ + try: + if isinstance(public_key, rsa.RSAPublicKey): + public_key.verify( + signature, + data, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH + ), + hashes.SHA256() + ) + elif isinstance(public_key, ec.EllipticCurvePublicKey): + public_key.verify(signature, data, ec.ECDSA(hashes.SHA256())) + else: + return False + return True + except Exception: + return False + + def to_attestation(self) -> dict[str, Any]: + """ + Generate attestation data for the tool. + + Returns: + Dictionary containing tool attestation information + """ + import base64 + + attestation = { + "tool_id": self.tool_id, + "name": self.name, + "version": self.version, + "fingerprint": self.fingerprint, + "certificate": base64.b64encode( + self.certificate.public_bytes(serialization.Encoding.PEM) + ).decode(), + "capabilities": list(self.capabilities), + "timestamp": datetime.utcnow().isoformat(), + } + + # Add hardware attestation if available + if self.attestation_report: + attestation["attestation"] = { + "type": self.attestation_type, + "report": self.attestation_report, + } + + # Sign the attestation + attestation_bytes = json.dumps(attestation, sort_keys=True).encode() + signature = self.sign_data(attestation_bytes) + attestation["signature"] = base64.b64encode(signature).decode() + + return attestation + + def verify_client_signature(self, data: bytes, signature: bytes, client_cert: x509.Certificate) -> bool: + """ + Verify a signature from a client certificate. + + Args: + data: Data that was signed + signature: Client's signature + client_cert: Client's certificate + + Returns: + True if signature is valid + """ + client_public_key = client_cert.public_key() + return self.verify_signature(data, signature, client_public_key) + + +@dataclass +class ClientIdentity: + """ + Represents an authenticated client identity. + + This is created after successful authentication and contains + the client's permissions and metadata. + """ + + client_id: str + authentication_method: AuthMethod + credentials: Any # JWT token, certificate, attestation, etc. + permissions: set[str] + + # Optional fields + session_id: Optional[str] = None + organization: Optional[str] = None + email: Optional[str] = None + + # Metadata + metadata: dict[str, Any] = field(default_factory=dict) + authenticated_at: datetime = field(default_factory=datetime.utcnow) + expires_at: Optional[datetime] = None + + # Certificate-based auth specifics + certificate: Optional[x509.Certificate] = None + certificate_fingerprint: Optional[str] = None + + # Rate limiting and quotas + rate_limit: Optional[int] = None + quota_remaining: Optional[int] = None + + def has_permission(self, permission: str) -> bool: + """ + Check if client has a specific permission. + + Args: + permission: Permission to check (e.g., "tool.execute", "resource.read") + + Returns: + True if client has the permission + """ + # Check exact permission + if permission in self.permissions: + return True + + # Check wildcard permissions + if "*" in self.permissions: + return True + + # Check hierarchical permissions (e.g., "tool.*" matches "tool.execute") + parts = permission.split(".") + for i in range(len(parts)): + wildcard_perm = ".".join(parts[:i+1]) + ".*" + if wildcard_perm in self.permissions: + return True + + return False + + def is_expired(self) -> bool: + """Check if the client identity has expired.""" + if self.expires_at is None: + return False + return datetime.utcnow() > self.expires_at + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "client_id": self.client_id, + "authentication_method": self.authentication_method.value, + "permissions": list(self.permissions), + "organization": self.organization, + "email": self.email, + "authenticated_at": self.authenticated_at.isoformat(), + "expires_at": self.expires_at.isoformat() if self.expires_at else None, + "certificate_fingerprint": self.certificate_fingerprint, + "metadata": self.metadata, + } + + +def create_tool_identity( + tool_name: str, + tool_version: str, + organization: str = "MCP-Secure", + country: str = "US", + validity_days: int = 365, + key_type: str = "EC" # "EC" or "RSA" +) -> ToolIdentity: + """ + Create a tool identity with a self-signed certificate. + + In production, you would use a proper CA-signed certificate. + + Args: + tool_name: Name of the tool + tool_version: Version of the tool + organization: Organization name + country: Country code + validity_days: Certificate validity period + key_type: Key type ("EC" for elliptic curve, "RSA" for RSA) + + Returns: + ToolIdentity with generated certificate and key + """ + # Generate key pair + if key_type == "RSA": + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + else: # EC + private_key = ec.generate_private_key(ec.SECP256R1()) + + # Create certificate + subject = issuer = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, country), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, organization), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "MCP-Tools"), + x509.NameAttribute(NameOID.COMMON_NAME, f"{tool_name}-v{tool_version}"), + ]) + + # Build certificate + builder = x509.CertificateBuilder() + builder = builder.subject_name(subject) + builder = builder.issuer_name(issuer) + builder = builder.public_key(private_key.public_key()) + builder = builder.serial_number(x509.random_serial_number()) + builder = builder.not_valid_before(datetime.utcnow()) + builder = builder.not_valid_after(datetime.utcnow() + timedelta(days=validity_days)) + + # Add extensions + builder = builder.add_extension( + x509.SubjectAlternativeName([ + x509.DNSName(f"{tool_name}.local"), + x509.DNSName("localhost"), + ]), + critical=False, + ) + + builder = builder.add_extension( + x509.KeyUsage( + digital_signature=True, + key_encipherment=True, + content_commitment=True, + data_encipherment=False, + key_agreement=True, + key_cert_sign=False, + crl_sign=False, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) + + builder = builder.add_extension( + x509.ExtendedKeyUsage([ + x509.oid.ExtendedKeyUsageOID.SERVER_AUTH, + x509.oid.ExtendedKeyUsageOID.CLIENT_AUTH, + ]), + critical=True, + ) + + # Self-sign the certificate + certificate = builder.sign(private_key, hashes.SHA256()) + + return ToolIdentity( + tool_id=f"{tool_name}-{tool_version}", + name=tool_name, + version=tool_version, + certificate=certificate, + private_key=private_key, + trusted_issuers=[certificate], # Self-signed + capabilities={ + "authentication.mutual", + "encryption.aes256", + "signing.sha256", + }, + metadata={ + "created_at": datetime.utcnow().isoformat(), + "key_type": key_type, + "organization": organization, + } + ) + + +def verify_tool_certificate( + certificate: x509.Certificate, + trusted_cas: list[x509.Certificate], + check_revocation: bool = True +) -> tuple[bool, Optional[str]]: + """ + Verify a tool's certificate against trusted CAs. + + Args: + certificate: Certificate to verify + trusted_cas: List of trusted CA certificates + check_revocation: Whether to check certificate revocation + + Returns: + Tuple of (is_valid, error_message) + """ + # Check certificate validity period + now = datetime.utcnow() + if now < certificate.not_valid_before: + return False, "Certificate not yet valid" + if now > certificate.not_valid_after: + return False, "Certificate has expired" + + # Verify certificate chain + for ca in trusted_cas: + try: + ca.public_key().verify( + certificate.signature, + certificate.tbs_certificate_bytes, + certificate.signature_algorithm_oid._name + ) + + # If we reach here, signature is valid + if check_revocation: + # In production, check CRL or OCSP + pass + + return True, None + except Exception: + continue + + return False, "Certificate not signed by trusted CA" \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/secure/prompt.py b/src/mcp/mcp/server/fastmcp/secure/prompt.py new file mode 100644 index 000000000..8956f1108 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/prompt.py @@ -0,0 +1,380 @@ +""" +Secure prompt implementation with authentication and compliance. +""" + +from __future__ import annotations + +import functools +import hashlib +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast + +from mcp.server.fastmcp.prompts.base import Prompt +from mcp.types import Error, Message + +from .annotations import AuthMethod, SecureAnnotations, SecurePromptAnnotations +from .identity import ToolIdentity +from .utils import SecureAnnotationProcessor + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context, FastMCP + +F = TypeVar('F', bound=Callable[..., Any]) + + +class SecurePrompt(Prompt): + """ + Secure prompt with authentication, compliance, and audit support. + + This extends the base Prompt class with security features for + handling sensitive prompts and ensuring compliance. + """ + + def __init__( + self, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + processor: Optional[SecureAnnotationProcessor] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = processor or SecureAnnotationProcessor(tool_identity=tool_identity) + + # Add compliance metadata + self._compliance_metadata = { + "compliance_tags": secure_annotations.compliance_tags, + "data_classification": secure_annotations.data_classification, + "audit_required": secure_annotations.audit_log, + } + + async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]: + """ + Render the secure prompt with authentication and compliance checks. + + This method: + 1. Verifies client authentication (if required) + 2. Checks compliance requirements + 3. Sanitizes/filters sensitive information + 4. Renders the prompt + 5. Audits the usage + """ + arguments = arguments or {} + + # In production, extract auth from request context + auth_header = None # Would come from request context + + # Process secure request + try: + session, processed_args = await self.processor.process_secure_request( + annotations=self.secure_annotations, + auth_header=auth_header, + client_cert=None, + request_data=arguments + ) + except Error as e: + # Audit failed access attempt + self.processor._audit_log( + session=None, + action="prompt_access_denied", + include_data=False, + data={"prompt": self.name, "error": str(e)} + ) + raise + + # Check compliance requirements + if self.secure_annotations.compliance_tags: + await self._check_compliance(session, processed_args) + + # Sanitize sensitive information if needed + if self.secure_annotations.data_classification in ["confidential", "secret"]: + processed_args = await self._sanitize_arguments(processed_args) + + # Render the actual prompt + messages = await super().render(processed_args) + + # Post-process messages for security + secure_messages = await self._secure_messages(messages, session) + + # Audit prompt usage + if self.secure_annotations.audit_log: + await self._audit_prompt_usage(session, processed_args, secure_messages) + + return secure_messages + + async def _check_compliance(self, session, arguments: dict) -> None: + """ + Check compliance requirements before rendering the prompt. + + Args: + session: Secure session + arguments: Prompt arguments + + Raises: + Error: If compliance requirements are not met + """ + for tag in self.secure_annotations.compliance_tags: + if tag == "GDPR": + # Check GDPR compliance (e.g., purpose limitation, data minimization) + if "personal_data" in arguments and not session.client_identity.has_permission("gdpr.process"): + raise Error(code=403, message="GDPR: Missing permission to process personal data") + + elif tag == "HIPAA": + # Check HIPAA compliance for health information + if "health_data" in arguments and not session.client_identity.has_permission("hipaa.access"): + raise Error(code=403, message="HIPAA: Not authorized to access health information") + + elif tag == "PCI-DSS": + # Check PCI-DSS compliance for payment card data + if "card_data" in arguments: + # Ensure card data is masked/tokenized + if not self._is_card_data_safe(arguments["card_data"]): + raise Error(code=400, message="PCI-DSS: Card data must be tokenized") + + async def _sanitize_arguments(self, arguments: dict) -> dict: + """ + Sanitize sensitive information from arguments. + + Args: + arguments: Original arguments + + Returns: + Sanitized arguments + """ + sanitized = {} + for key, value in arguments.items(): + if key in ["ssn", "credit_card", "password", "api_key"]: + # Mask sensitive fields + sanitized[key] = self._mask_sensitive_data(str(value)) + elif isinstance(value, dict): + # Recursively sanitize nested data + sanitized[key] = await self._sanitize_arguments(value) + else: + sanitized[key] = value + + return sanitized + + def _mask_sensitive_data(self, data: str) -> str: + """Mask sensitive data while preserving format hints.""" + if len(data) <= 4: + return "*" * len(data) + + # Show first and last 2 characters only + return data[:2] + "*" * (len(data) - 4) + data[-2:] + + def _is_card_data_safe(self, card_data: str) -> bool: + """Check if card data is properly tokenized/masked.""" + # Check if it's a token (e.g., tok_xxxx) or masked number + return card_data.startswith("tok_") or "*" in card_data + + async def _secure_messages(self, messages: list[Message], session) -> list[Message]: + """ + Apply security transformations to messages. + + Args: + messages: Original messages + session: Secure session + + Returns: + Secured messages + """ + secure_msgs = [] + + for msg in messages: + secure_msg = msg.copy() if hasattr(msg, 'copy') else msg + + # Add security headers to system messages + if isinstance(msg, dict) and msg.get("role") == "system": + if self.secure_annotations.compliance_tags: + compliance_notice = f"[Compliance: {', '.join(self.secure_annotations.compliance_tags)}] " + secure_msg["content"] = compliance_notice + secure_msg.get("content", "") + + # Add classification labels + if self.secure_annotations.data_classification != "public": + if isinstance(secure_msg, dict): + secure_msg["metadata"] = secure_msg.get("metadata", {}) + secure_msg["metadata"]["classification"] = self.secure_annotations.data_classification + + secure_msgs.append(secure_msg) + + return secure_msgs + + async def _audit_prompt_usage(self, session, arguments: dict, messages: list) -> None: + """ + Audit prompt usage for compliance and security monitoring. + + Args: + session: Secure session + arguments: Prompt arguments + messages: Generated messages + """ + audit_data = { + "prompt_name": self.name, + "client_id": session.client_identity.client_id if session.client_identity else "anonymous", + "compliance_tags": self.secure_annotations.compliance_tags, + "data_classification": self.secure_annotations.data_classification, + "message_count": len(messages), + } + + if self.secure_annotations.audit_include_inputs: + # Hash sensitive arguments for audit + audit_data["argument_hash"] = hashlib.sha256( + str(arguments).encode() + ).hexdigest() + + if self.secure_annotations.audit_include_outputs: + # Include message metadata (not content) + audit_data["message_roles"] = [ + msg.get("role") if isinstance(msg, dict) else "unknown" + for msg in messages + ] + + self.processor._audit_log( + session=session, + action="prompt_rendered", + include_data=True, + data=audit_data + ) + + +def secure_prompt( + # Security parameters + require_auth: bool = False, + audit_usage: bool = True, + compliance_tags: Optional[list[str]] = None, + data_classification: str = "public", + + # Standard prompt parameters + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, +) -> Callable[[F], F]: + """ + Decorator to create a secure prompt with compliance and audit support. + + This decorator wraps a function to create a secure MCP prompt that supports: + - Authentication and authorization + - Compliance checking (GDPR, HIPAA, PCI-DSS, etc.) + - Sensitive data sanitization + - Usage auditing + + Args: + require_auth: Whether to require authentication + audit_usage: Whether to audit prompt usage + compliance_tags: Compliance requirements (e.g., ["GDPR", "HIPAA"]) + data_classification: Data classification level + name: Prompt name + title: Prompt title + description: Prompt description + + Example: + ```python + @secure_prompt( + require_auth=True, + compliance_tags=["GDPR", "HIPAA"], + data_classification="confidential", + audit_usage=True + ) + async def medical_diagnosis_prompt( + patient_id: str, + symptoms: list[str], + ctx: Context + ) -> list[Message]: + # Ensure HIPAA compliance + return [ + { + "role": "system", + "content": "You are a medical AI assistant. Maintain patient confidentiality." + }, + { + "role": "user", + "content": f"Analyze symptoms for patient (ID: {patient_id}): {symptoms}" + } + ] + ``` + """ + # Create secure annotations + secure_annotations = SecureAnnotations( + require_auth=require_auth, + audit_log=audit_usage, + compliance_tags=compliance_tags or [], + data_classification=data_classification, + ) + + def decorator(func: F) -> F: + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # This wrapper would be replaced by SecurePrompt.render in production + result = await func(*args, **kwargs) if inspect.iscoroutinefunction(func) else func(*args, **kwargs) + return result + + # Store security metadata on the function + wrapper._secure_annotations = secure_annotations + wrapper._is_secure_prompt = True + wrapper._compliance_tags = compliance_tags or [] + wrapper._data_classification = data_classification + + return cast(F, wrapper) + + return decorator + + +class ComplianceValidator: + """ + Validator for ensuring prompts meet compliance requirements. + """ + + @staticmethod + def validate_gdpr(prompt_content: str, metadata: dict) -> tuple[bool, Optional[str]]: + """ + Validate GDPR compliance for a prompt. + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for purpose limitation + if "purpose" not in metadata: + return False, "GDPR requires explicit purpose declaration" + + # Check for data minimization + sensitive_keywords = ["ssn", "email", "phone", "address", "name"] + if any(keyword in prompt_content.lower() for keyword in sensitive_keywords): + if "legal_basis" not in metadata: + return False, "GDPR requires legal basis for processing personal data" + + return True, None + + @staticmethod + def validate_hipaa(prompt_content: str, metadata: dict) -> tuple[bool, Optional[str]]: + """ + Validate HIPAA compliance for a prompt. + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for PHI safeguards + phi_keywords = ["patient", "diagnosis", "treatment", "medical", "health"] + if any(keyword in prompt_content.lower() for keyword in phi_keywords): + if "hipaa_safeguards" not in metadata: + return False, "HIPAA requires safeguards for Protected Health Information" + + return True, None + + @staticmethod + def validate_pci_dss(prompt_content: str, metadata: dict) -> tuple[bool, Optional[str]]: + """ + Validate PCI-DSS compliance for a prompt. + + Returns: + Tuple of (is_valid, error_message) + """ + # Check for credit card data + import re + + # Simple regex for credit card patterns (not comprehensive) + cc_pattern = r'\b\d{4}[\s-]?\d{4}[\s-]?\d{4}[\s-]?\d{4}\b' + if re.search(cc_pattern, prompt_content): + return False, "PCI-DSS prohibits storage of unencrypted card numbers" + + return True, None \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/secure/resource.py b/src/mcp/mcp/server/fastmcp/secure/resource.py new file mode 100644 index 000000000..23e56b4ef --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/resource.py @@ -0,0 +1,270 @@ +""" +Secure resource implementation with authentication and encryption. +""" + +from __future__ import annotations + +import functools +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast + +from mcp.server.fastmcp.resources.base import Resource +from mcp.types import Error + +from .annotations import AuthMethod, SecureAnnotations, SecureResourceAnnotations +from .identity import ToolIdentity +from .utils import SecureAnnotationProcessor + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import FastMCP + +F = TypeVar('F', bound=Callable[..., Any]) + + +class SecureResource(Resource): + """ + Secure resource with authentication, encryption, and access control. + + This extends the base Resource class with security features. + """ + + def __init__( + self, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + processor: Optional[SecureAnnotationProcessor] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = processor or SecureAnnotationProcessor(tool_identity=tool_identity) + + # Add security metadata to the resource + self._secure_metadata = { + "data_classification": secure_annotations.data_classification, + "encryption_required": secure_annotations.encrypt_output, + "auth_required": secure_annotations.require_auth, + "compliance_tags": secure_annotations.compliance_tags, + } + + async def read(self) -> str | bytes: + """ + Read the secure resource with authentication and encryption. + + This method: + 1. Verifies client authentication (if required) + 2. Checks access permissions + 3. Reads the resource + 4. Encrypts the content (if required) + 5. Audits the access + """ + # In production, extract auth from request context + auth_header = None # Would come from request context + + # Process secure request + try: + session, _ = await self.processor.process_secure_request( + annotations=self.secure_annotations, + auth_header=auth_header, + client_cert=None, + request_data={} + ) + except Error as e: + # Audit failed access attempt + self.processor._audit_log( + session=None, + action="resource_access_denied", + include_data=False, + data={"uri": self.uri, "error": str(e)} + ) + raise + + # Check specific resource permissions + if session.client_identity: + resource_permission = f"resource.read.{self.name or self.uri}" + if not session.client_identity.has_permission(resource_permission) and \ + not session.client_identity.has_permission("resource.read.*"): + raise Error( + code=403, + message=f"Client lacks permission to read resource: {self.uri}" + ) + + # Read the actual resource content + content = await super().read() + + # Process secure response (encrypt if required) + if self.secure_annotations.encrypt_output: + secure_content = await self.processor.process_secure_response( + annotations=self.secure_annotations, + session=session, + response_data=content + ) + + # Convert encrypted response to string/bytes + if isinstance(secure_content, dict) and "data" in secure_content: + content = secure_content["data"] + + # Audit successful access + if self.secure_annotations.audit_log: + self.processor._audit_log( + session=session, + action="resource_accessed", + include_data=self.secure_annotations.audit_include_outputs, + data={ + "uri": self.uri, + "classification": self.secure_annotations.data_classification, + "size": len(content) if isinstance(content, (str, bytes)) else None + } + ) + + return content + + +def secure_resource( + uri: str, + # Security parameters + require_auth: bool = False, + data_classification: str = "public", + encrypt_io: bool = False, + audit_access: bool = True, + compliance_tags: Optional[list[str]] = None, + + # Standard resource parameters + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + mime_type: Optional[str] = None, +) -> Callable[[F], F]: + """ + Decorator to create a secure resource with authentication and encryption. + + This decorator wraps a function to create a secure MCP resource that supports: + - Access control and authentication + - Data classification and compliance + - Encryption for sensitive data + - Audit logging + + Args: + uri: Resource URI + require_auth: Whether to require authentication + data_classification: Classification level (public/internal/confidential/secret) + encrypt_io: Whether to encrypt the resource content + audit_access: Whether to audit resource access + compliance_tags: Compliance tags (e.g., ["GDPR", "HIPAA"]) + name: Resource name + title: Resource title + description: Resource description + mime_type: MIME type + + Example: + ```python + @secure_resource( + "secure://financial/portfolio/{account_id}", + require_auth=True, + data_classification="confidential", + encrypt_io=True, + compliance_tags=["PCI-DSS", "SOC2"] + ) + async def get_portfolio(account_id: str) -> dict: + # Resource implementation + return { + "account_id": account_id, + "balance": 100000, + "holdings": [...] + } + ``` + """ + # Create secure annotations + secure_annotations = SecureAnnotations( + require_auth=require_auth, + data_classification=data_classification, + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + audit_log=audit_access, + compliance_tags=compliance_tags or [], + ) + + def decorator(func: F) -> F: + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # This wrapper would be replaced by SecureResource.read in production + result = await func(*args, **kwargs) if inspect.iscoroutinefunction(func) else func(*args, **kwargs) + return result + + # Store security metadata on the function + wrapper._secure_annotations = secure_annotations + wrapper._resource_uri = uri + wrapper._is_secure_resource = True + wrapper._data_classification = data_classification + wrapper._compliance_tags = compliance_tags or [] + + return cast(F, wrapper) + + return decorator + + +class SecureResourceTemplate: + """ + Template for secure resources with dynamic URIs. + + Supports resources like "secure://data/{category}/{item_id}" + """ + + def __init__( + self, + uri_template: str, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + ): + self.uri_template = uri_template + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = SecureAnnotationProcessor(tool_identity=tool_identity) + + def create_resource(self, **params) -> SecureResource: + """ + Create a secure resource instance with the given parameters. + + Args: + **params: Parameters to fill in the URI template + + Returns: + SecureResource instance + """ + # Format the URI with parameters + uri = self.uri_template.format(**params) + + return SecureResource( + uri=uri, + secure_annotations=self.secure_annotations, + tool_identity=self.tool_identity, + processor=self.processor, + ) + + def validate_access(self, client_identity, params: dict) -> bool: + """ + Validate if a client can access a resource with given parameters. + + Args: + client_identity: Client identity to validate + params: Resource parameters + + Returns: + True if access is allowed, False otherwise + """ + # Check base permissions + if not client_identity.has_permission(f"resource.read.{self.uri_template}"): + return False + + # Check parameter-specific permissions + # For example, for "secure://portfolio/{account_id}", + # check if client can access that specific account + for param_name, param_value in params.items(): + specific_perm = f"resource.{param_name}.{param_value}" + if not client_identity.has_permission(specific_perm): + # Check wildcard permission + if not client_identity.has_permission(f"resource.{param_name}.*"): + return False + + return True \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/secure/session.py b/src/mcp/mcp/server/fastmcp/secure/session.py new file mode 100644 index 000000000..2f474d8f0 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/session.py @@ -0,0 +1,473 @@ +""" +Session management for secure MCP operations. + +Handles secure session establishment, key exchange, and session lifecycle. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import os +import secrets +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any, Dict, Optional + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM, ChaCha20Poly1305 +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec, dh +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2 +from cryptography.hazmat.primitives.asymmetric.types import PublicKeyTypes + +from .identity import ClientIdentity, ToolIdentity + + +@dataclass +class SecureSession: + """ + Represents a secure session between a client and tool. + + Supports: + - Mutual authentication + - Key exchange and encryption + - Session binding and replay protection + """ + + session_id: str + client_identity: Optional[ClientIdentity] + tool_identity: Optional[ToolIdentity] + established_at: datetime + expires_at: datetime + + # Encryption + encryption_algorithm: str = "AES-256-GCM" # or "ChaCha20-Poly1305" + encryption_key: Optional[AESGCM | ChaCha20Poly1305] = None + client_public_key: Optional[PublicKeyTypes] = None + server_public_key: Optional[PublicKeyTypes] = None + + # Session binding + client_ip: Optional[str] = None + client_fingerprint: Optional[str] = None + bound_to_client: bool = False + + # Replay protection + nonce_counter: int = 0 + used_nonces: set[str] = field(default_factory=set) + max_nonce_age_seconds: int = 300 + + # Rate limiting + request_count: int = 0 + last_request_at: Optional[datetime] = None + + # Metadata + metadata: dict[str, Any] = field(default_factory=dict) + + def is_valid(self) -> bool: + """Check if session is still valid.""" + now = datetime.utcnow() + + # Check expiration + if now > self.expires_at: + return False + + # Check client identity expiration + if self.client_identity and self.client_identity.is_expired(): + return False + + return True + + def is_bound_to(self, client_ip: str, client_fingerprint: str) -> bool: + """ + Check if session is bound to the requesting client. + + Args: + client_ip: Client IP address + client_fingerprint: Client fingerprint (e.g., TLS fingerprint) + + Returns: + True if session binding matches + """ + if not self.bound_to_client: + return True + + if self.client_ip and self.client_ip != client_ip: + return False + + if self.client_fingerprint and self.client_fingerprint != client_fingerprint: + return False + + return True + + def encrypt(self, data: bytes, associated_data: Optional[bytes] = None) -> bytes: + """ + Encrypt data using session key. + + Args: + data: Data to encrypt + associated_data: Additional authenticated data + + Returns: + Encrypted data with nonce prepended + """ + if not self.encryption_key: + raise ValueError("No encryption key established") + + nonce = os.urandom(12) # 96-bit nonce for AES-GCM + ciphertext = self.encryption_key.encrypt(nonce, data, associated_data) + + return nonce + ciphertext + + def decrypt( + self, + encrypted_data: bytes, + associated_data: Optional[bytes] = None + ) -> bytes: + """ + Decrypt data using session key. + + Args: + encrypted_data: Encrypted data with nonce prepended + associated_data: Additional authenticated data + + Returns: + Decrypted data + """ + if not self.encryption_key: + raise ValueError("No encryption key established") + + nonce, ciphertext = encrypted_data[:12], encrypted_data[12:] + + # Check for nonce reuse (replay protection) + nonce_b64 = base64.b64encode(nonce).decode() + if nonce_b64 in self.used_nonces: + raise ValueError("Nonce reuse detected - possible replay attack") + + self.used_nonces.add(nonce_b64) + self.nonce_counter += 1 + + return self.encryption_key.decrypt(nonce, ciphertext, associated_data) + + def generate_request_token(self) -> str: + """ + Generate a request token for replay protection. + + Returns: + Base64-encoded request token + """ + timestamp = datetime.utcnow().isoformat() + nonce = secrets.token_bytes(16) + + token_data = f"{self.session_id}:{timestamp}:{base64.b64encode(nonce).decode()}" + + # Sign the token + if self.encryption_key and isinstance(self.encryption_key, AESGCM): + # Use HMAC with part of the session key + key_bytes = self.encryption_key._key[:16] # Use first 16 bytes for HMAC + signature = hmac.new(key_bytes, token_data.encode(), hashlib.sha256).digest() + + return base64.b64encode( + token_data.encode() + signature + ).decode() + + return base64.b64encode(token_data.encode()).decode() + + def verify_request_token(self, token: str) -> bool: + """ + Verify a request token for replay protection. + + Args: + token: Request token to verify + + Returns: + True if token is valid and fresh + """ + try: + decoded = base64.b64decode(token) + + if self.encryption_key and isinstance(self.encryption_key, AESGCM): + # Split token and signature + token_data = decoded[:-32] + signature = decoded[-32:] + + # Verify signature + key_bytes = self.encryption_key._key[:16] + expected_sig = hmac.new(key_bytes, token_data, hashlib.sha256).digest() + + if not hmac.compare_digest(signature, expected_sig): + return False + else: + token_data = decoded + + # Parse token + parts = token_data.decode().split(":") + if len(parts) != 3: + return False + + session_id, timestamp_str, nonce_b64 = parts + + # Verify session ID + if session_id != self.session_id: + return False + + # Check timestamp freshness + timestamp = datetime.fromisoformat(timestamp_str) + age = (datetime.utcnow() - timestamp).total_seconds() + + if age > self.max_nonce_age_seconds: + return False + + # Check nonce uniqueness + if nonce_b64 in self.used_nonces: + return False + + self.used_nonces.add(nonce_b64) + + return True + + except Exception: + return False + + def rotate_session_key(self) -> None: + """Rotate the session encryption key.""" + if not self.encryption_key: + return + + # Derive new key from old key + if isinstance(self.encryption_key, AESGCM): + old_key = self.encryption_key._key + + # Use HKDF to derive new key + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=None, + info=b'session-key-rotation', + ) + new_key = hkdf.derive(old_key + self.session_id.encode()) + + self.encryption_key = AESGCM(new_key) + + # Clear nonce history on key rotation + self.used_nonces.clear() + self.nonce_counter = 0 + + +class SessionManager: + """ + Manages secure sessions for MCP operations. + """ + + def __init__( + self, + tool_identity: Optional[ToolIdentity] = None, + session_timeout_minutes: int = 60, + max_sessions_per_client: int = 10, + ): + self.tool_identity = tool_identity + self.session_timeout_minutes = session_timeout_minutes + self.max_sessions_per_client = max_sessions_per_client + + # Session storage + self.sessions: Dict[str, SecureSession] = {} + self.client_sessions: Dict[str, list[str]] = {} # client_id -> [session_ids] + + # DH parameters for key exchange + self._dh_parameters = None + self._ecdh_curve = ec.SECP256R1() + + def create_session( + self, + client_identity: Optional[ClientIdentity] = None, + encryption_algorithm: str = "AES-256-GCM", + bind_to_client: bool = False, + client_ip: Optional[str] = None, + client_fingerprint: Optional[str] = None, + ) -> SecureSession: + """ + Create a new secure session. + + Args: + client_identity: Authenticated client identity + encryption_algorithm: Encryption algorithm to use + bind_to_client: Whether to bind session to client + client_ip: Client IP for session binding + client_fingerprint: Client fingerprint for session binding + + Returns: + New SecureSession instance + """ + # Generate session ID + session_id = base64.urlsafe_b64encode(os.urandom(32)).decode().rstrip("=") + + # Check session limit per client + if client_identity: + client_id = client_identity.client_id + if client_id in self.client_sessions: + if len(self.client_sessions[client_id]) >= self.max_sessions_per_client: + # Remove oldest session + oldest_session_id = self.client_sessions[client_id][0] + self.revoke_session(oldest_session_id) + + # Create session + session = SecureSession( + session_id=session_id, + client_identity=client_identity, + tool_identity=self.tool_identity, + established_at=datetime.utcnow(), + expires_at=datetime.utcnow() + timedelta(minutes=self.session_timeout_minutes), + encryption_algorithm=encryption_algorithm, + bound_to_client=bind_to_client, + client_ip=client_ip, + client_fingerprint=client_fingerprint, + ) + + # Store session + self.sessions[session_id] = session + + # Track client sessions + if client_identity: + client_id = client_identity.client_id + if client_id not in self.client_sessions: + self.client_sessions[client_id] = [] + self.client_sessions[client_id].append(session_id) + + return session + + def get_session(self, session_id: str) -> Optional[SecureSession]: + """ + Get a session by ID. + + Args: + session_id: Session ID + + Returns: + SecureSession if found and valid + """ + session = self.sessions.get(session_id) + + if session and session.is_valid(): + return session + + # Remove invalid session + if session: + self.revoke_session(session_id) + + return None + + def revoke_session(self, session_id: str) -> None: + """ + Revoke a session. + + Args: + session_id: Session ID to revoke + """ + session = self.sessions.pop(session_id, None) + + if session and session.client_identity: + # Remove from client sessions + client_id = session.client_identity.client_id + if client_id in self.client_sessions: + self.client_sessions[client_id] = [ + sid for sid in self.client_sessions[client_id] + if sid != session_id + ] + + def perform_ecdh_key_exchange( + self, + session: SecureSession, + client_public_key_pem: bytes + ) -> bytes: + """ + Perform ECDH key exchange to establish session key. + + Args: + session: Session to establish key for + client_public_key_pem: Client's public key in PEM format + + Returns: + Server's public key in PEM format + """ + # Generate server's ephemeral key pair + server_private_key = ec.generate_private_key(self._ecdh_curve) + server_public_key = server_private_key.public_key() + + # Load client's public key + client_public_key = serialization.load_pem_public_key(client_public_key_pem) + + # Perform ECDH to get shared secret + shared_secret = server_private_key.exchange( + ec.ECDH(), + client_public_key + ) + + # Derive session key using HKDF + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, # 256-bit key + salt=session.session_id.encode()[:16], # Use session ID as salt + info=b'mcp-session-key', + ) + session_key = hkdf.derive(shared_secret) + + # Create cipher based on algorithm + if session.encryption_algorithm == "ChaCha20-Poly1305": + session.encryption_key = ChaCha20Poly1305(session_key) + else: # Default to AES-256-GCM + session.encryption_key = AESGCM(session_key) + + # Store public keys + session.client_public_key = client_public_key + session.server_public_key = server_public_key + + # Return server's public key + return server_public_key.public_key_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + + def establish_pre_shared_key( + self, + session: SecureSession, + pre_shared_key: bytes + ) -> None: + """ + Establish session key from pre-shared key. + + Args: + session: Session to establish key for + pre_shared_key: Pre-shared key + """ + # Derive session key from PSK using PBKDF2 + kdf = PBKDF2( + algorithm=hashes.SHA256(), + length=32, + salt=session.session_id.encode()[:16], + iterations=100000, + ) + session_key = kdf.derive(pre_shared_key) + + # Create cipher + if session.encryption_algorithm == "ChaCha20-Poly1305": + session.encryption_key = ChaCha20Poly1305(session_key) + else: + session.encryption_key = AESGCM(session_key) + + def cleanup_expired_sessions(self) -> int: + """ + Clean up expired sessions. + + Returns: + Number of sessions removed + """ + expired_sessions = [ + session_id for session_id, session in self.sessions.items() + if not session.is_valid() + ] + + for session_id in expired_sessions: + self.revoke_session(session_id) + + return len(expired_sessions) \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/secure/tool.py b/src/mcp/mcp/server/fastmcp/secure/tool.py new file mode 100644 index 000000000..554f7e0d7 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/tool.py @@ -0,0 +1,272 @@ +""" +Secure tool implementation with authentication and encryption. +""" + +from __future__ import annotations + +import functools +import inspect +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast + +from mcp.server.fastmcp.tools.base import Tool +from mcp.types import ContentBlock, Error + +from .annotations import SecureAnnotations, SecureToolAnnotations +from .identity import ClientIdentity, ToolIdentity +from .session import SecureSession +from .utils import SecureAnnotationProcessor + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context, FastMCP + +F = TypeVar('F', bound=Callable[..., Any]) + + +class SecureTool(Tool): + """ + Secure tool with authentication, encryption, and attestation support. + + This extends the base Tool class with security features. + """ + + def __init__( + self, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + processor: Optional[SecureAnnotationProcessor] = None, + **kwargs + ): + super().__init__(**kwargs) + self.secure_annotations = secure_annotations + self.tool_identity = tool_identity + self.processor = processor or SecureAnnotationProcessor(tool_identity=tool_identity) + + # Update annotations with security metadata + if not self.annotations: + self.annotations = SecureToolAnnotations(secure=secure_annotations) + elif hasattr(self.annotations, 'extensions'): + self.annotations.extensions["security"] = secure_annotations.to_dict() + + async def run( + self, + arguments: dict[str, Any], + context: Context | None = None, + convert_result: bool = False, + ) -> Any: + """ + Run the secure tool with authentication and encryption. + + This method: + 1. Authenticates the client (if required) + 2. Verifies tool identity (if mutual auth is enabled) + 3. Decrypts input (if encryption is enabled) + 4. Executes the tool + 5. Encrypts output (if encryption is enabled) + 6. Signs the result (if attestation is enabled) + """ + # Extract authentication information from context + auth_header = None + client_cert = None + if context and hasattr(context, 'request_context'): + request = getattr(context.request_context, 'request', None) + if request: + auth_header = request.headers.get('Authorization') + # In production, extract client cert from TLS connection + + # Process secure request (authenticate, decrypt, etc.) + try: + session, processed_args = await self.processor.process_secure_request( + annotations=self.secure_annotations, + auth_header=auth_header, + client_cert=client_cert, + request_data=arguments + ) + except Error as e: + # Log authentication failure + if context: + await context.error(f"Security check failed: {e.message}") + raise + + # If mutual authentication is required, send tool attestation + if self.secure_annotations.require_mutual_auth and self.tool_identity: + attestation = self.tool_identity.to_attestation() + if context: + await context.info(f"Tool attestation: {self.name} (fingerprint: {attestation['fingerprint'][:16]}...)") + + # Log the authenticated execution + if context and session.client_identity: + await context.info( + f"Executing secure tool '{self.name}' for client '{session.client_identity.client_id}' " + f"(auth: {session.client_identity.authentication_method.value})" + ) + + # Execute the actual tool function with processed arguments + try: + # Inject session into arguments if function expects it + sig = inspect.signature(self.fn) + if '_secure_session' in sig.parameters: + processed_args['_secure_session'] = session + + result = await super().run( + arguments=processed_args, + context=context, + convert_result=convert_result + ) + except Exception as e: + # Audit the failure + if self.secure_annotations.audit_log: + if context: + await context.error(f"Tool execution failed: {str(e)}") + raise + + # Process secure response (encrypt, sign, etc.) + secure_result = await self.processor.process_secure_response( + annotations=self.secure_annotations, + session=session, + response_data=result + ) + + return secure_result + + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + **kwargs + ) -> SecureTool: + """Create a SecureTool from a function.""" + # Create base tool first + base_tool = Tool.from_function( + fn=fn, + name=name, + title=title, + description=description, + **kwargs + ) + + # Create secure tool with same properties + return cls( + fn=base_tool.fn, + name=base_tool.name, + title=base_tool.title, + description=base_tool.description, + parameters=base_tool.parameters, + fn_metadata=base_tool.fn_metadata, + is_async=base_tool.is_async, + context_kwarg=base_tool.context_kwarg, + secure_annotations=secure_annotations, + tool_identity=tool_identity, + annotations=SecureToolAnnotations(secure=secure_annotations) + ) + + +def secure_tool( + # Security parameters + require_auth: bool = False, + auth_methods: Optional[list] = None, + required_permissions: Optional[set[str]] = None, + encrypt_io: bool = False, + require_mutual_auth: bool = False, + security_level: str = "standard", + tool_identity: Optional[ToolIdentity] = None, + + # Standard tool parameters + name: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + structured_output: Optional[bool] = None, +) -> Callable[[F], F]: + """ + Decorator to create a secure tool with authentication and encryption. + + This decorator wraps a function to create a secure MCP tool that supports: + - Client authentication (JWT, certificates, TEE attestation) + - Bidirectional authentication (tool ↔ client) + - Input/output encryption + - Audit logging and compliance + + Args: + require_auth: Whether to require authentication + auth_methods: List of accepted authentication methods + required_permissions: Permissions required to execute the tool + encrypt_io: Whether to encrypt input and output + require_mutual_auth: Whether to require bidirectional authentication + security_level: Security level (standard/high/critical) + tool_identity: Tool identity for attestation + name: Tool name + title: Tool title + description: Tool description + structured_output: Whether to use structured output + + Example: + ```python + @secure_tool( + require_auth=True, + required_permissions={"trade.execute"}, + encrypt_io=True, + require_mutual_auth=True + ) + async def execute_trade(symbol: str, amount: float, ctx: Context) -> str: + # Tool implementation + return f"Trade executed: {symbol} x {amount}" + ``` + """ + from .annotations import AuthMethod + + # Create secure annotations + secure_annotations = SecureAnnotations( + require_auth=require_auth, + auth_methods=auth_methods or [AuthMethod.JWT], + required_permissions=required_permissions or set(), + encrypt_input=encrypt_io, + encrypt_output=encrypt_io, + require_mutual_auth=require_mutual_auth, + security_level=security_level, + ) + + def decorator(func: F) -> F: + # Check if this is being used with FastMCP + # In production, this would be integrated with FastMCP.tool() + + @functools.wraps(func) + async def wrapper(*args, **kwargs): + # This wrapper would be replaced by SecureTool.run in production + # For now, just call the function + return await func(*args, **kwargs) if inspect.iscoroutinefunction(func) else func(*args, **kwargs) + + # Store security metadata on the function + wrapper._secure_annotations = secure_annotations + wrapper._tool_identity = tool_identity + wrapper._is_secure_tool = True + + return cast(F, wrapper) + + return decorator + + +def create_secure_tool_from_function( + fn: Callable[..., Any], + mcp: FastMCP, + secure_annotations: SecureAnnotations, + tool_identity: Optional[ToolIdentity] = None, + **kwargs +) -> None: + """ + Helper function to add a secure tool to a FastMCP instance. + + This would be called internally by FastMCP when a secure_tool decorator is used. + """ + secure_tool_instance = SecureTool.from_function( + fn=fn, + secure_annotations=secure_annotations, + tool_identity=tool_identity, + **kwargs + ) + + # Register with the tool manager + mcp._tool_manager._tools[secure_tool_instance.name] = secure_tool_instance \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/secure/utils.py b/src/mcp/mcp/server/fastmcp/secure/utils.py new file mode 100644 index 000000000..e0a935c9a --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/secure/utils.py @@ -0,0 +1,628 @@ +""" +Utility functions for secure MCP operations. + +Provides encryption, authentication, and security helper functions. +""" + +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import os +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple + +import jwt +from cryptography import x509 +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2 + +from mcp.types import Error + +from .annotations import AuthMethod, SecureAnnotations +from .identity import ClientIdentity, ToolIdentity +from .session import SecureSession + + +class SecureAnnotationProcessor: + """ + Processes secure annotations for tools, resources, and prompts. + + This class handles the actual authentication, encryption, and + attestation logic when secure annotations are present. + """ + + def __init__( + self, + tool_identity: Optional[ToolIdentity] = None, + jwt_secret: Optional[str] = None, + trusted_cas: Optional[List[x509.Certificate]] = None, + api_keys: Optional[Dict[str, ClientIdentity]] = None, + ): + self.tool_identity = tool_identity + self.jwt_secret = jwt_secret or os.environ.get("MCP_JWT_SECRET") + self.trusted_cas = trusted_cas or [] + self.api_keys = api_keys or {} + + # Session and rate limit storage + self.sessions: Dict[str, SecureSession] = {} + self.rate_limits: Dict[str, List[datetime]] = {} + + # Audit log (in production, use proper logging system) + self.audit_log_entries: List[dict] = [] + + async def process_secure_request( + self, + annotations: SecureAnnotations, + auth_header: Optional[str] = None, + client_cert: Optional[x509.Certificate] = None, + request_data: Optional[dict[str, Any]] = None, + ) -> Tuple[SecureSession, dict[str, Any]]: + """ + Process a secure request with authentication and encryption. + + Args: + annotations: Security annotations + auth_header: Authorization header + client_cert: Client certificate + request_data: Request data + + Returns: + Tuple of (secure_session, processed_request_data) + + Raises: + Error: If security checks fail + """ + # 1. Authenticate client if required + client_identity = None + if annotations.require_auth: + client_identity = await self._authenticate_client( + annotations.auth_methods, + auth_header, + client_cert, + request_data + ) + + # Check permissions + missing_perms = annotations.required_permissions - client_identity.permissions + if missing_perms: + raise Error( + code=403, + message=f"Missing required permissions: {missing_perms}" + ) + + # 2. Perform mutual authentication if required + if annotations.require_mutual_auth: + if not self.tool_identity: + raise Error( + code=500, + message="Tool identity not configured for mutual authentication" + ) + # Tool attestation is provided through session + + # 3. Create or retrieve session + session = await self._establish_session(client_identity) + + # 4. Verify tool attestation if required + if annotations.require_tool_attestation: + if not self._verify_tool_attestation(annotations): + raise Error( + code=403, + message="Tool attestation verification failed" + ) + + # 5. Check rate limits + if annotations.rate_limit: + self._check_rate_limit( + session.client_identity.client_id if session.client_identity else "anonymous", + annotations.rate_limit, + annotations.rate_limit_per_client + ) + + # 6. Decrypt input if required + processed_data = request_data or {} + if annotations.encrypt_input and session.encryption_key: + processed_data = self._decrypt_request_data(session, processed_data) + + # 7. Verify message integrity if required + if annotations.require_integrity_check: + self._verify_message_integrity(processed_data) + + # 8. Check replay protection if required + if annotations.require_replay_protection: + if not self._check_replay_protection(session, processed_data): + raise Error(code=400, message="Replay attack detected") + + # 9. Audit log + if annotations.audit_log: + self._audit_log( + session=session, + action="request", + include_data=annotations.audit_include_inputs, + data=processed_data if annotations.audit_include_inputs else None + ) + + return session, processed_data + + async def process_secure_response( + self, + annotations: SecureAnnotations, + session: SecureSession, + response_data: Any, + ) -> Any: + """ + Process a secure response with encryption and signing. + + Args: + annotations: Security annotations + session: Secure session + response_data: Response data + + Returns: + Processed response data + """ + # 1. Audit log + if annotations.audit_log: + self._audit_log( + session=session, + action="response", + include_data=annotations.audit_include_outputs, + data=response_data if annotations.audit_include_outputs else None + ) + + # 2. Encrypt output if required + if annotations.encrypt_output and session.encryption_key: + response_data = self._encrypt_response_data(session, response_data) + + # 3. Add integrity signature if required + if annotations.require_integrity_check: + response_data = self._add_integrity_signature(response_data) + + # 4. Sign response if tool signature is required + if annotations.tool_signature_required and self.tool_identity: + response_data = self._sign_response(response_data) + + # 5. Add session metadata + if isinstance(response_data, dict): + response_data["_session"] = { + "id": session.session_id[:8] + "...", # Truncated for security + "authenticated": session.client_identity is not None, + "encrypted": annotations.encrypt_output, + } + + return response_data + + async def _authenticate_client( + self, + auth_methods: List[AuthMethod], + auth_header: Optional[str], + client_cert: Optional[x509.Certificate], + request_data: Optional[dict], + ) -> ClientIdentity: + """Authenticate client using available methods.""" + + # Try JWT authentication + if AuthMethod.JWT in auth_methods and auth_header: + identity = self._authenticate_jwt(auth_header) + if identity: + return identity + + # Try API key authentication + if AuthMethod.API_KEY in auth_methods: + api_key = None + if auth_header and auth_header.startswith("Bearer "): + api_key = auth_header[7:] + elif request_data and "api_key" in request_data: + api_key = request_data["api_key"] + + if api_key: + identity = self._authenticate_api_key(api_key) + if identity: + return identity + + # Try certificate authentication + if AuthMethod.CERTIFICATE in auth_methods and client_cert: + identity = self._authenticate_certificate(client_cert) + if identity: + return identity + + # Try mutual TLS + if AuthMethod.MTLS in auth_methods and client_cert: + identity = self._authenticate_mtls(client_cert) + if identity: + return identity + + raise Error(code=401, message="Authentication failed") + + def _authenticate_jwt(self, auth_header: str) -> Optional[ClientIdentity]: + """Authenticate using JWT token.""" + if not auth_header.startswith("Bearer "): + return None + + token = auth_header[7:] + + try: + # Decode and verify JWT + claims = jwt.decode( + token, + self.jwt_secret, + algorithms=["HS256", "RS256", "ES256"] + ) + + return ClientIdentity( + client_id=claims.get("sub", "unknown"), + authentication_method=AuthMethod.JWT, + credentials=token, + permissions=set(claims.get("permissions", [])), + email=claims.get("email"), + organization=claims.get("org"), + expires_at=datetime.fromtimestamp(claims.get("exp", 0)), + metadata={"claims": claims} + ) + except jwt.InvalidTokenError: + return None + + def _authenticate_api_key(self, api_key: str) -> Optional[ClientIdentity]: + """Authenticate using API key.""" + return self.api_keys.get(api_key) + + def _authenticate_certificate(self, cert: x509.Certificate) -> Optional[ClientIdentity]: + """Authenticate using X.509 certificate.""" + # Verify certificate against trusted CAs + for ca in self.trusted_cas: + try: + ca.public_key().verify( + cert.signature, + cert.tbs_certificate_bytes, + cert.signature_algorithm_oid._name + ) + + # Extract client info from certificate + from cryptography.x509.oid import NameOID + + common_name = cert.subject.get_attributes_for_oid( + NameOID.COMMON_NAME + )[0].value + + org = None + org_attrs = cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME) + if org_attrs: + org = org_attrs[0].value + + return ClientIdentity( + client_id=common_name, + authentication_method=AuthMethod.CERTIFICATE, + credentials=cert, + permissions={"read", "write", "execute"}, # Extract from cert extensions + organization=org, + certificate=cert, + certificate_fingerprint=hashlib.sha256( + cert.public_bytes(x509.Encoding.DER) + ).hexdigest(), + ) + except Exception: + continue + + return None + + def _authenticate_mtls(self, cert: x509.Certificate) -> Optional[ClientIdentity]: + """Authenticate using mutual TLS.""" + # Similar to certificate auth but with bidirectional verification + identity = self._authenticate_certificate(cert) + + if identity and self.tool_identity: + # Verify that client also verified our tool certificate + # This would be handled at the TLS layer in production + identity.metadata["mtls_verified"] = True + + return identity + + async def _establish_session( + self, + client_identity: Optional[ClientIdentity] + ) -> SecureSession: + """Establish or retrieve a secure session.""" + # For simplicity, create a new session each time + # In production, implement session caching + + session_id = base64.b64encode(os.urandom(32)).decode() + + # Create encryption key if we have a client + encryption_key = None + if client_identity: + key_bytes = AESGCM.generate_key(bit_length=256) + encryption_key = AESGCM(key_bytes) + + session = SecureSession( + session_id=session_id, + client_identity=client_identity, + tool_identity=self.tool_identity, + established_at=datetime.utcnow(), + expires_at=datetime.utcnow() + timedelta(hours=1), + encryption_key=encryption_key + ) + + self.sessions[session_id] = session + return session + + def _verify_tool_attestation(self, annotations: SecureAnnotations) -> bool: + """Verify tool attestation matches requirements.""" + if not self.tool_identity: + return False + + # Check certificate fingerprint if specified + if annotations.tool_certificate_fingerprint: + if self.tool_identity.fingerprint != annotations.tool_certificate_fingerprint: + return False + + # Check attestation type if specified + if annotations.attestation_type: + if self.tool_identity.attestation_type != annotations.attestation_type: + return False + + return True + + def _check_rate_limit( + self, + client_id: str, + limit: int, + per_client: bool + ) -> None: + """Check and enforce rate limits.""" + key = client_id if per_client else "global" + now = datetime.utcnow() + + # Clean old entries + if key in self.rate_limits: + self.rate_limits[key] = [ + t for t in self.rate_limits[key] + if (now - t).total_seconds() < 60 + ] + else: + self.rate_limits[key] = [] + + # Check limit + if len(self.rate_limits[key]) >= limit: + raise Error(code=429, message="Rate limit exceeded") + + # Add current request + self.rate_limits[key].append(now) + + def _decrypt_request_data( + self, + session: SecureSession, + data: dict[str, Any] + ) -> dict[str, Any]: + """Decrypt request data.""" + decrypted = {} + for key, value in data.items(): + if isinstance(value, str) and value.startswith("ENC:"): + encrypted_bytes = base64.b64decode(value[4:]) + decrypted_bytes = session.decrypt(encrypted_bytes) + decrypted[key] = json.loads(decrypted_bytes) + elif isinstance(value, dict): + decrypted[key] = self._decrypt_request_data(session, value) + else: + decrypted[key] = value + return decrypted + + def _encrypt_response_data( + self, + session: SecureSession, + data: Any + ) -> dict[str, Any]: + """Encrypt response data.""" + json_data = json.dumps(data) + encrypted = session.encrypt(json_data.encode()) + + return { + "encrypted": True, + "algorithm": session.encryption_algorithm, + "data": "ENC:" + base64.b64encode(encrypted).decode(), + "session_id": session.session_id + } + + def _verify_message_integrity(self, data: dict) -> bool: + """Verify message integrity signature.""" + if "_integrity" not in data: + return True # No integrity check provided + + integrity = data.pop("_integrity") + + # Compute expected hash + data_str = json.dumps(data, sort_keys=True) + expected_hash = hashlib.sha256(data_str.encode()).hexdigest() + + return hmac.compare_digest(integrity, expected_hash) + + def _add_integrity_signature(self, data: Any) -> dict: + """Add integrity signature to response.""" + if isinstance(data, dict): + data_copy = data.copy() + else: + data_copy = {"value": data} + + # Compute hash + data_str = json.dumps(data_copy, sort_keys=True) + integrity = hashlib.sha256(data_str.encode()).hexdigest() + + data_copy["_integrity"] = integrity + return data_copy + + def _check_replay_protection( + self, + session: SecureSession, + data: dict + ) -> bool: + """Check for replay attacks.""" + if "_request_token" not in data: + return False + + token = data.pop("_request_token") + return session.verify_request_token(token) + + def _sign_response(self, data: Any) -> dict[str, Any]: + """Sign response data with tool identity.""" + if not self.tool_identity: + return data if isinstance(data, dict) else {"value": data} + + # Prepare data for signing + if isinstance(data, dict): + sign_data = data.copy() + else: + sign_data = {"value": data} + + # Add timestamp + sign_data["_timestamp"] = datetime.utcnow().isoformat() + + # Sign the data + json_data = json.dumps(sign_data, sort_keys=True) + signature = self.tool_identity.sign_data(json_data.encode()) + + return { + "data": sign_data, + "signature": base64.b64encode(signature).decode(), + "tool_id": self.tool_identity.tool_id, + "tool_fingerprint": self.tool_identity.fingerprint[:16] + "..." + } + + def _audit_log( + self, + session: Optional[SecureSession], + action: str, + include_data: bool, + data: Any = None + ) -> None: + """Create audit log entry.""" + log_entry = { + "timestamp": datetime.utcnow().isoformat(), + "action": action, + } + + if session: + log_entry["session_id"] = session.session_id + if session.client_identity: + log_entry["client_id"] = session.client_identity.client_id + log_entry["auth_method"] = session.client_identity.authentication_method.value + if session.tool_identity: + log_entry["tool_id"] = session.tool_identity.tool_id + + if include_data and data is not None: + # Hash sensitive data for audit + if isinstance(data, (dict, list)): + log_entry["data_hash"] = hashlib.sha256( + json.dumps(data, sort_keys=True).encode() + ).hexdigest() + else: + log_entry["data_hash"] = hashlib.sha256( + str(data).encode() + ).hexdigest() + + self.audit_log_entries.append(log_entry) + + # In production, write to proper audit system + # For now, just keep in memory + + +# Convenience functions for encryption/decryption +def encrypt_data(data: str, key: bytes) -> str: + """ + Encrypt data using AES-256-GCM. + + Args: + data: Data to encrypt + key: 256-bit encryption key + + Returns: + Base64-encoded encrypted data + """ + cipher = AESGCM(key) + nonce = os.urandom(12) + ciphertext = cipher.encrypt(nonce, data.encode(), None) + return base64.b64encode(nonce + ciphertext).decode() + + +def decrypt_data(encrypted: str, key: bytes) -> str: + """ + Decrypt data encrypted with AES-256-GCM. + + Args: + encrypted: Base64-encoded encrypted data + key: 256-bit encryption key + + Returns: + Decrypted data + """ + cipher = AESGCM(key) + raw = base64.b64decode(encrypted) + nonce, ciphertext = raw[:12], raw[12:] + plaintext = cipher.decrypt(nonce, ciphertext, None) + return plaintext.decode() + + +def generate_session_key(password: str, salt: Optional[bytes] = None) -> bytes: + """ + Generate a session key from a password. + + Args: + password: Password to derive key from + salt: Optional salt (will generate if not provided) + + Returns: + 256-bit key suitable for AES-256 + """ + if salt is None: + salt = os.urandom(16) + + kdf = PBKDF2( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + return kdf.derive(password.encode()) + + +def verify_signature( + data: bytes, + signature: bytes, + public_key_pem: bytes +) -> bool: + """ + Verify a digital signature. + + Args: + data: Data that was signed + signature: Digital signature + public_key_pem: Public key in PEM format + + Returns: + True if signature is valid + """ + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding + + try: + public_key = serialization.load_pem_public_key(public_key_pem) + + if isinstance(public_key, rsa.RSAPublicKey): + public_key.verify( + signature, + data, + padding.PSS( + mgf=padding.MGF1(hashes.SHA256()), + salt_length=padding.PSS.MAX_LENGTH + ), + hashes.SHA256() + ) + elif isinstance(public_key, ec.EllipticCurvePublicKey): + public_key.verify(signature, data, ec.ECDSA(hashes.SHA256())) + else: + return False + + return True + except Exception: + return False \ No newline at end of file diff --git a/src/mcp/mcp/server/fastmcp/server.py b/src/mcp/mcp/server/fastmcp/server.py new file mode 100644 index 000000000..924baaa9b --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/server.py @@ -0,0 +1,1188 @@ +"""FastMCP - A more ergonomic interface for MCP servers.""" + +from __future__ import annotations as _annotations + +import inspect +import re +from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from typing import Any, Generic, Literal + +import anyio +import pydantic_core +from pydantic import BaseModel +from pydantic.networks import AnyUrl +from pydantic_settings import BaseSettings, SettingsConfigDict +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import Request +from starlette.responses import Response +from starlette.routing import Mount, Route +from starlette.types import Receive, Scope, Send + +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware +from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier +from mcp.server.auth.settings import AuthSettings +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation +from mcp.server.fastmcp.exceptions import ResourceError +from mcp.server.fastmcp.prompts import Prompt, PromptManager +from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager +from mcp.server.fastmcp.tools import Tool, ToolManager +from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger +from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.lowlevel.server import LifespanResultT +from mcp.server.lowlevel.server import Server as MCPServer +from mcp.server.lowlevel.server import lifespan as default_lifespan +from mcp.server.session import ServerSession, ServerSessionT +from mcp.server.sse import SseServerTransport +from mcp.server.stdio import stdio_server +from mcp.server.streamable_http import EventStore +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared.context import LifespanContextT, RequestContext, RequestT +from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations +from mcp.types import Prompt as MCPPrompt +from mcp.types import PromptArgument as MCPPromptArgument +from mcp.types import Resource as MCPResource +from mcp.types import ResourceTemplate as MCPResourceTemplate +from mcp.types import Tool as MCPTool + +logger = get_logger(__name__) + + +class Settings(BaseSettings, Generic[LifespanResultT]): + """FastMCP server settings. + + All settings can be configured via environment variables with the prefix FASTMCP_. + For example, FASTMCP_DEBUG=true will set debug=True. + """ + + model_config = SettingsConfigDict( + env_prefix="FASTMCP_", + env_file=".env", + env_nested_delimiter="__", + nested_model_default_partial_update=True, + extra="ignore", + ) + + # Server settings + debug: bool + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + + # HTTP settings + host: str + port: int + mount_path: str + sse_path: str + message_path: str + streamable_http_path: str + + # StreamableHTTP settings + json_response: bool + stateless_http: bool + """Define if the server should create a new transport per request.""" + + # resource settings + warn_on_duplicate_resources: bool + + # tool settings + warn_on_duplicate_tools: bool + + # prompt settings + warn_on_duplicate_prompts: bool + + # TODO(Marcelo): Investigate if this is used. If it is, it's probably a good idea to remove it. + dependencies: list[str] + """A list of dependencies to install in the server environment.""" + + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None + """A async context manager that will be called when the server is started.""" + + auth: AuthSettings | None + + # Transport security settings (DNS rebinding protection) + transport_security: TransportSecuritySettings | None + + +def lifespan_wrapper( + app: FastMCP[LifespanResultT], + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], +) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: + @asynccontextmanager + async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: + async with lifespan(app) as context: + yield context + + return wrap + + +class FastMCP(Generic[LifespanResultT]): + def __init__( + self, + name: str | None = None, + instructions: str | None = None, + auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, + token_verifier: TokenVerifier | None = None, + event_store: EventStore | None = None, + *, + tools: list[Tool] | None = None, + debug: bool = False, + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", + host: str = "127.0.0.1", + port: int = 8000, + mount_path: str = "/", + sse_path: str = "/sse", + message_path: str = "/messages/", + streamable_http_path: str = "/mcp", + json_response: bool = False, + stateless_http: bool = False, + warn_on_duplicate_resources: bool = True, + warn_on_duplicate_tools: bool = True, + warn_on_duplicate_prompts: bool = True, + dependencies: Collection[str] = (), + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, + auth: AuthSettings | None = None, + transport_security: TransportSecuritySettings | None = None, + ): + self.settings = Settings( + debug=debug, + log_level=log_level, + host=host, + port=port, + mount_path=mount_path, + sse_path=sse_path, + message_path=message_path, + streamable_http_path=streamable_http_path, + json_response=json_response, + stateless_http=stateless_http, + warn_on_duplicate_resources=warn_on_duplicate_resources, + warn_on_duplicate_tools=warn_on_duplicate_tools, + warn_on_duplicate_prompts=warn_on_duplicate_prompts, + dependencies=list(dependencies), + lifespan=lifespan, + auth=auth, + transport_security=transport_security, + ) + + self._mcp_server = MCPServer( + name=name or "FastMCP", + instructions=instructions, + # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. + # We need to create a Lifespan type that is a generic on the server type, like Starlette does. + lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore + ) + self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) + self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) + self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) + # Validate auth configuration + if self.settings.auth is not None: + if auth_server_provider and token_verifier: + raise ValueError("Cannot specify both auth_server_provider and token_verifier") + if not auth_server_provider and not token_verifier: + raise ValueError("Must specify either auth_server_provider or token_verifier when auth is enabled") + else: + if auth_server_provider or token_verifier: + raise ValueError("Cannot specify auth_server_provider or token_verifier without auth settings") + + self._auth_server_provider = auth_server_provider + self._token_verifier = token_verifier + + # Create token verifier from provider if needed (backwards compatibility) + if auth_server_provider and not token_verifier: + self._token_verifier = ProviderTokenVerifier(auth_server_provider) + self._event_store = event_store + self._custom_starlette_routes: list[Route] = [] + self.dependencies = self.settings.dependencies + self._session_manager: StreamableHTTPSessionManager | None = None + + # Set up MCP protocol handlers + self._setup_handlers() + + # Configure logging + configure_logging(self.settings.log_level) + + @property + def name(self) -> str: + return self._mcp_server.name + + @property + def instructions(self) -> str | None: + return self._mcp_server.instructions + + @property + def session_manager(self) -> StreamableHTTPSessionManager: + """Get the StreamableHTTP session manager. + + This is exposed to enable advanced use cases like mounting multiple + FastMCP servers in a single FastAPI application. + + Raises: + RuntimeError: If called before streamable_http_app() has been called. + """ + if self._session_manager is None: + raise RuntimeError( + "Session manager can only be accessed after" + "calling streamable_http_app()." + "The session manager is created lazily" + "to avoid unnecessary initialization." + ) + return self._session_manager + + def run( + self, + transport: Literal["stdio", "sse", "streamable-http"] = "stdio", + mount_path: str | None = None, + ) -> None: + """Run the FastMCP server. Note this is a synchronous function. + + Args: + transport: Transport protocol to use ("stdio", "sse", or "streamable-http") + mount_path: Optional mount path for SSE transport + """ + TRANSPORTS = Literal["stdio", "sse", "streamable-http"] + if transport not in TRANSPORTS.__args__: # type: ignore + raise ValueError(f"Unknown transport: {transport}") + + match transport: + case "stdio": + anyio.run(self.run_stdio_async) + case "sse": + anyio.run(lambda: self.run_sse_async(mount_path)) + case "streamable-http": + anyio.run(self.run_streamable_http_async) + + def _setup_handlers(self) -> None: + """Set up core MCP protocol handlers.""" + self._mcp_server.list_tools()(self.list_tools) + # Note: we disable the lowlevel server's input validation. + # FastMCP does ad hoc conversion of incoming data before validating - + # for now we preserve this for backwards compatibility. + self._mcp_server.call_tool(validate_input=False)(self.call_tool) + self._mcp_server.list_resources()(self.list_resources) + self._mcp_server.read_resource()(self.read_resource) + self._mcp_server.list_prompts()(self.list_prompts) + self._mcp_server.get_prompt()(self.get_prompt) + self._mcp_server.list_resource_templates()(self.list_resource_templates) + + async def list_tools(self) -> list[MCPTool]: + """List all available tools.""" + tools = self._tool_manager.list_tools() + return [ + MCPTool( + name=info.name, + title=info.title, + description=info.description, + inputSchema=info.parameters, + outputSchema=info.output_schema, + annotations=info.annotations, + ) + for info in tools + ] + + def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: + """ + Returns a Context object. Note that the context will only be valid + during a request; outside a request, most methods will error. + """ + try: + request_context = self._mcp_server.request_context + except LookupError: + request_context = None + return Context(request_context=request_context, fastmcp=self) + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: + """Call a tool by name with arguments.""" + context = self.get_context() + return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) + + async def list_resources(self) -> list[MCPResource]: + """List all available resources.""" + + resources = self._resource_manager.list_resources() + return [ + MCPResource( + uri=resource.uri, + name=resource.name or "", + title=resource.title, + description=resource.description, + mimeType=resource.mime_type, + ) + for resource in resources + ] + + async def list_resource_templates(self) -> list[MCPResourceTemplate]: + templates = self._resource_manager.list_templates() + return [ + MCPResourceTemplate( + uriTemplate=template.uri_template, + name=template.name, + title=template.title, + description=template.description, + ) + for template in templates + ] + + async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: + """Read a resource by URI.""" + + resource = await self._resource_manager.get_resource(uri) + if not resource: + raise ResourceError(f"Unknown resource: {uri}") + + try: + content = await resource.read() + return [ReadResourceContents(content=content, mime_type=resource.mime_type)] + except Exception as e: + logger.exception(f"Error reading resource {uri}") + raise ResourceError(str(e)) + + def add_tool( + self, + fn: AnyFunction, + name: str | None = None, + title: str | None = None, + description: str | None = None, + annotations: ToolAnnotations | None = None, + structured_output: bool | None = None, + ) -> None: + """Add a tool to the server. + + The tool function can optionally request a Context object by adding a parameter + with the Context type annotation. See the @tool decorator for examples. + + Args: + fn: The function to register as a tool + name: Optional name for the tool (defaults to function name) + title: Optional human-readable title for the tool + description: Optional description of what the tool does + annotations: Optional ToolAnnotations providing additional tool information + structured_output: Controls whether the tool's output is structured or unstructured + - If None, auto-detects based on the function's return type annotation + - If True, unconditionally creates a structured tool (return type annotation permitting) + - If False, unconditionally creates an unstructured tool + """ + self._tool_manager.add_tool( + fn, + name=name, + title=title, + description=description, + annotations=annotations, + structured_output=structured_output, + ) + + def tool( + self, + name: str | None = None, + title: str | None = None, + description: str | None = None, + annotations: ToolAnnotations | None = None, + structured_output: bool | None = None, + ) -> Callable[[AnyFunction], AnyFunction]: + """Decorator to register a tool. + + Tools can optionally request a Context object by adding a parameter with the + Context type annotation. The context provides access to MCP capabilities like + logging, progress reporting, and resource access. + + Args: + name: Optional name for the tool (defaults to function name) + title: Optional human-readable title for the tool + description: Optional description of what the tool does + annotations: Optional ToolAnnotations providing additional tool information + structured_output: Controls whether the tool's output is structured or unstructured + - If None, auto-detects based on the function's return type annotation + - If True, unconditionally creates a structured tool (return type annotation permitting) + - If False, unconditionally creates an unstructured tool + + Example: + @server.tool() + def my_tool(x: int) -> str: + return str(x) + + @server.tool() + def tool_with_context(x: int, ctx: Context) -> str: + ctx.info(f"Processing {x}") + return str(x) + + @server.tool() + async def async_tool(x: int, context: Context) -> str: + await context.report_progress(50, 100) + return str(x) + """ + # Check if user passed function directly instead of calling decorator + if callable(name): + raise TypeError( + "The @tool decorator was used incorrectly. Did you forget to call it? Use @tool() instead of @tool" + ) + + def decorator(fn: AnyFunction) -> AnyFunction: + self.add_tool( + fn, + name=name, + title=title, + description=description, + annotations=annotations, + structured_output=structured_output, + ) + return fn + + return decorator + + def completion(self): + """Decorator to register a completion handler. + + The completion handler receives: + - ref: PromptReference or ResourceTemplateReference + - argument: CompletionArgument with name and partial value + - context: Optional CompletionContext with previously resolved arguments + + Example: + @mcp.completion() + async def handle_completion(ref, argument, context): + if isinstance(ref, ResourceTemplateReference): + # Return completions based on ref, argument, and context + return Completion(values=["option1", "option2"]) + return None + """ + return self._mcp_server.completion() + + def add_resource(self, resource: Resource) -> None: + """Add a resource to the server. + + Args: + resource: A Resource instance to add + """ + self._resource_manager.add_resource(resource) + + def resource( + self, + uri: str, + *, + name: str | None = None, + title: str | None = None, + description: str | None = None, + mime_type: str | None = None, + ) -> Callable[[AnyFunction], AnyFunction]: + """Decorator to register a function as a resource. + + The function will be called when the resource is read to generate its content. + The function can return: + - str for text content + - bytes for binary content + - other types will be converted to JSON + + If the URI contains parameters (e.g. "resource://{param}") or the function + has parameters, it will be registered as a template resource. + + Args: + uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") + name: Optional name for the resource + title: Optional human-readable title for the resource + description: Optional description of the resource + mime_type: Optional MIME type for the resource + + Example: + @server.resource("resource://my-resource") + def get_data() -> str: + return "Hello, world!" + + @server.resource("resource://my-resource") + async get_data() -> str: + data = await fetch_data() + return f"Hello, world! {data}" + + @server.resource("resource://{city}/weather") + def get_weather(city: str) -> str: + return f"Weather for {city}" + + @server.resource("resource://{city}/weather") + async def get_weather(city: str) -> str: + data = await fetch_weather(city) + return f"Weather for {city}: {data}" + """ + # Check if user passed function directly instead of calling decorator + if callable(uri): + raise TypeError( + "The @resource decorator was used incorrectly. " + "Did you forget to call it? Use @resource('uri') instead of @resource" + ) + + def decorator(fn: AnyFunction) -> AnyFunction: + # Check if this should be a template + has_uri_params = "{" in uri and "}" in uri + has_func_params = bool(inspect.signature(fn).parameters) + + if has_uri_params or has_func_params: + # Validate that URI params match function params + uri_params = set(re.findall(r"{(\w+)}", uri)) + func_params = set(inspect.signature(fn).parameters.keys()) + + if uri_params != func_params: + raise ValueError( + f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" + ) + + # Register as template + self._resource_manager.add_template( + fn=fn, + uri_template=uri, + name=name, + title=title, + description=description, + mime_type=mime_type, + ) + else: + # Register as regular resource + resource = FunctionResource.from_function( + fn=fn, + uri=uri, + name=name, + title=title, + description=description, + mime_type=mime_type, + ) + self.add_resource(resource) + return fn + + return decorator + + def add_prompt(self, prompt: Prompt) -> None: + """Add a prompt to the server. + + Args: + prompt: A Prompt instance to add + """ + self._prompt_manager.add_prompt(prompt) + + def prompt( + self, name: str | None = None, title: str | None = None, description: str | None = None + ) -> Callable[[AnyFunction], AnyFunction]: + """Decorator to register a prompt. + + Args: + name: Optional name for the prompt (defaults to function name) + title: Optional human-readable title for the prompt + description: Optional description of what the prompt does + + Example: + @server.prompt() + def analyze_table(table_name: str) -> list[Message]: + schema = read_table_schema(table_name) + return [ + { + "role": "user", + "content": f"Analyze this schema:\n{schema}" + } + ] + + @server.prompt() + async def analyze_file(path: str) -> list[Message]: + content = await read_file(path) + return [ + { + "role": "user", + "content": { + "type": "resource", + "resource": { + "uri": f"file://{path}", + "text": content + } + } + } + ] + """ + # Check if user passed function directly instead of calling decorator + if callable(name): + raise TypeError( + "The @prompt decorator was used incorrectly. " + "Did you forget to call it? Use @prompt() instead of @prompt" + ) + + def decorator(func: AnyFunction) -> AnyFunction: + prompt = Prompt.from_function(func, name=name, title=title, description=description) + self.add_prompt(prompt) + return func + + return decorator + + def custom_route( + self, + path: str, + methods: list[str], + name: str | None = None, + include_in_schema: bool = True, + ): + """ + Decorator to register a custom HTTP route on the FastMCP server. + + Allows adding arbitrary HTTP endpoints outside the standard MCP protocol, + which can be useful for OAuth callbacks, health checks, or admin APIs. + The handler function must be an async function that accepts a Starlette + Request and returns a Response. + + Args: + path: URL path for the route (e.g., "/oauth/callback") + methods: List of HTTP methods to support (e.g., ["GET", "POST"]) + name: Optional name for the route (to reference this route with + Starlette's reverse URL lookup feature) + include_in_schema: Whether to include in OpenAPI schema, defaults to True + + Example: + @server.custom_route("/health", methods=["GET"]) + async def health_check(request: Request) -> Response: + return JSONResponse({"status": "ok"}) + """ + + def decorator( + func: Callable[[Request], Awaitable[Response]], + ) -> Callable[[Request], Awaitable[Response]]: + self._custom_starlette_routes.append( + Route( + path, + endpoint=func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + ) + return func + + return decorator + + async def run_stdio_async(self) -> None: + """Run the server using stdio transport.""" + async with stdio_server() as (read_stream, write_stream): + await self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + + async def run_sse_async(self, mount_path: str | None = None) -> None: + """Run the server using SSE transport.""" + import uvicorn + + starlette_app = self.sse_app(mount_path) + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + async def run_streamable_http_async(self) -> None: + """Run the server using StreamableHTTP transport.""" + import uvicorn + + starlette_app = self.streamable_http_app() + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + def _normalize_path(self, mount_path: str, endpoint: str) -> str: + """ + Combine mount path and endpoint to return a normalized path. + + Args: + mount_path: The mount path (e.g. "/github" or "/") + endpoint: The endpoint path (e.g. "/messages/") + + Returns: + Normalized path (e.g. "/github/messages/") + """ + # Special case: root path + if mount_path == "/": + return endpoint + + # Remove trailing slash from mount path + if mount_path.endswith("/"): + mount_path = mount_path[:-1] + + # Ensure endpoint starts with slash + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + # Combine paths + return mount_path + endpoint + + def sse_app(self, mount_path: str | None = None) -> Starlette: + """Return an instance of the SSE server app.""" + from starlette.middleware import Middleware + from starlette.routing import Mount, Route + + # Update mount_path in settings if provided + if mount_path is not None: + self.settings.mount_path = mount_path + + # Create normalized endpoint considering the mount path + normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path) + + # Set up auth context and dependencies + + sse = SseServerTransport( + normalized_message_endpoint, + security_settings=self.settings.transport_security, + ) + + async def handle_sse(scope: Scope, receive: Receive, send: Send): + # Add client ID from auth context into request context if available + + async with sse.connect_sse( + scope, + receive, + send, + ) as streams: + await self._mcp_server.run( + streams[0], + streams[1], + self._mcp_server.create_initialization_options(), + ) + return Response() + + # Create routes + routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] + required_scopes = [] + + # Set up auth if configured + if self.settings.auth: + required_scopes = self.settings.auth.required_scopes or [] + + # Add auth middleware if token verifier is available + if self._token_verifier: + middleware = [ + # extract auth info from request (but do not require it) + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend(self._token_verifier), + ), + # Add the auth context middleware to store + # authenticated user in a contextvar + Middleware(AuthContextMiddleware), + ] + + # Add auth endpoints if auth server provider is configured + if self._auth_server_provider: + from mcp.server.auth.routes import create_auth_routes + + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + + # When auth is configured, require authentication + if self._token_verifier: + # Determine resource metadata URL + resource_metadata_url = None + if self.settings.auth and self.settings.auth.resource_server_url: + from pydantic import AnyHttpUrl + + resource_metadata_url = AnyHttpUrl( + str(self.settings.auth.resource_server_url).rstrip("/") + "/.well-known/oauth-protected-resource" + ) + + # Auth is enabled, wrap the endpoints with RequireAuthMiddleware + routes.append( + Route( + self.settings.sse_path, + endpoint=RequireAuthMiddleware(handle_sse, required_scopes, resource_metadata_url), + methods=["GET"], + ) + ) + routes.append( + Mount( + self.settings.message_path, + app=RequireAuthMiddleware(sse.handle_post_message, required_scopes, resource_metadata_url), + ) + ) + else: + # Auth is disabled, no need for RequireAuthMiddleware + # Since handle_sse is an ASGI app, we need to create a compatible endpoint + async def sse_endpoint(request: Request) -> Response: + # Convert the Starlette request to ASGI parameters + return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] + + routes.append( + Route( + self.settings.sse_path, + endpoint=sse_endpoint, + methods=["GET"], + ) + ) + routes.append( + Mount( + self.settings.message_path, + app=sse.handle_post_message, + ) + ) + # Add protected resource metadata endpoint if configured as RS + if self.settings.auth and self.settings.auth.resource_server_url: + from mcp.server.auth.routes import create_protected_resource_routes + + routes.extend( + create_protected_resource_routes( + resource_url=self.settings.auth.resource_server_url, + authorization_servers=[self.settings.auth.issuer_url], + scopes_supported=self.settings.auth.required_scopes, + ) + ) + + # mount these routes last, so they have the lowest route matching precedence + routes.extend(self._custom_starlette_routes) + + # Create Starlette app with routes and middleware + return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware) + + def streamable_http_app(self) -> Starlette: + """Return an instance of the StreamableHTTP server app.""" + from starlette.middleware import Middleware + + # Create session manager on first call (lazy initialization) + if self._session_manager is None: + self._session_manager = StreamableHTTPSessionManager( + app=self._mcp_server, + event_store=self._event_store, + json_response=self.settings.json_response, + stateless=self.settings.stateless_http, # Use the stateless setting + security_settings=self.settings.transport_security, + ) + + # Create the ASGI handler + streamable_http_app = StreamableHTTPASGIApp(self._session_manager) + + # Create routes + routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] + required_scopes = [] + + # Set up auth if configured + if self.settings.auth: + required_scopes = self.settings.auth.required_scopes or [] + + # Add auth middleware if token verifier is available + if self._token_verifier: + middleware = [ + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend(self._token_verifier), + ), + Middleware(AuthContextMiddleware), + ] + + # Add auth endpoints if auth server provider is configured + if self._auth_server_provider: + from mcp.server.auth.routes import create_auth_routes + + routes.extend( + create_auth_routes( + provider=self._auth_server_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + + # Set up routes with or without auth + if self._token_verifier: + # Determine resource metadata URL + resource_metadata_url = None + if self.settings.auth and self.settings.auth.resource_server_url: + from pydantic import AnyHttpUrl + + resource_metadata_url = AnyHttpUrl( + str(self.settings.auth.resource_server_url).rstrip("/") + "/.well-known/oauth-protected-resource" + ) + + routes.append( + Route( + self.settings.streamable_http_path, + endpoint=RequireAuthMiddleware(streamable_http_app, required_scopes, resource_metadata_url), + ) + ) + else: + # Auth is disabled, no wrapper needed + routes.append( + Route( + self.settings.streamable_http_path, + endpoint=streamable_http_app, + ) + ) + + # Add protected resource metadata endpoint if configured as RS + if self.settings.auth and self.settings.auth.resource_server_url: + from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler + from mcp.server.auth.routes import cors_middleware + from mcp.shared.auth import ProtectedResourceMetadata + + protected_resource_metadata = ProtectedResourceMetadata( + resource=self.settings.auth.resource_server_url, + authorization_servers=[self.settings.auth.issuer_url], + scopes_supported=self.settings.auth.required_scopes, + ) + routes.append( + Route( + "/.well-known/oauth-protected-resource", + endpoint=cors_middleware( + ProtectedResourceMetadataHandler(protected_resource_metadata).handle, + ["GET", "OPTIONS"], + ), + methods=["GET", "OPTIONS"], + ) + ) + + routes.extend(self._custom_starlette_routes) + + return Starlette( + debug=self.settings.debug, + routes=routes, + middleware=middleware, + lifespan=lambda app: self.session_manager.run(), + ) + + async def list_prompts(self) -> list[MCPPrompt]: + """List all available prompts.""" + prompts = self._prompt_manager.list_prompts() + return [ + MCPPrompt( + name=prompt.name, + title=prompt.title, + description=prompt.description, + arguments=[ + MCPPromptArgument( + name=arg.name, + description=arg.description, + required=arg.required, + ) + for arg in (prompt.arguments or []) + ], + ) + for prompt in prompts + ] + + async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: + """Get a prompt by name with arguments.""" + try: + prompt = self._prompt_manager.get_prompt(name) + if not prompt: + raise ValueError(f"Unknown prompt: {name}") + + messages = await prompt.render(arguments) + + return GetPromptResult( + description=prompt.description, + messages=pydantic_core.to_jsonable_python(messages), + ) + except Exception as e: + logger.exception(f"Error getting prompt {name}") + raise ValueError(str(e)) + + +class StreamableHTTPASGIApp: + """ + ASGI application for Streamable HTTP server transport. + """ + + def __init__(self, session_manager: StreamableHTTPSessionManager): + self.session_manager = session_manager + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await self.session_manager.handle_request(scope, receive, send) + + +class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): + """Context object providing access to MCP capabilities. + + This provides a cleaner interface to MCP's RequestContext functionality. + It gets injected into tool and resource functions that request it via type hints. + + To use context in a tool function, add a parameter with the Context type annotation: + + ```python + @server.tool() + def my_tool(x: int, ctx: Context) -> str: + # Log messages to the client + ctx.info(f"Processing {x}") + ctx.debug("Debug info") + ctx.warning("Warning message") + ctx.error("Error message") + + # Report progress + ctx.report_progress(50, 100) + + # Access resources + data = ctx.read_resource("resource://data") + + # Get request info + request_id = ctx.request_id + client_id = ctx.client_id + + return str(x) + ``` + + The context parameter name can be anything as long as it's annotated with Context. + The context is optional - tools that don't need it can omit the parameter. + """ + + _request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None + _fastmcp: FastMCP | None + + def __init__( + self, + *, + request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None, + fastmcp: FastMCP | None = None, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._request_context = request_context + self._fastmcp = fastmcp + + @property + def fastmcp(self) -> FastMCP: + """Access to the FastMCP server.""" + if self._fastmcp is None: + raise ValueError("Context is not available outside of a request") + return self._fastmcp + + @property + def request_context( + self, + ) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]: + """Access to the underlying request context.""" + if self._request_context is None: + raise ValueError("Context is not available outside of a request") + return self._request_context + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the current operation. + + Args: + progress: Current progress value e.g. 24 + total: Optional total value e.g. 100 + message: Optional message e.g. Starting render... + """ + progress_token = self.request_context.meta.progressToken if self.request_context.meta else None + + if progress_token is None: + return + + await self.request_context.session.send_progress_notification( + progress_token=progress_token, + progress=progress, + total=total, + message=message, + ) + + async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: + """Read a resource by URI. + + Args: + uri: Resource URI to read + + Returns: + The resource content as either text or bytes + """ + assert self._fastmcp is not None, "Context is not available outside of a request" + return await self._fastmcp.read_resource(uri) + + async def elicit( + self, + message: str, + schema: type[ElicitSchemaModelT], + ) -> ElicitationResult[ElicitSchemaModelT]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. The client might display the message to the + user and collect a response according to the provided schema. Or in case a + client is an agent, it might decide how to handle the elicitation -- either by asking + the user or automatically generating a response. + + Args: + schema: A Pydantic model class defining the expected response structure, according to the specification, + only primive types are allowed. + message: Optional message to present to the user. If not provided, will use + a default message based on the schema + + Returns: + An ElicitationResult containing the action taken and the data if accepted + + Note: + Check the result.action to determine if the user accepted, declined, or cancelled. + The result.data will only be populated if action is "accept" and validation succeeded. + """ + + return await elicit_with_validation( + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id + ) + + async def log( + self, + level: Literal["debug", "info", "warning", "error"], + message: str, + *, + logger_name: str | None = None, + ) -> None: + """Send a log message to the client. + + Args: + level: Log level (debug, info, warning, error) + message: Log message + logger_name: Optional logger name + **extra: Additional structured data to include + """ + await self.request_context.session.send_log_message( + level=level, + data=message, + logger=logger_name, + related_request_id=self.request_id, + ) + + @property + def client_id(self) -> str | None: + """Get the client ID if available.""" + return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None + + @property + def request_id(self) -> str: + """Get the unique ID for this request.""" + return str(self.request_context.request_id) + + @property + def session(self): + """Access to the underlying session for advanced usage.""" + return self.request_context.session + + # Convenience methods for common log levels + async def debug(self, message: str, **extra: Any) -> None: + """Send a debug log message.""" + await self.log("debug", message, **extra) + + async def info(self, message: str, **extra: Any) -> None: + """Send an info log message.""" + await self.log("info", message, **extra) + + async def warning(self, message: str, **extra: Any) -> None: + """Send a warning log message.""" + await self.log("warning", message, **extra) + + async def error(self, message: str, **extra: Any) -> None: + """Send an error log message.""" + await self.log("error", message, **extra) diff --git a/src/mcp/mcp/server/fastmcp/tools/__init__.py b/src/mcp/mcp/server/fastmcp/tools/__init__.py new file mode 100644 index 000000000..ae9c65619 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/tools/__init__.py @@ -0,0 +1,4 @@ +from .base import Tool +from .tool_manager import ToolManager + +__all__ = ["Tool", "ToolManager"] diff --git a/src/mcp/mcp/server/fastmcp/tools/base.py b/src/mcp/mcp/server/fastmcp/tools/base.py new file mode 100644 index 000000000..f50126081 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/tools/base.py @@ -0,0 +1,119 @@ +from __future__ import annotations as _annotations + +import functools +import inspect +from collections.abc import Callable +from functools import cached_property +from typing import TYPE_CHECKING, Any, get_origin + +from pydantic import BaseModel, Field + +from mcp.server.fastmcp.exceptions import ToolError +from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.types import ToolAnnotations + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT + from mcp.shared.context import LifespanContextT, RequestT + + +class Tool(BaseModel): + """Internal tool registration info.""" + + fn: Callable[..., Any] = Field(exclude=True) + name: str = Field(description="Name of the tool") + title: str | None = Field(None, description="Human-readable title of the tool") + description: str = Field(description="Description of what the tool does") + parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") + fn_metadata: FuncMetadata = Field( + description="Metadata about the function including a pydantic model for tool arguments" + ) + is_async: bool = Field(description="Whether the tool is async") + context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") + annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") + + @cached_property + def output_schema(self) -> dict[str, Any] | None: + return self.fn_metadata.output_schema + + @classmethod + def from_function( + cls, + fn: Callable[..., Any], + name: str | None = None, + title: str | None = None, + description: str | None = None, + context_kwarg: str | None = None, + annotations: ToolAnnotations | None = None, + structured_output: bool | None = None, + ) -> Tool: + """Create a Tool from a function.""" + from mcp.server.fastmcp.server import Context + + func_name = name or fn.__name__ + + if func_name == "": + raise ValueError("You must provide a name for lambda functions") + + func_doc = description or fn.__doc__ or "" + is_async = _is_async_callable(fn) + + if context_kwarg is None: + sig = inspect.signature(fn) + for param_name, param in sig.parameters.items(): + if get_origin(param.annotation) is not None: + continue + if issubclass(param.annotation, Context): + context_kwarg = param_name + break + + func_arg_metadata = func_metadata( + fn, + skip_names=[context_kwarg] if context_kwarg is not None else [], + structured_output=structured_output, + ) + parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) + + return cls( + fn=fn, + name=func_name, + title=title, + description=func_doc, + parameters=parameters, + fn_metadata=func_arg_metadata, + is_async=is_async, + context_kwarg=context_kwarg, + annotations=annotations, + ) + + async def run( + self, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + convert_result: bool = False, + ) -> Any: + """Run the tool with arguments.""" + try: + result = await self.fn_metadata.call_fn_with_arg_validation( + self.fn, + self.is_async, + arguments, + {self.context_kwarg: context} if self.context_kwarg is not None else None, + ) + + if convert_result: + result = self.fn_metadata.convert_result(result) + + return result + except Exception as e: + raise ToolError(f"Error executing tool {self.name}: {e}") from e + + +def _is_async_callable(obj: Any) -> bool: + while isinstance(obj, functools.partial): + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + ) diff --git a/src/mcp/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/mcp/server/fastmcp/tools/tool_manager.py new file mode 100644 index 000000000..bfa8b2382 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/tools/tool_manager.py @@ -0,0 +1,83 @@ +from __future__ import annotations as _annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from mcp.server.fastmcp.exceptions import ToolError +from mcp.server.fastmcp.tools.base import Tool +from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.shared.context import LifespanContextT, RequestT +from mcp.types import ToolAnnotations + +if TYPE_CHECKING: + from mcp.server.fastmcp.server import Context + from mcp.server.session import ServerSessionT + +logger = get_logger(__name__) + + +class ToolManager: + """Manages FastMCP tools.""" + + def __init__( + self, + warn_on_duplicate_tools: bool = True, + *, + tools: list[Tool] | None = None, + ): + self._tools: dict[str, Tool] = {} + if tools is not None: + for tool in tools: + if warn_on_duplicate_tools and tool.name in self._tools: + logger.warning(f"Tool already exists: {tool.name}") + self._tools[tool.name] = tool + + self.warn_on_duplicate_tools = warn_on_duplicate_tools + + def get_tool(self, name: str) -> Tool | None: + """Get tool by name.""" + return self._tools.get(name) + + def list_tools(self) -> list[Tool]: + """List all registered tools.""" + return list(self._tools.values()) + + def add_tool( + self, + fn: Callable[..., Any], + name: str | None = None, + title: str | None = None, + description: str | None = None, + annotations: ToolAnnotations | None = None, + structured_output: bool | None = None, + ) -> Tool: + """Add a tool to the server.""" + tool = Tool.from_function( + fn, + name=name, + title=title, + description=description, + annotations=annotations, + structured_output=structured_output, + ) + existing = self._tools.get(tool.name) + if existing: + if self.warn_on_duplicate_tools: + logger.warning(f"Tool already exists: {tool.name}") + return existing + self._tools[tool.name] = tool + return tool + + async def call_tool( + self, + name: str, + arguments: dict[str, Any], + context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, + convert_result: bool = False, + ) -> Any: + """Call a tool by name with arguments.""" + tool = self.get_tool(name) + if not tool: + raise ToolError(f"Unknown tool: {name}") + + return await tool.run(arguments, context=context, convert_result=convert_result) diff --git a/src/mcp/mcp/server/fastmcp/utilities/__init__.py b/src/mcp/mcp/server/fastmcp/utilities/__init__.py new file mode 100644 index 000000000..be448f97a --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/utilities/__init__.py @@ -0,0 +1 @@ +"""FastMCP utility modules.""" diff --git a/src/mcp/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/mcp/server/fastmcp/utilities/func_metadata.py new file mode 100644 index 000000000..a4cb8ac5b --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/utilities/func_metadata.py @@ -0,0 +1,523 @@ +import inspect +import json +from collections.abc import Awaitable, Callable, Sequence +from itertools import chain +from types import GenericAlias +from typing import Annotated, Any, ForwardRef, cast, get_args, get_origin, get_type_hints + +import pydantic_core +from pydantic import ( + BaseModel, + ConfigDict, + Field, + RootModel, + WithJsonSchema, + create_model, +) +from pydantic._internal._typing_extra import eval_type_backport +from pydantic.fields import FieldInfo +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind +from pydantic_core import PydanticUndefined + +from mcp.server.fastmcp.exceptions import InvalidSignature +from mcp.server.fastmcp.utilities.logging import get_logger +from mcp.server.fastmcp.utilities.types import Audio, Image +from mcp.types import ContentBlock, TextContent + +logger = get_logger(__name__) + + +class StrictJsonSchema(GenerateJsonSchema): + """A JSON schema generator that raises exceptions instead of emitting warnings. + + This is used to detect non-serializable types during schema generation. + """ + + def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None: + # Raise an exception instead of emitting a warning + raise ValueError(f"JSON schema warning: {kind} - {detail}") + + +class ArgModelBase(BaseModel): + """A model representing the arguments to a function.""" + + def model_dump_one_level(self) -> dict[str, Any]: + """Return a dict of the model's fields, one level deep. + + That is, sub-models etc are not dumped - they are kept as pydantic models. + """ + kwargs: dict[str, Any] = {} + for field_name, field_info in self.__class__.model_fields.items(): + value = getattr(self, field_name) + # Use the alias if it exists, otherwise use the field name + output_name = field_info.alias if field_info.alias else field_name + kwargs[output_name] = value + return kwargs + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +class FuncMetadata(BaseModel): + arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)] + output_schema: dict[str, Any] | None = None + output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None + wrap_output: bool = False + + async def call_fn_with_arg_validation( + self, + fn: Callable[..., Any | Awaitable[Any]], + fn_is_async: bool, + arguments_to_validate: dict[str, Any], + arguments_to_pass_directly: dict[str, Any] | None, + ) -> Any: + """Call the given function with arguments validated and injected. + + Arguments are first attempted to be parsed from JSON, then validated against + the argument model, before being passed to the function. + """ + arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) + arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) + arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() + + arguments_parsed_dict |= arguments_to_pass_directly or {} + + if fn_is_async: + return await fn(**arguments_parsed_dict) + else: + return fn(**arguments_parsed_dict) + + def convert_result(self, result: Any) -> Any: + """ + Convert the result of a function call to the appropriate format for + the lowlevel server tool call handler: + + - If output_model is None, return the unstructured content directly. + - If output_model is not None, convert the result to structured output format + (dict[str, Any]) and return both unstructured and structured content. + + Note: we return unstructured content here **even though the lowlevel server + tool call handler provides generic backwards compatibility serialization of + structured content**. This is for FastMCP backwards compatibility: we need to + retain FastMCP's ad hoc conversion logic for constructing unstructured output + from function return values, whereas the lowlevel server simply serializes + the structured output. + """ + unstructured_content = _convert_to_content(result) + + if self.output_schema is None: + return unstructured_content + else: + if self.wrap_output: + result = {"result": result} + + assert self.output_model is not None, "Output model must be set if output schema is defined" + validated = self.output_model.model_validate(result) + structured_content = validated.model_dump(mode="json", by_alias=True) + + return (unstructured_content, structured_content) + + def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: + """Pre-parse data from JSON. + + Return a dict with same keys as input but with values parsed from JSON + if appropriate. + + This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside + a string rather than an actual list. Claude desktop is prone to this - in fact + it seems incapable of NOT doing this. For sub-models, it tends to pass + dicts (JSON objects) as JSON strings, which can be pre-parsed here. + """ + new_data = data.copy() # Shallow copy + + # Build a mapping from input keys (including aliases) to field info + key_to_field_info: dict[str, FieldInfo] = {} + for field_name, field_info in self.arg_model.model_fields.items(): + # Map both the field name and its alias (if any) to the field info + key_to_field_info[field_name] = field_info + if field_info.alias: + key_to_field_info[field_info.alias] = field_info + + for data_key in data.keys(): + if data_key not in key_to_field_info: + continue + + field_info = key_to_field_info[data_key] + if isinstance(data[data_key], str) and field_info.annotation is not str: + try: + pre_parsed = json.loads(data[data_key]) + except json.JSONDecodeError: + continue # Not JSON - skip + if isinstance(pre_parsed, str | int | float): + # This is likely that the raw value is e.g. `"hello"` which we + # Should really be parsed as '"hello"' in Python - but if we parse + # it as JSON it'll turn into just 'hello'. So we skip it. + continue + new_data[data_key] = pre_parsed + assert new_data.keys() == data.keys() + return new_data + + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) + + +def func_metadata( + func: Callable[..., Any], + skip_names: Sequence[str] = (), + structured_output: bool | None = None, +) -> FuncMetadata: + """Given a function, return metadata including a pydantic model representing its + signature. + + The use case for this is + ``` + meta = func_metadata(func) + validated_args = meta.arg_model.model_validate(some_raw_data_dict) + return func(**validated_args.model_dump_one_level()) + ``` + + **critically** it also provides pre-parse helper to attempt to parse things from + JSON. + + Args: + func: The function to convert to a pydantic model + skip_names: A list of parameter names to skip. These will not be included in + the model. + structured_output: Controls whether the tool's output is structured or unstructured + - If None, auto-detects based on the function's return type annotation + - If True, unconditionally creates a structured tool (return type annotation permitting) + - If False, unconditionally creates an unstructured tool + + If structured, creates a Pydantic model for the function's result based on its annotation. + Supports various return types: + - BaseModel subclasses (used directly) + - Primitive types (str, int, float, bool, bytes, None) - wrapped in a + model with a 'result' field + - TypedDict - converted to a Pydantic model with same fields + - Dataclasses and other annotated classes - converted to Pydantic models + - Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field + + Returns: + A FuncMetadata object containing: + - arg_model: A pydantic model representing the function's arguments + - output_model: A pydantic model for the return type if output is structured + - output_conversion: Records how function output should be converted before returning. + """ + sig = _get_typed_signature(func) + params = sig.parameters + dynamic_pydantic_model_params: dict[str, Any] = {} + globalns = getattr(func, "__globals__", {}) + for param in params.values(): + if param.name.startswith("_"): + raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'") + if param.name in skip_names: + continue + annotation = param.annotation + + # `x: None` / `x: None = None` + if annotation is None: + annotation = Annotated[ + None, + Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined), + ] + + # Untyped field + if annotation is inspect.Parameter.empty: + annotation = Annotated[ + Any, + Field(), + # 🤷 + WithJsonSchema({"title": param.name, "type": "string"}), + ] + + field_info = FieldInfo.from_annotated_attribute( + _get_typed_annotation(annotation, globalns), + param.default if param.default is not inspect.Parameter.empty else PydanticUndefined, + ) + + # Check if the parameter name conflicts with BaseModel attributes + # This is necessary because Pydantic warns about shadowing parent attributes + if hasattr(BaseModel, param.name) and callable(getattr(BaseModel, param.name)): + # Use an alias to avoid the shadowing warning + field_info.alias = param.name + field_info.validation_alias = param.name + field_info.serialization_alias = param.name + # Use a prefixed internal name + internal_name = f"field_{param.name}" + dynamic_pydantic_model_params[internal_name] = (field_info.annotation, field_info) + else: + dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) + continue + + arguments_model = create_model( + f"{func.__name__}Arguments", + **dynamic_pydantic_model_params, + __base__=ArgModelBase, + ) + + if structured_output is False: + return FuncMetadata(arg_model=arguments_model) + + # set up structured output support based on return type annotation + + if sig.return_annotation is inspect.Parameter.empty and structured_output is True: + raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output") + + output_info = FieldInfo.from_annotation(_get_typed_annotation(sig.return_annotation, globalns)) + annotation = output_info.annotation + + output_model, output_schema, wrap_output = _try_create_model_and_schema(annotation, func.__name__, output_info) + + if output_model is None and structured_output is True: + # Model creation failed or produced warnings - no structured output + raise InvalidSignature( + f"Function {func.__name__}: return type {annotation} is not serializable for structured output" + ) + + return FuncMetadata( + arg_model=arguments_model, + output_schema=output_schema, + output_model=output_model, + wrap_output=wrap_output, + ) + + +def _try_create_model_and_schema( + annotation: Any, func_name: str, field_info: FieldInfo +) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]: + """Try to create a model and schema for the given annotation without warnings. + + Returns: + tuple of (model or None, schema or None, wrap_output) + Model and schema are None if warnings occur or creation fails. + wrap_output is True if the result needs to be wrapped in {"result": ...} + """ + model = None + wrap_output = False + + # First handle special case: None + if annotation is None: + model = _create_wrapped_model(func_name, annotation, field_info) + wrap_output = True + + # Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.) + elif isinstance(annotation, GenericAlias): + origin = get_origin(annotation) + + # Special case: dict with string keys can use RootModel + if origin is dict: + args = get_args(annotation) + if len(args) == 2 and args[0] is str: + model = _create_dict_model(func_name, annotation) + else: + # dict with non-str keys needs wrapping + model = _create_wrapped_model(func_name, annotation, field_info) + wrap_output = True + else: + # All other generic types need wrapping (list, tuple, Union, Optional, etc.) + model = _create_wrapped_model(func_name, annotation, field_info) + wrap_output = True + + # Handle regular type objects + elif isinstance(annotation, type): + type_annotation: type[Any] = cast(type[Any], annotation) + + # Case 1: BaseModel subclasses (can be used directly) + if issubclass(annotation, BaseModel): + model = annotation + + # Case 2: TypedDict (special dict subclass with __annotations__) + elif hasattr(type_annotation, "__annotations__") and issubclass(annotation, dict): + model = _create_model_from_typeddict(type_annotation) + + # Case 3: Primitive types that need wrapping + elif annotation in (str, int, float, bool, bytes, type(None)): + model = _create_wrapped_model(func_name, annotation, field_info) + wrap_output = True + + # Case 4: Other class types (dataclasses, regular classes with annotations) + else: + type_hints = get_type_hints(type_annotation) + if type_hints: + # Classes with type hints can be converted to Pydantic models + model = _create_model_from_class(type_annotation) + # Classes without type hints are not serializable - model remains None + + # Handle any other types not covered above + else: + # This includes typing constructs that aren't GenericAlias in Python 3.10 + # (e.g., Union, Optional in some Python versions) + model = _create_wrapped_model(func_name, annotation, field_info) + wrap_output = True + + if model: + # If we successfully created a model, try to get its schema + # Use StrictJsonSchema to raise exceptions instead of warnings + try: + schema = model.model_json_schema(schema_generator=StrictJsonSchema) + except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e: + # These are expected errors when a type can't be converted to a Pydantic schema + # TypeError: When Pydantic can't handle the type + # ValueError: When there are issues with the type definition (including our custom warnings) + # SchemaError: When Pydantic can't build a schema + # ValidationError: When validation fails + logger.info(f"Cannot create schema for type {annotation} in {func_name}: {type(e).__name__}: {e}") + return None, None, False + + return model, schema, wrap_output + + return None, None, False + + +def _create_model_from_class(cls: type[Any]) -> type[BaseModel]: + """Create a Pydantic model from an ordinary class. + + The created model will: + - Have the same name as the class + - Have fields with the same names and types as the class's fields + - Include all fields whose type does not include None in the set of required fields + + Precondition: cls must have type hints (i.e., get_type_hints(cls) is non-empty) + """ + type_hints = get_type_hints(cls) + + model_fields: dict[str, Any] = {} + for field_name, field_type in type_hints.items(): + if field_name.startswith("_"): + continue + + default = getattr(cls, field_name, PydanticUndefined) + field_info = FieldInfo.from_annotated_attribute(field_type, default) + model_fields[field_name] = (field_info.annotation, field_info) + + # Create a base class with the config + class BaseWithConfig(BaseModel): + model_config = ConfigDict(from_attributes=True) + + return create_model(cls.__name__, **model_fields, __base__=BaseWithConfig) + + +def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]: + """Create a Pydantic model from a TypedDict. + + The created model will have the same name and fields as the TypedDict. + """ + type_hints = get_type_hints(td_type) + required_keys = getattr(td_type, "__required_keys__", set(type_hints.keys())) + + model_fields: dict[str, Any] = {} + for field_name, field_type in type_hints.items(): + field_info = FieldInfo.from_annotation(field_type) + + if field_name not in required_keys: + # For optional TypedDict fields, set default=None + # This makes them not required in the Pydantic model + # The model should use exclude_unset=True when dumping to get TypedDict semantics + field_info.default = None + + model_fields[field_name] = (field_info.annotation, field_info) + + return create_model(td_type.__name__, **model_fields, __base__=BaseModel) + + +def _create_wrapped_model(func_name: str, annotation: Any, field_info: FieldInfo) -> type[BaseModel]: + """Create a model that wraps a type in a 'result' field. + + This is used for primitive types, generic types like list/dict, etc. + """ + model_name = f"{func_name}Output" + + # Pydantic needs type(None) instead of None for the type annotation + if annotation is None: + annotation = type(None) + + return create_model(model_name, result=(annotation, field_info), __base__=BaseModel) + + +def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: + """Create a RootModel for dict[str, T] types.""" + + class DictModel(RootModel[dict_annotation]): + pass + + # Give it a meaningful name + DictModel.__name__ = f"{func_name}DictOutput" + DictModel.__qualname__ = f"{func_name}DictOutput" + + return DictModel + + +def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: + def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]: + try: + return eval_type_backport(value, globalns, localns), True + except NameError: + return value, False + + if isinstance(annotation, str): + annotation = ForwardRef(annotation) + annotation, status = try_eval_type(annotation, globalns, globalns) + + # This check and raise could perhaps be skipped, and we (FastMCP) just call + # model_rebuild right before using it 🤷 + if status is False: + raise InvalidSignature(f"Unable to evaluate type annotation {annotation}") + + return annotation + + +def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: + """Get function signature while evaluating forward references""" + signature = inspect.signature(call) + globalns = getattr(call, "__globals__", {}) + typed_params = [ + inspect.Parameter( + name=param.name, + kind=param.kind, + default=param.default, + annotation=_get_typed_annotation(param.annotation, globalns), + ) + for param in signature.parameters.values() + ] + typed_return = _get_typed_annotation(signature.return_annotation, globalns) + typed_signature = inspect.Signature(typed_params, return_annotation=typed_return) + return typed_signature + + +def _convert_to_content( + result: Any, +) -> Sequence[ContentBlock]: + """ + Convert a result to a sequence of content objects. + + Note: This conversion logic comes from previous versions of FastMCP and is being + retained for purposes of backwards compatibility. It produces different unstructured + output than the lowlevel server tool call handler, which just serializes structured + content verbatim. + """ + if result is None: + return [] + + if isinstance(result, ContentBlock): + return [result] + + if isinstance(result, Image): + return [result.to_image_content()] + + if isinstance(result, Audio): + return [result.to_audio_content()] + + if isinstance(result, list | tuple): + return list( + chain.from_iterable( + _convert_to_content(item) + for item in result # type: ignore + ) + ) + + if not isinstance(result, str): + result = pydantic_core.to_json(result, fallback=str, indent=2).decode() + + return [TextContent(type="text", text=result)] diff --git a/src/mcp/mcp/server/fastmcp/utilities/logging.py b/src/mcp/mcp/server/fastmcp/utilities/logging.py new file mode 100644 index 000000000..091d57e69 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/utilities/logging.py @@ -0,0 +1,43 @@ +"""Logging utilities for FastMCP.""" + +import logging +from typing import Literal + + +def get_logger(name: str) -> logging.Logger: + """Get a logger nested under MCPnamespace. + + Args: + name: the name of the logger, which will be prefixed with 'FastMCP.' + + Returns: + a configured logger instance + """ + return logging.getLogger(name) + + +def configure_logging( + level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", +) -> None: + """Configure logging for MCP. + + Args: + level: the log level to use + """ + handlers: list[logging.Handler] = [] + try: + from rich.console import Console + from rich.logging import RichHandler + + handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) + except ImportError: + pass + + if not handlers: + handlers.append(logging.StreamHandler()) + + logging.basicConfig( + level=level, + format="%(message)s", + handlers=handlers, + ) diff --git a/src/mcp/mcp/server/fastmcp/utilities/types.py b/src/mcp/mcp/server/fastmcp/utilities/types.py new file mode 100644 index 000000000..1be6f8274 --- /dev/null +++ b/src/mcp/mcp/server/fastmcp/utilities/types.py @@ -0,0 +1,101 @@ +"""Common types used across FastMCP.""" + +import base64 +from pathlib import Path + +from mcp.types import AudioContent, ImageContent + + +class Image: + """Helper class for returning images from tools.""" + + def __init__( + self, + path: str | Path | None = None, + data: bytes | None = None, + format: str | None = None, + ): + if path is None and data is None: + raise ValueError("Either path or data must be provided") + if path is not None and data is not None: + raise ValueError("Only one of path or data can be provided") + + self.path = Path(path) if path else None + self.data = data + self._format = format + self._mime_type = self._get_mime_type() + + def _get_mime_type(self) -> str: + """Get MIME type from format or guess from file extension.""" + if self._format: + return f"image/{self._format.lower()}" + + if self.path: + suffix = self.path.suffix.lower() + return { + ".png": "image/png", + ".jpg": "image/jpeg", + ".jpeg": "image/jpeg", + ".gif": "image/gif", + ".webp": "image/webp", + }.get(suffix, "application/octet-stream") + return "image/png" # default for raw binary data + + def to_image_content(self) -> ImageContent: + """Convert to MCP ImageContent.""" + if self.path: + with open(self.path, "rb") as f: + data = base64.b64encode(f.read()).decode() + elif self.data is not None: + data = base64.b64encode(self.data).decode() + else: + raise ValueError("No image data available") + + return ImageContent(type="image", data=data, mimeType=self._mime_type) + + +class Audio: + """Helper class for returning audio from tools.""" + + def __init__( + self, + path: str | Path | None = None, + data: bytes | None = None, + format: str | None = None, + ): + if not bool(path) ^ bool(data): + raise ValueError("Either path or data can be provided") + + self.path = Path(path) if path else None + self.data = data + self._format = format + self._mime_type = self._get_mime_type() + + def _get_mime_type(self) -> str: + """Get MIME type from format or guess from file extension.""" + if self._format: + return f"audio/{self._format.lower()}" + + if self.path: + suffix = self.path.suffix.lower() + return { + ".wav": "audio/wav", + ".mp3": "audio/mpeg", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".aac": "audio/aac", + ".m4a": "audio/mp4", + }.get(suffix, "application/octet-stream") + return "audio/wav" # default for raw binary data + + def to_audio_content(self) -> AudioContent: + """Convert to MCP AudioContent.""" + if self.path: + with open(self.path, "rb") as f: + data = base64.b64encode(f.read()).decode() + elif self.data is not None: + data = base64.b64encode(self.data).decode() + else: + raise ValueError("No audio data available") + + return AudioContent(type="audio", data=data, mimeType=self._mime_type) diff --git a/src/mcp/mcp/server/lowlevel/__init__.py b/src/mcp/mcp/server/lowlevel/__init__.py new file mode 100644 index 000000000..66df38991 --- /dev/null +++ b/src/mcp/mcp/server/lowlevel/__init__.py @@ -0,0 +1,3 @@ +from .server import NotificationOptions, Server + +__all__ = ["Server", "NotificationOptions"] diff --git a/src/mcp/mcp/server/lowlevel/helper_types.py b/src/mcp/mcp/server/lowlevel/helper_types.py new file mode 100644 index 000000000..3d09b2505 --- /dev/null +++ b/src/mcp/mcp/server/lowlevel/helper_types.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass + + +@dataclass +class ReadResourceContents: + """Contents returned from a read_resource call.""" + + content: str | bytes + mime_type: str | None = None diff --git a/src/mcp/mcp/server/lowlevel/server.py b/src/mcp/mcp/server/lowlevel/server.py new file mode 100644 index 000000000..8c459383c --- /dev/null +++ b/src/mcp/mcp/server/lowlevel/server.py @@ -0,0 +1,686 @@ +""" +MCP Server Module + +This module provides a framework for creating an MCP (Model Context Protocol) server. +It allows you to easily define and handle various types of requests and notifications +in an asynchronous manner. + +Usage: +1. Create a Server instance: + server = Server("your_server_name") + +2. Define request handlers using decorators: + @server.list_prompts() + async def handle_list_prompts() -> list[types.Prompt]: + # Implementation + + @server.get_prompt() + async def handle_get_prompt( + name: str, arguments: dict[str, str] | None + ) -> types.GetPromptResult: + # Implementation + + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + # Implementation + + @server.call_tool() + async def handle_call_tool( + name: str, arguments: dict | None + ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: + # Implementation + + @server.list_resource_templates() + async def handle_list_resource_templates() -> list[types.ResourceTemplate]: + # Implementation + +3. Define notification handlers if needed: + @server.progress_notification() + async def handle_progress( + progress_token: str | int, progress: float, total: float | None, + message: str | None + ) -> None: + # Implementation + +4. Run the server: + async def main(): + async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + InitializationOptions( + server_name="your_server_name", + server_version="your_version", + capabilities=server.get_capabilities( + notification_options=NotificationOptions(), + experimental_capabilities={}, + ), + ), + ) + + asyncio.run(main()) + +The Server class provides methods to register handlers for various MCP requests and +notifications. It automatically manages the request context and handles incoming +messages from the client. +""" + +from __future__ import annotations as _annotations + +import contextvars +import json +import logging +import warnings +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable +from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from typing import Any, Generic, TypeAlias, cast + +import anyio +import jsonschema +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import AnyUrl +from typing_extensions import TypeVar + +import mcp.types as types +from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.server.stdio import stdio_server as stdio_server +from mcp.shared.context import RequestContext +from mcp.shared.exceptions import McpError +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.session import RequestResponder + +logger = logging.getLogger(__name__) + +LifespanResultT = TypeVar("LifespanResultT", default=Any) +RequestT = TypeVar("RequestT", default=Any) + +# type aliases for tool call results +StructuredContent: TypeAlias = dict[str, Any] +UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] +CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] + +# This will be properly typed in each Server instance's context +request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") + + +class NotificationOptions: + def __init__( + self, + prompts_changed: bool = False, + resources_changed: bool = False, + tools_changed: bool = False, + ): + self.prompts_changed = prompts_changed + self.resources_changed = resources_changed + self.tools_changed = tools_changed + + +@asynccontextmanager +async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: + """Default lifespan context manager that does nothing. + + Args: + server: The server instance this lifespan is managing + + Returns: + An empty context object + """ + yield {} + + +class Server(Generic[LifespanResultT, RequestT]): + def __init__( + self, + name: str, + version: str | None = None, + instructions: str | None = None, + lifespan: Callable[ + [Server[LifespanResultT, RequestT]], + AbstractAsyncContextManager[LifespanResultT], + ] = lifespan, + ): + self.name = name + self.version = version + self.instructions = instructions + self.lifespan = lifespan + self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { + types.PingRequest: _ping_handler, + } + self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} + self._tool_cache: dict[str, types.Tool] = {} + logger.debug("Initializing server %r", name) + + def create_initialization_options( + self, + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, + ) -> InitializationOptions: + """Create initialization options from this server instance.""" + + def pkg_version(package: str) -> str: + try: + from importlib.metadata import version + + return version(package) + except Exception: + pass + + return "unknown" + + return InitializationOptions( + server_name=self.name, + server_version=self.version if self.version else pkg_version("mcp"), + capabilities=self.get_capabilities( + notification_options or NotificationOptions(), + experimental_capabilities or {}, + ), + instructions=self.instructions, + ) + + def get_capabilities( + self, + notification_options: NotificationOptions, + experimental_capabilities: dict[str, dict[str, Any]], + ) -> types.ServerCapabilities: + """Convert existing handlers to a ServerCapabilities object.""" + prompts_capability = None + resources_capability = None + tools_capability = None + logging_capability = None + completions_capability = None + + # Set prompt capabilities if handler exists + if types.ListPromptsRequest in self.request_handlers: + prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed) + + # Set resource capabilities if handler exists + if types.ListResourcesRequest in self.request_handlers: + resources_capability = types.ResourcesCapability( + subscribe=False, listChanged=notification_options.resources_changed + ) + + # Set tool capabilities if handler exists + if types.ListToolsRequest in self.request_handlers: + tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed) + + # Set logging capabilities if handler exists + if types.SetLevelRequest in self.request_handlers: + logging_capability = types.LoggingCapability() + + # Set completions capabilities if handler exists + if types.CompleteRequest in self.request_handlers: + completions_capability = types.CompletionsCapability() + + return types.ServerCapabilities( + prompts=prompts_capability, + resources=resources_capability, + tools=tools_capability, + logging=logging_capability, + experimental=experimental_capabilities, + completions=completions_capability, + ) + + @property + def request_context( + self, + ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: + """If called outside of a request context, this will raise a LookupError.""" + return request_ctx.get() + + def list_prompts(self): + def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): + logger.debug("Registering handler for PromptListRequest") + + async def handler(_: Any): + prompts = await func() + return types.ServerResult(types.ListPromptsResult(prompts=prompts)) + + self.request_handlers[types.ListPromptsRequest] = handler + return func + + return decorator + + def get_prompt(self): + def decorator( + func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], + ): + logger.debug("Registering handler for GetPromptRequest") + + async def handler(req: types.GetPromptRequest): + prompt_get = await func(req.params.name, req.params.arguments) + return types.ServerResult(prompt_get) + + self.request_handlers[types.GetPromptRequest] = handler + return func + + return decorator + + def list_resources(self): + def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): + logger.debug("Registering handler for ListResourcesRequest") + + async def handler(_: Any): + resources = await func() + return types.ServerResult(types.ListResourcesResult(resources=resources)) + + self.request_handlers[types.ListResourcesRequest] = handler + return func + + return decorator + + def list_resource_templates(self): + def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): + logger.debug("Registering handler for ListResourceTemplatesRequest") + + async def handler(_: Any): + templates = await func() + return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates)) + + self.request_handlers[types.ListResourceTemplatesRequest] = handler + return func + + return decorator + + def read_resource(self): + def decorator( + func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]], + ): + logger.debug("Registering handler for ReadResourceRequest") + + async def handler(req: types.ReadResourceRequest): + result = await func(req.params.uri) + + def create_content(data: str | bytes, mime_type: str | None): + match data: + case str() as data: + return types.TextResourceContents( + uri=req.params.uri, + text=data, + mimeType=mime_type or "text/plain", + ) + case bytes() as data: + import base64 + + return types.BlobResourceContents( + uri=req.params.uri, + blob=base64.b64encode(data).decode(), + mimeType=mime_type or "application/octet-stream", + ) + + match result: + case str() | bytes() as data: + warnings.warn( + "Returning str or bytes from read_resource is deprecated. " + "Use Iterable[ReadResourceContents] instead.", + DeprecationWarning, + stacklevel=2, + ) + content = create_content(data, None) + case Iterable() as contents: + contents_list = [ + create_content(content_item.content, content_item.mime_type) for content_item in contents + ] + return types.ServerResult( + types.ReadResourceResult( + contents=contents_list, + ) + ) + case _: + raise ValueError(f"Unexpected return type from read_resource: {type(result)}") + + return types.ServerResult( + types.ReadResourceResult( + contents=[content], + ) + ) + + self.request_handlers[types.ReadResourceRequest] = handler + return func + + return decorator + + def set_logging_level(self): + def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): + logger.debug("Registering handler for SetLevelRequest") + + async def handler(req: types.SetLevelRequest): + await func(req.params.level) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.SetLevelRequest] = handler + return func + + return decorator + + def subscribe_resource(self): + def decorator(func: Callable[[AnyUrl], Awaitable[None]]): + logger.debug("Registering handler for SubscribeRequest") + + async def handler(req: types.SubscribeRequest): + await func(req.params.uri) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.SubscribeRequest] = handler + return func + + return decorator + + def unsubscribe_resource(self): + def decorator(func: Callable[[AnyUrl], Awaitable[None]]): + logger.debug("Registering handler for UnsubscribeRequest") + + async def handler(req: types.UnsubscribeRequest): + await func(req.params.uri) + return types.ServerResult(types.EmptyResult()) + + self.request_handlers[types.UnsubscribeRequest] = handler + return func + + return decorator + + def list_tools(self): + def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): + logger.debug("Registering handler for ListToolsRequest") + + async def handler(_: Any): + tools = await func() + # Refresh the tool cache + self._tool_cache.clear() + for tool in tools: + self._tool_cache[tool.name] = tool + return types.ServerResult(types.ListToolsResult(tools=tools)) + + self.request_handlers[types.ListToolsRequest] = handler + return func + + return decorator + + def _make_error_result(self, error_message: str) -> types.ServerResult: + """Create a ServerResult with an error CallToolResult.""" + return types.ServerResult( + types.CallToolResult( + content=[types.TextContent(type="text", text=error_message)], + isError=True, + ) + ) + + async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: + """Get tool definition from cache, refreshing if necessary. + + Returns the Tool object if found, None otherwise. + """ + if tool_name not in self._tool_cache: + if types.ListToolsRequest in self.request_handlers: + logger.debug("Tool cache miss for %s, refreshing cache", tool_name) + await self.request_handlers[types.ListToolsRequest](None) + + tool = self._tool_cache.get(tool_name) + if tool is None: + logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) + + return tool + + def call_tool(self, *, validate_input: bool = True): + """Register a tool call handler. + + Args: + validate_input: If True, validates input against inputSchema. Default is True. + + The handler validates input against inputSchema (if validate_input=True), calls the tool function, + and builds a CallToolResult with the results: + - Unstructured content (iterable of ContentBlock): returned in content + - Structured content (dict): returned in structuredContent, serialized JSON text returned in content + - Both: returned in content and structuredContent + + If outputSchema is defined, validates structuredContent or errors if missing. + """ + + def decorator( + func: Callable[ + ..., + Awaitable[UnstructuredContent | StructuredContent | CombinationContent], + ], + ): + logger.debug("Registering handler for CallToolRequest") + + async def handler(req: types.CallToolRequest): + try: + tool_name = req.params.name + arguments = req.params.arguments or {} + tool = await self._get_cached_tool_definition(tool_name) + + # input validation + if validate_input and tool: + try: + jsonschema.validate(instance=arguments, schema=tool.inputSchema) + except jsonschema.ValidationError as e: + return self._make_error_result(f"Input validation error: {e.message}") + + # tool call + results = await func(tool_name, arguments) + + # output normalization + unstructured_content: UnstructuredContent + maybe_structured_content: StructuredContent | None + if isinstance(results, tuple) and len(results) == 2: + # tool returned both structured and unstructured content + unstructured_content, maybe_structured_content = cast(CombinationContent, results) + elif isinstance(results, dict): + # tool returned structured content only + maybe_structured_content = cast(StructuredContent, results) + unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] + elif hasattr(results, "__iter__"): + # tool returned unstructured content only + unstructured_content = cast(UnstructuredContent, results) + maybe_structured_content = None + else: + return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") + + # output validation + if tool and tool.outputSchema is not None: + if maybe_structured_content is None: + return self._make_error_result( + "Output validation error: outputSchema defined but no structured output returned" + ) + else: + try: + jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema) + except jsonschema.ValidationError as e: + return self._make_error_result(f"Output validation error: {e.message}") + + # result + return types.ServerResult( + types.CallToolResult( + content=list(unstructured_content), + structuredContent=maybe_structured_content, + isError=False, + ) + ) + except Exception as e: + return self._make_error_result(str(e)) + + self.request_handlers[types.CallToolRequest] = handler + return func + + return decorator + + def progress_notification(self): + def decorator( + func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], + ): + logger.debug("Registering handler for ProgressNotification") + + async def handler(req: types.ProgressNotification): + await func( + req.params.progressToken, + req.params.progress, + req.params.total, + req.params.message, + ) + + self.notification_handlers[types.ProgressNotification] = handler + return func + + return decorator + + def completion(self): + """Provides completions for prompts and resource templates""" + + def decorator( + func: Callable[ + [ + types.PromptReference | types.ResourceTemplateReference, + types.CompletionArgument, + types.CompletionContext | None, + ], + Awaitable[types.Completion | None], + ], + ): + logger.debug("Registering handler for CompleteRequest") + + async def handler(req: types.CompleteRequest): + completion = await func(req.params.ref, req.params.argument, req.params.context) + return types.ServerResult( + types.CompleteResult( + completion=completion + if completion is not None + else types.Completion(values=[], total=None, hasMore=None), + ) + ) + + self.request_handlers[types.CompleteRequest] = handler + return func + + return decorator + + async def run( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + initialization_options: InitializationOptions, + # When False, exceptions are returned as messages to the client. + # When True, exceptions are raised, which will cause the server to shut down + # but also make tracing exceptions much easier during testing and when using + # in-process servers. + raise_exceptions: bool = False, + # When True, the server is stateless and + # clients can perform initialization with any node. The client must still follow + # the initialization lifecycle, but can do so with any available node + # rather than requiring initialization for each connection. + stateless: bool = False, + ): + async with AsyncExitStack() as stack: + lifespan_context = await stack.enter_async_context(self.lifespan(self)) + session = await stack.enter_async_context( + ServerSession( + read_stream, + write_stream, + initialization_options, + stateless=stateless, + ) + ) + + async with anyio.create_task_group() as tg: + async for message in session.incoming_messages: + logger.debug("Received message: %s", message) + + tg.start_soon( + self._handle_message, + message, + session, + lifespan_context, + raise_exceptions, + ) + + async def _handle_message( + self, + message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, + session: ServerSession, + lifespan_context: LifespanResultT, + raise_exceptions: bool = False, + ): + with warnings.catch_warnings(record=True) as w: + # TODO(Marcelo): We should be checking if message is Exception here. + match message: # type: ignore[reportMatchNotExhaustive] + case RequestResponder(request=types.ClientRequest(root=req)) as responder: + with responder: + await self._handle_request(message, req, session, lifespan_context, raise_exceptions) + case types.ClientNotification(root=notify): + await self._handle_notification(notify) + + for warning in w: + logger.info("Warning: %s: %s", warning.category.__name__, warning.message) + + async def _handle_request( + self, + message: RequestResponder[types.ClientRequest, types.ServerResult], + req: Any, + session: ServerSession, + lifespan_context: LifespanResultT, + raise_exceptions: bool, + ): + logger.info("Processing request of type %s", type(req).__name__) + if handler := self.request_handlers.get(type(req)): # type: ignore + logger.debug("Dispatching request of type %s", type(req).__name__) + + token = None + try: + # Extract request context from message metadata + request_data = None + if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): + request_data = message.message_metadata.request_context + + # Set our global state that can be retrieved via + # app.get_request_context() + token = request_ctx.set( + RequestContext( + message.request_id, + message.request_meta, + session, + lifespan_context, + request=request_data, + ) + ) + response = await handler(req) + except McpError as err: + response = err.error + except anyio.get_cancelled_exc_class(): + logger.info( + "Request %s cancelled - duplicate response suppressed", + message.request_id, + ) + return + except Exception as err: + if raise_exceptions: + raise err + response = types.ErrorData(code=0, message=str(err), data=None) + finally: + # Reset the global state after we are done + if token is not None: + request_ctx.reset(token) + + await message.respond(response) + else: + await message.respond( + types.ErrorData( + code=types.METHOD_NOT_FOUND, + message="Method not found", + ) + ) + + logger.debug("Response sent") + + async def _handle_notification(self, notify: Any): + if handler := self.notification_handlers.get(type(notify)): # type: ignore + logger.debug("Dispatching notification of type %s", type(notify).__name__) + + try: + await handler(notify) + except Exception: + logger.exception("Uncaught exception in notification handler") + + +async def _ping_handler(request: types.PingRequest) -> types.ServerResult: + return types.ServerResult(types.EmptyResult()) diff --git a/src/mcp/mcp/server/models.py b/src/mcp/mcp/server/models.py new file mode 100644 index 000000000..3b5abba78 --- /dev/null +++ b/src/mcp/mcp/server/models.py @@ -0,0 +1,17 @@ +""" +This module provides simpler types to use with the server for managing prompts +and tools. +""" + +from pydantic import BaseModel + +from mcp.types import ( + ServerCapabilities, +) + + +class InitializationOptions(BaseModel): + server_name: str + server_version: str + capabilities: ServerCapabilities + instructions: str | None = None diff --git a/src/mcp/mcp/server/session.py b/src/mcp/mcp/server/session.py new file mode 100644 index 000000000..5c696b136 --- /dev/null +++ b/src/mcp/mcp/server/session.py @@ -0,0 +1,359 @@ +""" +ServerSession Module + +This module provides the ServerSession class, which manages communication between the +server and client in the MCP (Model Context Protocol) framework. It is most commonly +used in MCP servers to interact with the client. + +Common usage pattern: +``` + server = Server(name) + + @server.call_tool() + async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: + # Check client capabilities before proceeding + if ctx.session.check_client_capability( + types.ClientCapabilities(experimental={"advanced_tools": dict()}) + ): + # Perform advanced tool operations + result = await perform_advanced_tool_operation(arguments) + else: + # Fall back to basic tool operations + result = await perform_basic_tool_operation(arguments) + + return result + + @server.list_prompts() + async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: + # Access session for any necessary checks or operations + if ctx.session.client_params: + # Customize prompts based on client initialization parameters + return generate_custom_prompts(ctx.session.client_params) + else: + return default_prompts +``` + +The ServerSession class is typically used internally by the Server class and should not +be instantiated directly by users of the MCP framework. +""" + +from enum import Enum +from typing import Any, TypeVar + +import anyio +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import AnyUrl + +import mcp.types as types +from mcp.server.models import InitializationOptions +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.session import ( + BaseSession, + RequestResponder, +) +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS + + +class InitializationState(Enum): + NotInitialized = 1 + Initializing = 2 + Initialized = 3 + + +ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") + +ServerRequestResponder = ( + RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception +) + + +class ServerSession( + BaseSession[ + types.ServerRequest, + types.ServerNotification, + types.ServerResult, + types.ClientRequest, + types.ClientNotification, + ] +): + _initialized: InitializationState = InitializationState.NotInitialized + _client_params: types.InitializeRequestParams | None = None + + def __init__( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + init_options: InitializationOptions, + stateless: bool = False, + ) -> None: + super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) + self._initialization_state = ( + InitializationState.Initialized if stateless else InitializationState.NotInitialized + ) + + self._init_options = init_options + self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ + ServerRequestResponder + ](0) + self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + + @property + def client_params(self) -> types.InitializeRequestParams | None: + return self._client_params + + def check_client_capability(self, capability: types.ClientCapabilities) -> bool: + """Check if the client supports a specific capability.""" + if self._client_params is None: + return False + + # Get client capabilities from initialization params + client_caps = self._client_params.capabilities + + # Check each specified capability in the passed in capability object + if capability.roots is not None: + if client_caps.roots is None: + return False + if capability.roots.listChanged and not client_caps.roots.listChanged: + return False + + if capability.sampling is not None: + if client_caps.sampling is None: + return False + + if capability.elicitation is not None: + if client_caps.elicitation is None: + return False + + if capability.experimental is not None: + if client_caps.experimental is None: + return False + # Check each experimental capability + for exp_key, exp_value in capability.experimental.items(): + if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: + return False + + return True + + async def _receive_loop(self) -> None: + async with self._incoming_message_stream_writer: + await super()._receive_loop() + + async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): + match responder.request.root: + case types.InitializeRequest(params=params): + requested_version = params.protocolVersion + self._initialization_state = InitializationState.Initializing + self._client_params = params + with responder: + await responder.respond( + types.ServerResult( + types.InitializeResult( + protocolVersion=requested_version + if requested_version in SUPPORTED_PROTOCOL_VERSIONS + else types.LATEST_PROTOCOL_VERSION, + capabilities=self._init_options.capabilities, + serverInfo=types.Implementation( + name=self._init_options.server_name, + version=self._init_options.server_version, + ), + instructions=self._init_options.instructions, + ) + ) + ) + case _: + if self._initialization_state != InitializationState.Initialized: + raise RuntimeError("Received request before initialization was complete") + + async def _received_notification(self, notification: types.ClientNotification) -> None: + # Need this to avoid ASYNC910 + await anyio.lowlevel.checkpoint() + match notification.root: + case types.InitializedNotification(): + self._initialization_state = InitializationState.Initialized + case _: + if self._initialization_state != InitializationState.Initialized: + raise RuntimeError("Received notification before initialization was complete") + + async def send_log_message( + self, + level: types.LoggingLevel, + data: Any, + logger: str | None = None, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a log message notification.""" + await self.send_notification( + types.ServerNotification( + types.LoggingMessageNotification( + method="notifications/message", + params=types.LoggingMessageNotificationParams( + level=level, + data=data, + logger=logger, + ), + ) + ), + related_request_id, + ) + + async def send_resource_updated(self, uri: AnyUrl) -> None: + """Send a resource updated notification.""" + await self.send_notification( + types.ServerNotification( + types.ResourceUpdatedNotification( + method="notifications/resources/updated", + params=types.ResourceUpdatedNotificationParams(uri=uri), + ) + ) + ) + + async def create_message( + self, + messages: list[types.SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: types.IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: types.ModelPreferences | None = None, + related_request_id: types.RequestId | None = None, + ) -> types.CreateMessageResult: + """Send a sampling/create_message request.""" + return await self.send_request( + request=types.ServerRequest( + types.CreateMessageRequest( + method="sampling/createMessage", + params=types.CreateMessageRequestParams( + messages=messages, + systemPrompt=system_prompt, + includeContext=include_context, + temperature=temperature, + maxTokens=max_tokens, + stopSequences=stop_sequences, + metadata=metadata, + modelPreferences=model_preferences, + ), + ) + ), + result_type=types.CreateMessageResult, + metadata=ServerMessageMetadata( + related_request_id=related_request_id, + ), + ) + + async def list_roots(self) -> types.ListRootsResult: + """Send a roots/list request.""" + return await self.send_request( + types.ServerRequest( + types.ListRootsRequest( + method="roots/list", + ) + ), + types.ListRootsResult, + ) + + async def elicit( + self, + message: str, + requestedSchema: types.ElicitRequestedSchema, + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: Schema defining the expected response structure + + Returns: + The client's response + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + method="elicitation/create", + params=types.ElicitRequestParams( + message=message, + requestedSchema=requestedSchema, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + + async def send_ping(self) -> types.EmptyResult: + """Send a ping request.""" + return await self.send_request( + types.ServerRequest( + types.PingRequest( + method="ping", + ) + ), + types.EmptyResult, + ) + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + related_request_id: str | None = None, + ) -> None: + """Send a progress notification.""" + await self.send_notification( + types.ServerNotification( + types.ProgressNotification( + method="notifications/progress", + params=types.ProgressNotificationParams( + progressToken=progress_token, + progress=progress, + total=total, + message=message, + ), + ) + ), + related_request_id, + ) + + async def send_resource_list_changed(self) -> None: + """Send a resource list changed notification.""" + await self.send_notification( + types.ServerNotification( + types.ResourceListChangedNotification( + method="notifications/resources/list_changed", + ) + ) + ) + + async def send_tool_list_changed(self) -> None: + """Send a tool list changed notification.""" + await self.send_notification( + types.ServerNotification( + types.ToolListChangedNotification( + method="notifications/tools/list_changed", + ) + ) + ) + + async def send_prompt_list_changed(self) -> None: + """Send a prompt list changed notification.""" + await self.send_notification( + types.ServerNotification( + types.PromptListChangedNotification( + method="notifications/prompts/list_changed", + ) + ) + ) + + async def _handle_incoming(self, req: ServerRequestResponder) -> None: + await self._incoming_message_stream_writer.send(req) + + @property + def incoming_messages( + self, + ) -> MemoryObjectReceiveStream[ServerRequestResponder]: + return self._incoming_message_stream_reader diff --git a/src/mcp/mcp/server/sse.py b/src/mcp/mcp/server/sse.py new file mode 100644 index 000000000..b7ff33280 --- /dev/null +++ b/src/mcp/mcp/server/sse.py @@ -0,0 +1,249 @@ +""" +SSE Server Transport Module + +This module implements a Server-Sent Events (SSE) transport layer for MCP servers. + +Example usage: +``` + # Create an SSE transport at an endpoint + sse = SseServerTransport("/messages/") + + # Create Starlette routes for SSE and message handling + routes = [ + Route("/sse", endpoint=handle_sse, methods=["GET"]), + Mount("/messages/", app=sse.handle_post_message), + ] + + # Define handler functions + async def handle_sse(request): + async with sse.connect_sse( + request.scope, request.receive, request._send + ) as streams: + await app.run( + streams[0], streams[1], app.create_initialization_options() + ) + # Return empty response to avoid NoneType error + return Response() + + # Create and run Starlette app + starlette_app = Starlette(routes=routes) + uvicorn.run(starlette_app, host="127.0.0.1", port=port) +``` + +Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' +object is not callable" error when client disconnects. The example above returns +an empty Response() after the SSE connection ends to fix this. + +See SseServerTransport class documentation for more details. +""" + +import logging +from contextlib import asynccontextmanager +from typing import Any +from urllib.parse import quote +from uuid import UUID, uuid4 + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +import mcp.types as types +from mcp.server.transport_security import ( + TransportSecurityMiddleware, + TransportSecuritySettings, +) +from mcp.shared.message import ServerMessageMetadata, SessionMessage + +logger = logging.getLogger(__name__) + + +class SseServerTransport: + """ + SSE server transport for MCP. This class provides _two_ ASGI applications, + suitable to be used with a framework like Starlette and a server like Hypercorn: + + 1. connect_sse() is an ASGI application which receives incoming GET requests, + and sets up a new SSE stream to send server messages to the client. + 2. handle_post_message() is an ASGI application which receives incoming POST + requests, which should contain client messages that link to a + previously-established SSE session. + """ + + _endpoint: str + _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] + _security: TransportSecurityMiddleware + + def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: + """ + Creates a new SSE server transport, which will direct the client to POST + messages to the relative path given. + + Args: + endpoint: A relative path where messages should be posted + (e.g., "/messages/"). + security_settings: Optional security settings for DNS rebinding protection. + + Note: + We use relative paths instead of full URLs for several reasons: + 1. Security: Prevents cross-origin requests by ensuring clients only connect + to the same origin they established the SSE connection with + 2. Flexibility: The server can be mounted at any path without needing to + know its full URL + 3. Portability: The same endpoint configuration works across different + environments (development, staging, production) + + Raises: + ValueError: If the endpoint is a full URL instead of a relative path + """ + + super().__init__() + + # Validate that endpoint is a relative path and not a full URL + if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint: + raise ValueError( + f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), " + "expecting a relative path (e.g., '/messages/')." + ) + + # Ensure endpoint starts with a forward slash + if not endpoint.startswith("/"): + endpoint = "/" + endpoint + + self._endpoint = endpoint + self._read_stream_writers = {} + self._security = TransportSecurityMiddleware(security_settings) + logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") + + @asynccontextmanager + async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + if scope["type"] != "http": + logger.error("connect_sse received non-HTTP request") + raise ValueError("connect_sse can only handle HTTP requests") + + # Validate request headers for DNS rebinding protection + request = Request(scope, receive) + error_response = await self._security.validate_request(request, is_post=False) + if error_response: + await error_response(scope, receive, send) + raise ValueError("Request validation failed") + + logger.debug("Setting up SSE connection") + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + session_id = uuid4() + self._read_stream_writers[session_id] = read_stream_writer + logger.debug(f"Created new session with ID: {session_id}") + + # Determine the full path for the message endpoint to be sent to the client. + # scope['root_path'] is the prefix where the current Starlette app + # instance is mounted. + # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". + root_path = scope.get("root_path", "") + + # self._endpoint is the path *within* this app, e.g., "/messages". + # Concatenating them gives the full absolute path from the server root. + # e.g., "" + "/messages" -> "/messages" + # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" + full_message_path_for_client = root_path.rstrip("/") + self._endpoint + + # This is the URI (path + query) the client will use to POST messages. + client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" + + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) + + async def sse_writer(): + logger.debug("Starting SSE writer") + async with sse_stream_writer, write_stream_reader: + await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) + logger.debug(f"Sent endpoint event: {client_post_uri_data}") + + async for session_message in write_stream_reader: + logger.debug(f"Sending message via SSE: {session_message}") + await sse_stream_writer.send( + { + "event": "message", + "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), + } + ) + + async with anyio.create_task_group() as tg: + + async def response_wrapper(scope: Scope, receive: Receive, send: Send): + """ + The EventSourceResponse returning signals a client close / disconnect. + In this case we close our side of the streams to signal the client that + the connection has been closed. + """ + await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( + scope, receive, send + ) + await read_stream_writer.aclose() + await write_stream_reader.aclose() + logging.debug(f"Client session disconnected {session_id}") + + logger.debug("Starting SSE response task") + tg.start_soon(response_wrapper, scope, receive, send) + + logger.debug("Yielding read and write streams") + yield (read_stream, write_stream) + + async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: + logger.debug("Handling POST message") + request = Request(scope, receive) + + # Validate request headers for DNS rebinding protection + error_response = await self._security.validate_request(request, is_post=True) + if error_response: + return await error_response(scope, receive, send) + + session_id_param = request.query_params.get("session_id") + if session_id_param is None: + logger.warning("Received request without session_id") + response = Response("session_id is required", status_code=400) + return await response(scope, receive, send) + + try: + session_id = UUID(hex=session_id_param) + logger.debug(f"Parsed session ID: {session_id}") + except ValueError: + logger.warning(f"Received invalid session ID: {session_id_param}") + response = Response("Invalid session ID", status_code=400) + return await response(scope, receive, send) + + writer = self._read_stream_writers.get(session_id) + if not writer: + logger.warning(f"Could not find session for ID: {session_id}") + response = Response("Could not find session", status_code=404) + return await response(scope, receive, send) + + body = await request.body() + logger.debug(f"Received JSON: {body}") + + try: + message = types.JSONRPCMessage.model_validate_json(body) + logger.debug(f"Validated client message: {message}") + except ValidationError as err: + logger.exception("Failed to parse message") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + await writer.send(err) + return + + # Pass the ASGI scope for framework-agnostic access to request data + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) + logger.debug(f"Sending session message to writer: {session_message}") + response = Response("Accepted", status_code=202) + await response(scope, receive, send) + await writer.send(session_message) diff --git a/src/mcp/mcp/server/stdio.py b/src/mcp/mcp/server/stdio.py new file mode 100644 index 000000000..d1618a371 --- /dev/null +++ b/src/mcp/mcp/server/stdio.py @@ -0,0 +1,88 @@ +""" +Stdio Server Transport Module + +This module provides functionality for creating an stdio-based transport layer +that can be used to communicate with an MCP client through standard input/output +streams. + +Example usage: +``` + async def run_server(): + async with stdio_server() as (read_stream, write_stream): + # read_stream contains incoming JSONRPCMessages from stdin + # write_stream allows sending JSONRPCMessages to stdout + server = await create_my_server() + await server.run(read_stream, write_stream, init_options) + + anyio.run(run_server) +``` +""" + +import sys +from contextlib import asynccontextmanager +from io import TextIOWrapper + +import anyio +import anyio.lowlevel +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +import mcp.types as types +from mcp.shared.message import SessionMessage + + +@asynccontextmanager +async def stdio_server( + stdin: anyio.AsyncFile[str] | None = None, + stdout: anyio.AsyncFile[str] | None = None, +): + """ + Server transport for stdio: this communicates with an MCP client by reading + from the current process' stdin and writing to stdout. + """ + # Purposely not using context managers for these, as we don't want to close + # standard process handles. Encoding of stdin/stdout as text streams on + # python is platform-dependent (Windows is particularly problematic), so we + # re-wrap the underlying binary stream to ensure UTF-8. + if not stdin: + stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8")) + if not stdout: + stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) + + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def stdin_reader(): + try: + async with read_stream_writer: + async for line in stdin: + try: + message = types.JSONRPCMessage.model_validate_json(line) + except Exception as exc: + await read_stream_writer.send(exc) + continue + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async def stdout_writer(): + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + await stdout.write(json + "\n") + await stdout.flush() + except anyio.ClosedResourceError: + await anyio.lowlevel.checkpoint() + + async with anyio.create_task_group() as tg: + tg.start_soon(stdin_reader) + tg.start_soon(stdout_writer) + yield read_stream, write_stream diff --git a/src/mcp/mcp/server/streamable_http.py b/src/mcp/mcp/server/streamable_http.py new file mode 100644 index 000000000..802cb8680 --- /dev/null +++ b/src/mcp/mcp/server/streamable_http.py @@ -0,0 +1,901 @@ +""" +StreamableHTTP Server Transport Module + +This module implements an HTTP transport layer with Streamable HTTP. + +The transport handles bidirectional communication using HTTP requests and +responses, with streaming support for long-running operations. +""" + +import json +import logging +import re +from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass +from http import HTTPStatus + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import ValidationError +from sse_starlette import EventSourceResponse +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.server.transport_security import ( + TransportSecurityMiddleware, + TransportSecuritySettings, +) +from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.types import ( + DEFAULT_NEGOTIATED_VERSION, + INTERNAL_ERROR, + INVALID_PARAMS, + INVALID_REQUEST, + PARSE_ERROR, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCRequest, + JSONRPCResponse, + RequestId, +) + +logger = logging.getLogger(__name__) + + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + +# Special key for the standalone GET stream +GET_STREAM_KEY = "_GET_stream" + +# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) +# Pattern ensures entire string contains only valid characters by using ^ and $ anchors +SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") + +# Type aliases +StreamId = str +EventId = str + + +@dataclass +class EventMessage: + """ + A JSONRPCMessage with an optional event ID for stream resumability. + """ + + message: JSONRPCMessage + event_id: str | None = None + + +EventCallback = Callable[[EventMessage], Awaitable[None]] + + +class EventStore(ABC): + """ + Interface for resumability support via event storage. + """ + + @abstractmethod + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: + """ + Stores an event for later retrieval. + + Args: + stream_id: ID of the stream the event belongs to + message: The JSON-RPC message to store + + Returns: + The generated event ID for the stored event + """ + pass + + @abstractmethod + async def replay_events_after( + self, + last_event_id: EventId, + send_callback: EventCallback, + ) -> StreamId | None: + """ + Replays events that occurred after the specified event ID. + + Args: + last_event_id: The ID of the last event the client received + send_callback: A callback function to send events to the client + + Returns: + The stream ID of the replayed events + """ + pass + + +class StreamableHTTPServerTransport: + """ + HTTP server transport with event streaming support for MCP. + + Handles JSON-RPC messages in HTTP POST requests with SSE streaming. + Supports optional JSON responses and session management. + """ + + # Server notification streams for POST requests as well as standalone SSE stream + _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None + _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None + _write_stream: MemoryObjectSendStream[SessionMessage] | None = None + _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None + _security: TransportSecurityMiddleware + + def __init__( + self, + mcp_session_id: str | None, + is_json_response_enabled: bool = False, + event_store: EventStore | None = None, + security_settings: TransportSecuritySettings | None = None, + ) -> None: + """ + Initialize a new StreamableHTTP server transport. + + Args: + mcp_session_id: Optional session identifier for this connection. + Must contain only visible ASCII characters (0x21-0x7E). + is_json_response_enabled: If True, return JSON responses for requests + instead of SSE streams. Default is False. + event_store: Event store for resumability support. If provided, + resumability will be enabled, allowing clients to + reconnect and resume messages. + security_settings: Optional security settings for DNS rebinding protection. + + Raises: + ValueError: If the session ID contains invalid characters. + """ + if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id): + raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)") + + self.mcp_session_id = mcp_session_id + self.is_json_response_enabled = is_json_response_enabled + self._event_store = event_store + self._security = TransportSecurityMiddleware(security_settings) + self._request_streams: dict[ + RequestId, + tuple[ + MemoryObjectSendStream[EventMessage], + MemoryObjectReceiveStream[EventMessage], + ], + ] = {} + self._terminated = False + + @property + def is_terminated(self) -> bool: + """Check if this transport has been explicitly terminated.""" + return self._terminated + + def _create_error_response( + self, + error_message: str, + status_code: HTTPStatus, + error_code: int = INVALID_REQUEST, + headers: dict[str, str] | None = None, + ) -> Response: + """Create an error response with a simple string message.""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Return a properly formatted JSON error response + error_response = JSONRPCError( + jsonrpc="2.0", + id="server-error", # We don't have a request ID for general errors + error=ErrorData( + code=error_code, + message=error_message, + ), + ) + + return Response( + error_response.model_dump_json(by_alias=True, exclude_none=True), + status_code=status_code, + headers=response_headers, + ) + + def _create_json_response( + self, + response_message: JSONRPCMessage | None, + status_code: HTTPStatus = HTTPStatus.OK, + headers: dict[str, str] | None = None, + ) -> Response: + """Create a JSON response from a JSONRPCMessage""" + response_headers = {"Content-Type": CONTENT_TYPE_JSON} + if headers: + response_headers.update(headers) + + if self.mcp_session_id: + response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + return Response( + response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, + status_code=status_code, + headers=response_headers, + ) + + def _get_session_id(self, request: Request) -> str | None: + """Extract the session ID from request headers.""" + return request.headers.get(MCP_SESSION_ID_HEADER) + + def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + """Create event data dictionary from an EventMessage.""" + event_data = { + "event": "message", + "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), + } + + # If an event ID was provided, include it + if event_message.event_id: + event_data["id"] = event_message.event_id + + return event_data + + async def _clean_up_memory_streams(self, request_id: RequestId) -> None: + """Clean up memory streams for a given request ID.""" + if request_id in self._request_streams: + try: + # Close the request stream + await self._request_streams[request_id][0].aclose() + await self._request_streams[request_id][1].aclose() + except Exception: + # During cleanup, we catch all exceptions since streams might be in various states + logger.debug("Error closing memory streams - may already be closed") + finally: + # Remove the request stream from the mapping + self._request_streams.pop(request_id, None) + + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: + """Application entry point that handles all HTTP requests""" + request = Request(scope, receive) + + # Validate request headers for DNS rebinding protection + is_post = request.method == "POST" + error_response = await self._security.validate_request(request, is_post=is_post) + if error_response: + await error_response(scope, receive, send) + return + + if self._terminated: + # If the session has been terminated, return 404 Not Found + response = self._create_error_response( + "Not Found: Session has been terminated", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + + if request.method == "POST": + await self._handle_post_request(scope, request, receive, send) + elif request.method == "GET": + await self._handle_get_request(request, send) + elif request.method == "DELETE": + await self._handle_delete_request(request, send) + else: + await self._handle_unsupported_request(request, send) + + def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: + """Check if the request accepts the required media types.""" + accept_header = request.headers.get("accept", "") + accept_types = [media_type.strip() for media_type in accept_header.split(",")] + + has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) + has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) + + return has_json, has_sse + + def _check_content_type(self, request: Request) -> bool: + """Check if the request has the correct Content-Type.""" + content_type = request.headers.get("content-type", "") + content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")] + + return any(part == CONTENT_TYPE_JSON for part in content_type_parts) + + async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: + """Handle POST requests containing JSON-RPC messages.""" + writer = self._read_stream_writer + if writer is None: + raise ValueError("No read stream writer available. Ensure connect() is called first.") + try: + # Check Accept headers + has_json, has_sse = self._check_accept_headers(request) + if not (has_json and has_sse): + response = self._create_error_response( + ("Not Acceptable: Client must accept both application/json and text/event-stream"), + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(scope, receive, send) + return + + # Validate Content-Type + if not self._check_content_type(request): + response = self._create_error_response( + "Unsupported Media Type: Content-Type must be application/json", + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + ) + await response(scope, receive, send) + return + + # Parse the body - only read it once + body = await request.body() + + try: + raw_message = json.loads(body) + except json.JSONDecodeError as e: + response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) + await response(scope, receive, send) + return + + try: + message = JSONRPCMessage.model_validate(raw_message) + except ValidationError as e: + response = self._create_error_response( + f"Validation error: {str(e)}", + HTTPStatus.BAD_REQUEST, + INVALID_PARAMS, + ) + await response(scope, receive, send) + return + + # Check if this is an initialization request + is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + + if is_initialization_request: + # Check if the server already has an established session + if self.mcp_session_id: + # Check if request has a session ID + request_session_id = self._get_session_id(request) + + # If request has a session ID but doesn't match, return 404 + if request_session_id and request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(scope, receive, send) + return + elif not await self._validate_request_headers(request, send): + return + + # For notifications and responses only, return 202 Accepted + if not isinstance(message.root, JSONRPCRequest): + # Create response object and send it + response = self._create_json_response( + None, + HTTPStatus.ACCEPTED, + ) + await response(scope, receive, send) + + # Process the message after sending the response + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) + await writer.send(session_message) + + return + + # Extract the request ID outside the try block for proper scope + request_id = str(message.root.id) + # Register this stream for the request ID + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) + request_stream_reader = self._request_streams[request_id][1] + + if self.is_json_response_enabled: + # Process the message + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) + await writer.send(session_message) + try: + # Process messages from the request-specific stream + # We need to collect all messages until we get a response + response_message = None + + # Use similar approach to SSE writer for consistency + async for event_message in request_stream_reader: + # If it's a response, this is what we're waiting for + if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError): + response_message = event_message.message + break + # For notifications and request, keep waiting + else: + logger.debug(f"received: {event_message.message.root.method}") + + # At this point we should have a response + if response_message: + # Create JSON response + response = self._create_json_response(response_message) + await response(scope, receive, send) + else: + # This shouldn't happen in normal operation + logger.error("No response message received before stream closed") + response = self._create_error_response( + "Error processing request: No response received", + HTTPStatus.INTERNAL_SERVER_ERROR, + ) + await response(scope, receive, send) + except Exception: + logger.exception("Error processing JSON response") + response = self._create_error_response( + "Error processing request", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + finally: + await self._clean_up_memory_streams(request_id) + else: + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + + async def sse_writer(): + # Get the request ID from the incoming request message + try: + async with sse_stream_writer, request_stream_reader: + # Process messages from the request-specific stream + async for event_message in request_stream_reader: + # Build the event data + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + # If response, remove from pending streams and close + if isinstance( + event_message.message.root, + JSONRPCResponse | JSONRPCError, + ): + break + except Exception: + logger.exception("Error in SSE writer") + finally: + logger.debug("Closing SSE writer") + await self._clean_up_memory_streams(request_id) + + # Create and start EventSourceResponse + # SSE stream mode (original behavior) + # Set up headers + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), + } + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=sse_writer, + headers=headers, + ) + + # Start the SSE response (this will send headers immediately) + try: + # First send the response to establish the SSE connection + async with anyio.create_task_group() as tg: + tg.start_soon(response, scope, receive, send) + # Then send the message to be processed by the server + metadata = ServerMessageMetadata(request_context=request) + session_message = SessionMessage(message, metadata=metadata) + await writer.send(session_message) + except Exception: + logger.exception("SSE response error") + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(request_id) + + except Exception as err: + logger.exception("Error handling POST request") + response = self._create_error_response( + f"Error handling POST request: {err}", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(scope, receive, send) + if writer: + await writer.send(Exception(err)) + return + + async def _handle_get_request(self, request: Request, send: Send) -> None: + """ + Handle GET request to establish SSE. + + This allows the server to communicate to the client without the client + first sending data via HTTP POST. The server can send JSON-RPC requests + and notifications on this stream. + """ + writer = self._read_stream_writer + if writer is None: + raise ValueError("No read stream writer available. Ensure connect() is called first.") + + # Validate Accept header - must include text/event-stream + _, has_sse = self._check_accept_headers(request) + + if not has_sse: + response = self._create_error_response( + "Not Acceptable: Client must accept text/event-stream", + HTTPStatus.NOT_ACCEPTABLE, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_request_headers(request, send): + return + + # Handle resumability: check for Last-Event-ID header + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): + await self._replay_events(last_event_id, request, send) + return + + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Check if we already have an active GET stream + if GET_STREAM_KEY in self._request_streams: + response = self._create_error_response( + "Conflict: Only one SSE stream is allowed per session", + HTTPStatus.CONFLICT, + ) + await response(request.scope, request.receive, send) + return + + # Create SSE stream + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + + async def standalone_sse_writer(): + try: + # Create a standalone message stream for server-initiated messages + + self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0) + standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] + + async with sse_stream_writer, standalone_stream_reader: + # Process messages from the standalone stream + async for event_message in standalone_stream_reader: + # For the standalone stream, we handle: + # - JSONRPCNotification (server sends notifications to client) + # - JSONRPCRequest (server sends requests to client) + # We should NOT receive JSONRPCResponse + + # Send the message via SSE + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + except Exception: + logger.exception("Error in standalone SSE writer") + finally: + logger.debug("Closing standalone SSE writer") + await self._clean_up_memory_streams(GET_STREAM_KEY) + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=standalone_sse_writer, + headers=headers, + ) + + try: + # This will send headers immediately and establish the SSE connection + await response(request.scope, request.receive, send) + except Exception: + logger.exception("Error in standalone SSE response") + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + await self._clean_up_memory_streams(GET_STREAM_KEY) + + async def _handle_delete_request(self, request: Request, send: Send) -> None: + """Handle DELETE requests for explicit session termination.""" + # Validate session ID + if not self.mcp_session_id: + # If no session ID set, return Method Not Allowed + response = self._create_error_response( + "Method Not Allowed: Session termination not supported", + HTTPStatus.METHOD_NOT_ALLOWED, + ) + await response(request.scope, request.receive, send) + return + + if not await self._validate_request_headers(request, send): + return + + await self.terminate() + + response = self._create_json_response( + None, + HTTPStatus.OK, + ) + await response(request.scope, request.receive, send) + + async def terminate(self) -> None: + """Terminate the current session, closing all streams. + + Once terminated, all requests with this session ID will receive 404 Not Found. + """ + + self._terminated = True + logger.info(f"Terminating session: {self.mcp_session_id}") + + # We need a copy of the keys to avoid modification during iteration + request_stream_keys = list(self._request_streams.keys()) + + # Close all request streams asynchronously + for key in request_stream_keys: + await self._clean_up_memory_streams(key) + + # Clear the request streams dictionary immediately + self._request_streams.clear() + try: + if self._read_stream_writer is not None: + await self._read_stream_writer.aclose() + if self._read_stream is not None: + await self._read_stream.aclose() + if self._write_stream_reader is not None: + await self._write_stream_reader.aclose() + if self._write_stream is not None: + await self._write_stream.aclose() + except Exception as e: + # During cleanup, we catch all exceptions since streams might be in various states + logger.debug(f"Error closing streams: {e}") + + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: + """Handle unsupported HTTP methods.""" + headers = { + "Content-Type": CONTENT_TYPE_JSON, + "Allow": "GET, POST, DELETE", + } + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + response = self._create_error_response( + "Method Not Allowed", + HTTPStatus.METHOD_NOT_ALLOWED, + headers=headers, + ) + await response(request.scope, request.receive, send) + + async def _validate_request_headers(self, request: Request, send: Send) -> bool: + if not await self._validate_session(request, send): + return False + if not await self._validate_protocol_version(request, send): + return False + return True + + async def _validate_session(self, request: Request, send: Send) -> bool: + """Validate the session ID in the request.""" + if not self.mcp_session_id: + # If we're not using session IDs, return True + return True + + # Get the session ID from the request headers + request_session_id = self._get_session_id(request) + + # If no session ID provided but required, return error + if not request_session_id: + response = self._create_error_response( + "Bad Request: Missing session ID", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + # If session ID doesn't match, return error + if request_session_id != self.mcp_session_id: + response = self._create_error_response( + "Not Found: Invalid or expired session ID", + HTTPStatus.NOT_FOUND, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _validate_protocol_version(self, request: Request, send: Send) -> bool: + """Validate the protocol version header in the request.""" + # Get the protocol version from the request headers + protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) + + # If no protocol version provided, assume default version + if protocol_version is None: + protocol_version = DEFAULT_NEGOTIATED_VERSION + + # Check if the protocol version is supported + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: + supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) + response = self._create_error_response( + f"Bad Request: Unsupported protocol version: {protocol_version}. " + + f"Supported versions: {supported_versions}", + HTTPStatus.BAD_REQUEST, + ) + await response(request.scope, request.receive, send) + return False + + return True + + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: + """ + Replays events that would have been sent after the specified event ID. + Only used when resumability is enabled. + """ + event_store = self._event_store + if not event_store: + return + + try: + headers = { + "Cache-Control": "no-cache, no-transform", + "Connection": "keep-alive", + "Content-Type": CONTENT_TYPE_SSE, + } + + if self.mcp_session_id: + headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id + + # Create SSE stream for replay + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + + async def replay_sender(): + try: + async with sse_stream_writer: + # Define an async callback for sending events + async def send_event(event_message: EventMessage) -> None: + event_data = self._create_event_data(event_message) + await sse_stream_writer.send(event_data) + + # Replay past events and get the stream ID + stream_id = await event_store.replay_events_after(last_event_id, send_event) + + # If stream ID not in mapping, create it + if stream_id and stream_id not in self._request_streams: + self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) + msg_reader = self._request_streams[stream_id][1] + + # Forward messages to SSE + async with msg_reader: + async for event_message in msg_reader: + event_data = self._create_event_data(event_message) + + await sse_stream_writer.send(event_data) + except Exception: + logger.exception("Error in replay sender") + + # Create and start EventSourceResponse + response = EventSourceResponse( + content=sse_stream_reader, + data_sender_callable=replay_sender, + headers=headers, + ) + + try: + await response(request.scope, request.receive, send) + except Exception: + logger.exception("Error in replay response") + finally: + await sse_stream_writer.aclose() + await sse_stream_reader.aclose() + + except Exception: + logger.exception("Error replaying events") + response = self._create_error_response( + "Error replaying events", + HTTPStatus.INTERNAL_SERVER_ERROR, + INTERNAL_ERROR, + ) + await response(request.scope, request.receive, send) + + @asynccontextmanager + async def connect( + self, + ) -> AsyncGenerator[ + tuple[ + MemoryObjectReceiveStream[SessionMessage | Exception], + MemoryObjectSendStream[SessionMessage], + ], + None, + ]: + """Context manager that provides read and write streams for a connection. + + Yields: + Tuple of (read_stream, write_stream) for bidirectional communication + """ + + # Create the memory streams for this connection + + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + # Store the streams + self._read_stream_writer = read_stream_writer + self._read_stream = read_stream + self._write_stream_reader = write_stream_reader + self._write_stream = write_stream + + # Start a task group for message routing + async with anyio.create_task_group() as tg: + # Create a message router that distributes messages to request streams + async def message_router(): + try: + async for session_message in write_stream_reader: + # Determine which request stream(s) should receive this message + message = session_message.message + target_request_id = None + # Check if this is a response + if isinstance(message.root, JSONRPCResponse | JSONRPCError): + response_id = str(message.root.id) + # If this response is for an existing request stream, + # send it there + target_request_id = response_id + else: + # Extract related_request_id from meta if it exists + if ( + session_message.metadata is not None + and isinstance( + session_message.metadata, + ServerMessageMetadata, + ) + and session_message.metadata.related_request_id is not None + ): + target_request_id = str(session_message.metadata.related_request_id) + + request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY + + # Store the event if we have an event store, + # regardless of whether a client is connected + # messages will be replayed on the re-connect + event_id = None + if self._event_store: + event_id = await self._event_store.store_event(request_stream_id, message) + logger.debug(f"Stored {event_id} from {request_stream_id}") + + if request_stream_id in self._request_streams: + try: + # Send both the message and the event ID + await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + ): + # Stream might be closed, remove from registry + self._request_streams.pop(request_stream_id, None) + else: + logging.debug( + f"""Request stream {request_stream_id} not found + for message. Still processing message as the client + might reconnect and replay.""" + ) + except Exception: + logger.exception("Error in message router") + + # Start the message router + tg.start_soon(message_router) + + try: + # Yield the streams for the caller to use + yield read_stream, write_stream + finally: + for stream_id in list(self._request_streams.keys()): + await self._clean_up_memory_streams(stream_id) + self._request_streams.clear() + + # Clean up the read and write streams + try: + await read_stream_writer.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() + await write_stream.aclose() + except Exception as e: + # During cleanup, we catch all exceptions since streams might be in various states + logger.debug(f"Error closing streams: {e}") diff --git a/src/mcp/mcp/server/streamable_http_manager.py b/src/mcp/mcp/server/streamable_http_manager.py new file mode 100644 index 000000000..53d542d21 --- /dev/null +++ b/src/mcp/mcp/server/streamable_http_manager.py @@ -0,0 +1,279 @@ +"""StreamableHTTP Session Manager for MCP servers.""" + +from __future__ import annotations + +import contextlib +import logging +from collections.abc import AsyncIterator +from http import HTTPStatus +from typing import Any +from uuid import uuid4 + +import anyio +from anyio.abc import TaskStatus +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import Receive, Scope, Send + +from mcp.server.lowlevel.server import Server as MCPServer +from mcp.server.streamable_http import ( + MCP_SESSION_ID_HEADER, + EventStore, + StreamableHTTPServerTransport, +) +from mcp.server.transport_security import TransportSecuritySettings + +logger = logging.getLogger(__name__) + + +class StreamableHTTPSessionManager: + """ + Manages StreamableHTTP sessions with optional resumability via event store. + + This class abstracts away the complexity of session management, event storage, + and request handling for StreamableHTTP transports. It handles: + + 1. Session tracking for clients + 2. Resumability via an optional event store + 3. Connection management and lifecycle + 4. Request handling and transport setup + + Important: Only one StreamableHTTPSessionManager instance should be created + per application. The instance cannot be reused after its run() context has + completed. If you need to restart the manager, create a new instance. + + Args: + app: The MCP server instance + event_store: Optional event store for resumability support. + If provided, enables resumable connections where clients + can reconnect and receive missed events. + If None, sessions are still tracked but not resumable. + json_response: Whether to use JSON responses instead of SSE streams + stateless: If True, creates a completely fresh transport for each request + with no session tracking or state persistence between requests. + """ + + def __init__( + self, + app: MCPServer[Any, Any], + event_store: EventStore | None = None, + json_response: bool = False, + stateless: bool = False, + security_settings: TransportSecuritySettings | None = None, + ): + self.app = app + self.event_store = event_store + self.json_response = json_response + self.stateless = stateless + self.security_settings = security_settings + + # Session tracking (only used if not stateless) + self._session_creation_lock = anyio.Lock() + self._server_instances: dict[str, StreamableHTTPServerTransport] = {} + + # The task group will be set during lifespan + self._task_group = None + # Thread-safe tracking of run() calls + self._run_lock = anyio.Lock() + self._has_started = False + + @contextlib.asynccontextmanager + async def run(self) -> AsyncIterator[None]: + """ + Run the session manager with proper lifecycle management. + + This creates and manages the task group for all session operations. + + Important: This method can only be called once per instance. The same + StreamableHTTPSessionManager instance cannot be reused after this + context manager exits. Create a new instance if you need to restart. + + Use this in the lifespan context manager of your Starlette app: + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[None]: + async with session_manager.run(): + yield + """ + # Thread-safe check to ensure run() is only called once + async with self._run_lock: + if self._has_started: + raise RuntimeError( + "StreamableHTTPSessionManager .run() can only be called " + "once per instance. Create a new instance if you need to run again." + ) + self._has_started = True + + async with anyio.create_task_group() as tg: + # Store the task group for later use + self._task_group = tg + logger.info("StreamableHTTP session manager started") + try: + yield # Let the application run + finally: + logger.info("StreamableHTTP session manager shutting down") + # Cancel task group to stop all spawned tasks + tg.cancel_scope.cancel() + self._task_group = None + # Clear any remaining server instances + self._server_instances.clear() + + async def handle_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process ASGI request with proper session handling and transport setup. + + Dispatches to the appropriate handler based on stateless mode. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + if self._task_group is None: + raise RuntimeError("Task group is not initialized. Make sure to use run().") + + # Dispatch to the appropriate handler + if self.stateless: + await self._handle_stateless_request(scope, receive, send) + else: + await self._handle_stateful_request(scope, receive, send) + + async def _handle_stateless_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateless mode - creating a new transport for each request. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + logger.debug("Stateless mode: Creating new transport for this request") + # No session ID needed in stateless mode + http_transport = StreamableHTTPServerTransport( + mcp_session_id=None, # No session tracking in stateless mode + is_json_response_enabled=self.json_response, + event_store=None, # No event store in stateless mode + security_settings=self.security_settings, + ) + + # Start server in a new task + async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + try: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=True, + ) + except Exception: + logger.exception("Stateless session crashed") + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_stateless_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + + # Terminate the transport after the request is handled + await http_transport.terminate() + + async def _handle_stateful_request( + self, + scope: Scope, + receive: Receive, + send: Send, + ) -> None: + """ + Process request in stateful mode - maintaining session state between requests. + + Args: + scope: ASGI scope + receive: ASGI receive function + send: ASGI send function + """ + request = Request(scope, receive) + request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) + + # Existing session case + if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: + transport = self._server_instances[request_mcp_session_id] + logger.debug("Session already exists, handling request directly") + await transport.handle_request(scope, receive, send) + return + + if request_mcp_session_id is None: + # New session case + logger.debug("Creating new transport") + async with self._session_creation_lock: + new_session_id = uuid4().hex + http_transport = StreamableHTTPServerTransport( + mcp_session_id=new_session_id, + is_json_response_enabled=self.json_response, + event_store=self.event_store, # May be None (no resumability) + security_settings=self.security_settings, + ) + + assert http_transport.mcp_session_id is not None + self._server_instances[http_transport.mcp_session_id] = http_transport + logger.info(f"Created new transport with session ID: {new_session_id}") + + # Define the server runner + async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + async with http_transport.connect() as streams: + read_stream, write_stream = streams + task_status.started() + try: + await self.app.run( + read_stream, + write_stream, + self.app.create_initialization_options(), + stateless=False, # Stateful mode + ) + except Exception as e: + logger.error( + f"Session {http_transport.mcp_session_id} crashed: {e}", + exc_info=True, + ) + finally: + # Only remove from instances if not terminated + if ( + http_transport.mcp_session_id + and http_transport.mcp_session_id in self._server_instances + and not http_transport.is_terminated + ): + logger.info( + "Cleaning up crashed session " + f"{http_transport.mcp_session_id} from " + "active instances." + ) + del self._server_instances[http_transport.mcp_session_id] + + # Assert task group is not None for type checking + assert self._task_group is not None + # Start the server task + await self._task_group.start(run_server) + + # Handle the HTTP request and return the response + await http_transport.handle_request(scope, receive, send) + else: + # Invalid session ID + response = Response( + "Bad Request: No valid session ID provided", + status_code=HTTPStatus.BAD_REQUEST, + ) + await response(scope, receive, send) diff --git a/src/mcp/mcp/server/streaming_asgi_transport.py b/src/mcp/mcp/server/streaming_asgi_transport.py new file mode 100644 index 000000000..a74751312 --- /dev/null +++ b/src/mcp/mcp/server/streaming_asgi_transport.py @@ -0,0 +1,203 @@ +""" +A modified version of httpx.ASGITransport that supports streaming responses. + +This transport runs the ASGI app as a separate anyio task, allowing it to +handle streaming responses like SSE where the app doesn't terminate until +the connection is closed. + +This is only intended for writing tests for the SSE transport. +""" + +import typing +from typing import Any, cast + +import anyio +import anyio.abc +import anyio.streams.memory +from httpx._models import Request, Response +from httpx._transports.base import AsyncBaseTransport +from httpx._types import AsyncByteStream +from starlette.types import ASGIApp, Receive, Scope, Send + + +class StreamingASGITransport(AsyncBaseTransport): + """ + A custom AsyncTransport that handles sending requests directly to an ASGI app + and supports streaming responses like SSE. + + Unlike the standard ASGITransport, this transport runs the ASGI app in a + separate anyio task, allowing it to handle responses from apps that don't + terminate immediately (like SSE endpoints). + + Arguments: + + * `app` - The ASGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `root_path` - The root path on which the ASGI application should be mounted. + * `client` - A two-tuple indicating the client IP and port of incoming requests. + * `response_timeout` - Timeout in seconds to wait for the initial response. + Default is 10 seconds. + + TODO: https://github.com/encode/httpx/pull/3059 is adding something similar to + upstream httpx. When that merges, we should delete this & switch back to the + upstream implementation. + """ + + def __init__( + self, + app: ASGIApp, + task_group: anyio.abc.TaskGroup, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.root_path = root_path + self.client = client + self.task_group = task_group + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?")[0], + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": self.client, + "root_path": self.root_path, + } + + # Request body + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response state + status_code = 499 + response_headers = None + response_started = False + response_complete = anyio.Event() + initial_response_ready = anyio.Event() + + # Synchronization for streaming response + asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100) + content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) + + # ASGI callables. + async def receive() -> dict[str, Any]: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: dict[str, Any]) -> None: + nonlocal status_code, response_headers, response_started + + await asgi_send_channel.send(message) + + # Start the ASGI application in a separate task + async def run_app() -> None: + try: + # Cast the receive and send functions to the ASGI types + await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send)) + except Exception: + if self.raise_app_exceptions: + raise + + if not response_started: + await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []}) + + await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False}) + finally: + await asgi_send_channel.aclose() + + # Process messages from the ASGI app + async def process_messages() -> None: + nonlocal status_code, response_headers, response_started + + try: + async with asgi_receive_channel: + async for message in asgi_receive_channel: + if message["type"] == "http.response.start": + assert not response_started + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + # As soon as we have headers, we can return a response + initial_response_ready.set() + + elif message["type"] == "http.response.body": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + await content_send_channel.send(body) + + if not more_body: + response_complete.set() + await content_send_channel.aclose() + break + finally: + # Ensure events are set even if there's an error + initial_response_ready.set() + response_complete.set() + await content_send_channel.aclose() + + # Create tasks for running the app and processing messages + self.task_group.start_soon(run_app) + self.task_group.start_soon(process_messages) + + # Wait for the initial response or timeout + await initial_response_ready.wait() + + # Create a streaming response + return Response( + status_code, + headers=response_headers, + stream=StreamingASGIResponseStream(content_receive_channel), + ) + + +class StreamingASGIResponseStream(AsyncByteStream): + """ + A modified ASGIResponseStream that supports streaming responses. + + This class extends the standard ASGIResponseStream to handle cases where + the response body continues to be generated after the initial response + is returned. + """ + + def __init__( + self, + receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], + ) -> None: + self.receive_channel = receive_channel + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + try: + async for chunk in self.receive_channel: + yield chunk + finally: + await self.receive_channel.aclose() diff --git a/src/mcp/mcp/server/transport_security.py b/src/mcp/mcp/server/transport_security.py new file mode 100644 index 000000000..3a884ee2b --- /dev/null +++ b/src/mcp/mcp/server/transport_security.py @@ -0,0 +1,127 @@ +"""DNS rebinding protection for MCP server transports.""" + +import logging + +from pydantic import BaseModel, Field +from starlette.requests import Request +from starlette.responses import Response + +logger = logging.getLogger(__name__) + + +class TransportSecuritySettings(BaseModel): + """Settings for MCP transport security features. + + These settings help protect against DNS rebinding attacks by validating + incoming request headers. + """ + + enable_dns_rebinding_protection: bool = Field( + default=True, + description="Enable DNS rebinding protection (recommended for production)", + ) + + allowed_hosts: list[str] = Field( + default=[], + description="List of allowed Host header values. Only applies when " + + "enable_dns_rebinding_protection is True.", + ) + + allowed_origins: list[str] = Field( + default=[], + description="List of allowed Origin header values. Only applies when " + + "enable_dns_rebinding_protection is True.", + ) + + +class TransportSecurityMiddleware: + """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" + + def __init__(self, settings: TransportSecuritySettings | None = None): + # If not specified, disable DNS rebinding protection by default + # for backwards compatibility + self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) + + def _validate_host(self, host: str | None) -> bool: + """Validate the Host header against allowed values.""" + if not host: + logger.warning("Missing Host header in request") + return False + + # Check exact match first + if host in self.settings.allowed_hosts: + return True + + # Check wildcard port patterns + for allowed in self.settings.allowed_hosts: + if allowed.endswith(":*"): + # Extract base host from pattern + base_host = allowed[:-2] + # Check if the actual host starts with base host and has a port + if host.startswith(base_host + ":"): + return True + + logger.warning(f"Invalid Host header: {host}") + return False + + def _validate_origin(self, origin: str | None) -> bool: + """Validate the Origin header against allowed values.""" + # Origin can be absent for same-origin requests + if not origin: + return True + + # Check exact match first + if origin in self.settings.allowed_origins: + return True + + # Check wildcard port patterns + for allowed in self.settings.allowed_origins: + if allowed.endswith(":*"): + # Extract base origin from pattern + base_origin = allowed[:-2] + # Check if the actual origin starts with base origin and has a port + if origin.startswith(base_origin + ":"): + return True + + logger.warning(f"Invalid Origin header: {origin}") + return False + + def _validate_content_type(self, content_type: str | None) -> bool: + """Validate the Content-Type header for POST requests.""" + if not content_type: + logger.warning("Missing Content-Type header in POST request") + return False + + # Content-Type must start with application/json + if not content_type.lower().startswith("application/json"): + logger.warning(f"Invalid Content-Type header: {content_type}") + return False + + return True + + async def validate_request(self, request: Request, is_post: bool = False) -> Response | None: + """Validate request headers for DNS rebinding protection. + + Returns None if validation passes, or an error Response if validation fails. + """ + # Always validate Content-Type for POST requests + if is_post: + content_type = request.headers.get("content-type") + if not self._validate_content_type(content_type): + return Response("Invalid Content-Type header", status_code=400) + + # Skip remaining validation if DNS rebinding protection is disabled + if not self.settings.enable_dns_rebinding_protection: + return None + + # Validate Host header + host = request.headers.get("host") + if not self._validate_host(host): + return Response("Invalid Host header", status_code=421) + + # Validate Origin header + origin = request.headers.get("origin") + if not self._validate_origin(origin): + return Response("Invalid Origin header", status_code=400) + + return None diff --git a/src/mcp/mcp/server/websocket.py b/src/mcp/mcp/server/websocket.py new file mode 100644 index 000000000..7c0d8789c --- /dev/null +++ b/src/mcp/mcp/server/websocket.py @@ -0,0 +1,62 @@ +import logging +from contextlib import asynccontextmanager + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic_core import ValidationError +from starlette.types import Receive, Scope, Send +from starlette.websockets import WebSocket + +import mcp.types as types +from mcp.shared.message import SessionMessage + +logger = logging.getLogger(__name__) + + +@asynccontextmanager +async def websocket_server(scope: Scope, receive: Receive, send: Send): + """ + WebSocket server transport for MCP. This is an ASGI application, suitable to be + used with a framework like Starlette and a server like Hypercorn. + """ + + websocket = WebSocket(scope, receive, send) + await websocket.accept(subprotocol="mcp") + + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] + read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] + + write_stream: MemoryObjectSendStream[SessionMessage] + write_stream_reader: MemoryObjectReceiveStream[SessionMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def ws_reader(): + try: + async with read_stream_writer: + async for msg in websocket.iter_text(): + try: + client_message = types.JSONRPCMessage.model_validate_json(msg) + except ValidationError as exc: + await read_stream_writer.send(exc) + continue + + session_message = SessionMessage(client_message) + await read_stream_writer.send(session_message) + except anyio.ClosedResourceError: + await websocket.close() + + async def ws_writer(): + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) + await websocket.send_text(obj) + except anyio.ClosedResourceError: + await websocket.close() + + async with anyio.create_task_group() as tg: + tg.start_soon(ws_reader) + tg.start_soon(ws_writer) + yield (read_stream, write_stream) diff --git a/src/mcp/mcp/shared/__init__.py b/src/mcp/mcp/shared/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mcp/mcp/shared/_httpx_utils.py b/src/mcp/mcp/shared/_httpx_utils.py new file mode 100644 index 000000000..e0611ce73 --- /dev/null +++ b/src/mcp/mcp/shared/_httpx_utils.py @@ -0,0 +1,83 @@ +"""Utilities for creating standardized httpx AsyncClient instances.""" + +from typing import Any, Protocol + +import httpx + +__all__ = ["create_mcp_http_client"] + + +class McpHttpClientFactory(Protocol): + def __call__( + self, + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, + ) -> httpx.AsyncClient: ... + + +def create_mcp_http_client( + headers: dict[str, str] | None = None, + timeout: httpx.Timeout | None = None, + auth: httpx.Auth | None = None, +) -> httpx.AsyncClient: + """Create a standardized httpx AsyncClient with MCP defaults. + + This function provides common defaults used throughout the MCP codebase: + - follow_redirects=True (always enabled) + - Default timeout of 30 seconds if not specified + + Args: + headers: Optional headers to include with all requests. + timeout: Request timeout as httpx.Timeout object. + Defaults to 30 seconds if not specified. + auth: Optional authentication handler. + + Returns: + Configured httpx.AsyncClient instance with MCP defaults. + + Note: + The returned AsyncClient must be used as a context manager to ensure + proper cleanup of connections. + + Examples: + # Basic usage with MCP defaults + async with create_mcp_http_client() as client: + response = await client.get("https://api.example.com") + + # With custom headers + headers = {"Authorization": "Bearer token"} + async with create_mcp_http_client(headers) as client: + response = await client.get("/endpoint") + + # With both custom headers and timeout + timeout = httpx.Timeout(60.0, read=300.0) + async with create_mcp_http_client(headers, timeout) as client: + response = await client.get("/long-request") + + # With authentication + from httpx import BasicAuth + auth = BasicAuth(username="user", password="pass") + async with create_mcp_http_client(headers, timeout, auth) as client: + response = await client.get("/protected-endpoint") + """ + # Set MCP defaults + kwargs: dict[str, Any] = { + "follow_redirects": True, + } + + # Handle timeout + if timeout is None: + kwargs["timeout"] = httpx.Timeout(30.0) + else: + kwargs["timeout"] = timeout + + # Handle headers + if headers is not None: + kwargs["headers"] = headers + + # Handle authentication + if auth is not None: + kwargs["auth"] = auth + + return httpx.AsyncClient(**kwargs) diff --git a/src/mcp/mcp/shared/auth.py b/src/mcp/mcp/shared/auth.py new file mode 100644 index 000000000..6bf15b531 --- /dev/null +++ b/src/mcp/mcp/shared/auth.py @@ -0,0 +1,155 @@ +from typing import Any, Literal + +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator + + +class OAuthToken(BaseModel): + """ + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 + """ + + access_token: str + token_type: Literal["Bearer"] = "Bearer" + expires_in: int | None = None + scope: str | None = None + refresh_token: str | None = None + + @field_validator("token_type", mode="before") + @classmethod + def normalize_token_type(cls, v: str | None) -> str | None: + if isinstance(v, str): + # Bearer is title-cased in the spec, so we normalize it + # https://datatracker.ietf.org/doc/html/rfc6750#section-4 + return v.title() + return v + + +class InvalidScopeError(Exception): + def __init__(self, message: str): + self.message = message + + +class InvalidRedirectUriError(Exception): + def __init__(self, message: str): + self.message = message + + +class OAuthClientMetadata(BaseModel): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. + See https://datatracker.ietf.org/doc/html/rfc7591#section-2 + for the full specification. + """ + + redirect_uris: list[AnyUrl] = Field(..., min_length=1) + # token_endpoint_auth_method: this implementation only supports none & + # client_secret_post; + # ie: we do not support client_secret_basic + token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" + # grant_types: this implementation only supports authorization_code & refresh_token + grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + "authorization_code", + "refresh_token", + ] + # this implementation only supports code; ie: it does not support implicit grants + response_types: list[Literal["code"]] = ["code"] + scope: str | None = None + + # these fields are currently unused, but we support & store them for potential + # future use + client_name: str | None = None + client_uri: AnyHttpUrl | None = None + logo_uri: AnyHttpUrl | None = None + contacts: list[str] | None = None + tos_uri: AnyHttpUrl | None = None + policy_uri: AnyHttpUrl | None = None + jwks_uri: AnyHttpUrl | None = None + jwks: Any | None = None + software_id: str | None = None + software_version: str | None = None + + def validate_scope(self, requested_scope: str | None) -> list[str] | None: + if requested_scope is None: + return None + requested_scopes = requested_scope.split(" ") + allowed_scopes = [] if self.scope is None else self.scope.split(" ") + for scope in requested_scopes: + if scope not in allowed_scopes: + raise InvalidScopeError(f"Client was not registered with scope {scope}") + return requested_scopes + + def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: + if redirect_uri is not None: + # Validate redirect_uri against client's registered redirect URIs + if redirect_uri not in self.redirect_uris: + raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") + return redirect_uri + elif len(self.redirect_uris) == 1: + return self.redirect_uris[0] + else: + raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs") + + +class OAuthClientInformationFull(OAuthClientMetadata): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration full response + (client information plus metadata). + """ + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + +class OAuthMetadata(BaseModel): + """ + RFC 8414 OAuth 2.0 Authorization Server Metadata. + See https://datatracker.ietf.org/doc/html/rfc8414#section-2 + """ + + issuer: AnyHttpUrl + authorization_endpoint: AnyHttpUrl + token_endpoint: AnyHttpUrl + registration_endpoint: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + response_types_supported: list[str] = ["code"] + response_modes_supported: list[str] | None = None + grant_types_supported: list[str] | None = None + token_endpoint_auth_methods_supported: list[str] | None = None + token_endpoint_auth_signing_alg_values_supported: list[str] | None = None + service_documentation: AnyHttpUrl | None = None + ui_locales_supported: list[str] | None = None + op_policy_uri: AnyHttpUrl | None = None + op_tos_uri: AnyHttpUrl | None = None + revocation_endpoint: AnyHttpUrl | None = None + revocation_endpoint_auth_methods_supported: list[str] | None = None + revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None + introspection_endpoint: AnyHttpUrl | None = None + introspection_endpoint_auth_methods_supported: list[str] | None = None + introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None + + +class ProtectedResourceMetadata(BaseModel): + """ + RFC 9728 OAuth 2.0 Protected Resource Metadata. + See https://datatracker.ietf.org/doc/html/rfc9728#section-2 + """ + + resource: AnyHttpUrl + authorization_servers: list[AnyHttpUrl] = Field(..., min_length=1) + jwks_uri: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + bearer_methods_supported: list[str] | None = Field(default=["header"]) # MCP only supports header method + resource_signing_alg_values_supported: list[str] | None = None + resource_name: str | None = None + resource_documentation: AnyHttpUrl | None = None + resource_policy_uri: AnyHttpUrl | None = None + resource_tos_uri: AnyHttpUrl | None = None + # tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity + tls_client_certificate_bound_access_tokens: bool | None = None + authorization_details_types_supported: list[str] | None = None + dpop_signing_alg_values_supported: list[str] | None = None + # dpop_bound_access_tokens_required default is False, but ommited here for clarity + dpop_bound_access_tokens_required: bool | None = None diff --git a/src/mcp/mcp/shared/auth_utils.py b/src/mcp/mcp/shared/auth_utils.py new file mode 100644 index 000000000..6d6300c9c --- /dev/null +++ b/src/mcp/mcp/shared/auth_utils.py @@ -0,0 +1,69 @@ +"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707).""" + +from urllib.parse import urlparse, urlsplit, urlunsplit + +from pydantic import AnyUrl, HttpUrl + + +def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: + """Convert server URL to canonical resource URL per RFC 8707. + + RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". + Returns absolute URI with lowercase scheme/host for canonical form. + + Args: + url: Server URL to convert + + Returns: + Canonical resource URL string + """ + # Convert to string if needed + url_str = str(url) + + # Parse the URL and remove fragment, create canonical form + parsed = urlsplit(url_str) + canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) + + return canonical + + +def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: + """Check if a requested resource URL matches a configured resource URL. + + A requested resource matches if it has the same scheme, domain, port, + and its path starts with the configured resource's path. This allows + hierarchical matching where a token for a parent resource can be used + for child resources. + + Args: + requested_resource: The resource URL being requested + configured_resource: The resource URL that has been configured + + Returns: + True if the requested resource matches the configured resource + """ + # Parse both URLs + requested = urlparse(requested_resource) + configured = urlparse(configured_resource) + + # Compare scheme, host, and port (origin) + if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): + return False + + # Handle cases like requested=/foo and configured=/foo/ + requested_path = requested.path + configured_path = configured.path + + # If requested path is shorter, it cannot be a child + if len(requested_path) < len(configured_path): + return False + + # Check if the requested path starts with the configured path + # Ensure both paths end with / for proper comparison + # This ensures that paths like "/api123" don't incorrectly match "/api" + if not requested_path.endswith("/"): + requested_path += "/" + if not configured_path.endswith("/"): + configured_path += "/" + + return requested_path.startswith(configured_path) diff --git a/src/mcp/mcp/shared/context.py b/src/mcp/mcp/shared/context.py new file mode 100644 index 000000000..f3006e7d5 --- /dev/null +++ b/src/mcp/mcp/shared/context.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Any, Generic + +from typing_extensions import TypeVar + +from mcp.shared.session import BaseSession +from mcp.types import RequestId, RequestParams + +SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +LifespanContextT = TypeVar("LifespanContextT") +RequestT = TypeVar("RequestT", default=Any) + + +@dataclass +class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): + request_id: RequestId + meta: RequestParams.Meta | None + session: SessionT + lifespan_context: LifespanContextT + request: RequestT | None = None diff --git a/src/mcp/mcp/shared/exceptions.py b/src/mcp/mcp/shared/exceptions.py new file mode 100644 index 000000000..97a1c09a9 --- /dev/null +++ b/src/mcp/mcp/shared/exceptions.py @@ -0,0 +1,14 @@ +from mcp.types import ErrorData + + +class McpError(Exception): + """ + Exception type raised when an error arrives over an MCP connection. + """ + + error: ErrorData + + def __init__(self, error: ErrorData): + """Initialize McpError.""" + super().__init__(error.message) + self.error = error diff --git a/src/mcp/mcp/shared/memory.py b/src/mcp/mcp/shared/memory.py new file mode 100644 index 000000000..c94e5e6ac --- /dev/null +++ b/src/mcp/mcp/shared/memory.py @@ -0,0 +1,99 @@ +""" +In-memory transports +""" + +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import Any + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream + +import mcp.types as types +from mcp.client.session import ( + ClientSession, + ElicitationFnT, + ListRootsFnT, + LoggingFnT, + MessageHandlerFnT, + SamplingFnT, +) +from mcp.server import Server +from mcp.shared.message import SessionMessage + +MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] + + +@asynccontextmanager +async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: + """ + Creates a pair of bidirectional memory streams for client-server communication. + + Returns: + A tuple of (client_streams, server_streams) where each is a tuple of + (read_stream, write_stream) + """ + # Create streams for both directions + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) + + client_streams = (server_to_client_receive, client_to_server_send) + server_streams = (client_to_server_receive, server_to_client_send) + + async with ( + server_to_client_receive, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + ): + yield client_streams, server_streams + + +@asynccontextmanager +async def create_connected_server_and_client_session( + server: Server[Any], + read_timeout_seconds: timedelta | None = None, + sampling_callback: SamplingFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, + raise_exceptions: bool = False, + elicitation_callback: ElicitationFnT | None = None, +) -> AsyncGenerator[ClientSession, None]: + """Creates a ClientSession that is connected to a running MCP server.""" + async with create_client_server_memory_streams() as ( + client_streams, + server_streams, + ): + client_read, client_write = client_streams + server_read, server_write = server_streams + + # Create a cancel scope for the server task + async with anyio.create_task_group() as tg: + tg.start_soon( + lambda: server.run( + server_read, + server_write, + server.create_initialization_options(), + raise_exceptions=raise_exceptions, + ) + ) + + try: + async with ClientSession( + read_stream=client_read, + write_stream=client_write, + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + elicitation_callback=elicitation_callback, + ) as client_session: + await client_session.initialize() + yield client_session + finally: + tg.cancel_scope.cancel() diff --git a/src/mcp/mcp/shared/message.py b/src/mcp/mcp/shared/message.py new file mode 100644 index 000000000..4b6df23eb --- /dev/null +++ b/src/mcp/mcp/shared/message.py @@ -0,0 +1,43 @@ +""" +Message wrapper with metadata support. + +This module defines a wrapper type that combines JSONRPCMessage with metadata +to support transport-specific features like resumability. +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from mcp.types import JSONRPCMessage, RequestId + +ResumptionToken = str + +ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] + + +@dataclass +class ClientMessageMetadata: + """Metadata specific to client messages.""" + + resumption_token: ResumptionToken | None = None + on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None + + +@dataclass +class ServerMessageMetadata: + """Metadata specific to server messages.""" + + related_request_id: RequestId | None = None + # Request-specific context (e.g., headers, auth info) + request_context: object | None = None + + +MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None + + +@dataclass +class SessionMessage: + """A message with specific metadata for transport-specific features.""" + + message: JSONRPCMessage + metadata: MessageMetadata = None diff --git a/src/mcp/mcp/shared/metadata_utils.py b/src/mcp/mcp/shared/metadata_utils.py new file mode 100644 index 000000000..e3f49daf4 --- /dev/null +++ b/src/mcp/mcp/shared/metadata_utils.py @@ -0,0 +1,45 @@ +"""Utility functions for working with metadata in MCP types. + +These utilities are primarily intended for client-side usage to properly display +human-readable names in user interfaces in a spec compliant way. +""" + +from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool + + +def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implementation) -> str: + """ + Get the display name for an MCP object with proper precedence. + + This is a client-side utility function designed to help MCP clients display + human-readable names in their user interfaces. When servers provide a 'title' + field, it should be preferred over the programmatic 'name' field for display. + + For tools: title > annotations.title > name + For other objects: title > name + + Example: + # In a client displaying available tools + tools = await session.list_tools() + for tool in tools.tools: + display_name = get_display_name(tool) + print(f"Available tool: {display_name}") + + Args: + obj: An MCP object with name and optional title fields + + Returns: + The display name to use for UI presentation + """ + if isinstance(obj, Tool): + # Tools have special precedence: title > annotations.title > name + if hasattr(obj, "title") and obj.title is not None: + return obj.title + if obj.annotations and hasattr(obj.annotations, "title") and obj.annotations.title is not None: + return obj.annotations.title + return obj.name + else: + # All other objects: title > name + if hasattr(obj, "title") and obj.title is not None: + return obj.title + return obj.name diff --git a/src/mcp/mcp/shared/progress.py b/src/mcp/mcp/shared/progress.py new file mode 100644 index 000000000..1ad81a779 --- /dev/null +++ b/src/mcp/mcp/shared/progress.py @@ -0,0 +1,58 @@ +from collections.abc import Generator +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Generic + +from pydantic import BaseModel + +from mcp.shared.context import LifespanContextT, RequestContext +from mcp.shared.session import ( + BaseSession, + ReceiveNotificationT, + ReceiveRequestT, + SendNotificationT, + SendRequestT, + SendResultT, +) +from mcp.types import ProgressToken + + +class Progress(BaseModel): + progress: float + total: float | None + + +@dataclass +class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]): + session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] + progress_token: ProgressToken + total: float | None + current: float = field(default=0.0, init=False) + + async def progress(self, amount: float, message: str | None = None) -> None: + self.current += amount + + await self.session.send_progress_notification( + self.progress_token, self.current, total=self.total, message=message + ) + + +@contextmanager +def progress( + ctx: RequestContext[ + BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], + LifespanContextT, + ], + total: float | None = None, +) -> Generator[ + ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], + None, +]: + if ctx.meta is None or ctx.meta.progressToken is None: + raise ValueError("No progress token provided") + + progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total) + try: + yield progress_ctx + finally: + pass diff --git a/src/mcp/mcp/shared/session.py b/src/mcp/mcp/shared/session.py new file mode 100644 index 000000000..b2f49fc8b --- /dev/null +++ b/src/mcp/mcp/shared/session.py @@ -0,0 +1,470 @@ +import logging +from collections.abc import Callable +from contextlib import AsyncExitStack +from datetime import timedelta +from types import TracebackType +from typing import Any, Generic, Protocol, TypeVar + +import anyio +import httpx +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from pydantic import BaseModel +from typing_extensions import Self + +from mcp.shared.exceptions import McpError +from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + CancelledNotification, + ClientNotification, + ClientRequest, + ClientResult, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ProgressNotification, + RequestParams, + ServerNotification, + ServerRequest, + ServerResult, +) + +SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) +SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) +SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) +ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) +ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) + +RequestId = str | int + + +class ProgressFnT(Protocol): + """Protocol for progress notification callbacks.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class RequestResponder(Generic[ReceiveRequestT, SendResultT]): + """Handles responding to MCP requests and manages request lifecycle. + + This class MUST be used as a context manager to ensure proper cleanup and + cancellation handling: + + Example: + with request_responder as resp: + await resp.respond(result) + + The context manager ensures: + 1. Proper cancellation scope setup and cleanup + 2. Request completion tracking + 3. Cleanup of in-flight requests + """ + + def __init__( + self, + request_id: RequestId, + request_meta: RequestParams.Meta | None, + request: ReceiveRequestT, + session: """BaseSession[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT + ]""", + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], + message_metadata: MessageMetadata = None, + ) -> None: + self.request_id = request_id + self.request_meta = request_meta + self.request = request + self.message_metadata = message_metadata + self._session = session + self._completed = False + self._cancel_scope = anyio.CancelScope() + self._on_complete = on_complete + self._entered = False # Track if we're in a context manager + + def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": + """Enter the context manager, enabling request cancellation tracking.""" + self._entered = True + self._cancel_scope = anyio.CancelScope() + self._cancel_scope.__enter__() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the context manager, performing cleanup and notifying completion.""" + try: + if self._completed: + self._on_complete(self) + finally: + self._entered = False + if not self._cancel_scope: + raise RuntimeError("No active cancel scope") + self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) + + async def respond(self, response: SendResultT | ErrorData) -> None: + """Send a response for this request. + + Must be called within a context manager block. + Raises: + RuntimeError: If not used within a context manager + AssertionError: If request was already responded to + """ + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") + assert not self._completed, "Request already responded to" + + if not self.cancelled: + self._completed = True + + await self._session._send_response( # type: ignore[reportPrivateUsage] + request_id=self.request_id, response=response + ) + + async def cancel(self) -> None: + """Cancel this request and mark it as completed.""" + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") + if not self._cancel_scope: + raise RuntimeError("No active cancel scope") + + self._cancel_scope.cancel() + self._completed = True # Mark as completed so it's removed from in_flight + # Send an error response to indicate cancellation + await self._session._send_response( # type: ignore[reportPrivateUsage] + request_id=self.request_id, + response=ErrorData(code=0, message="Request cancelled", data=None), + ) + + @property + def in_flight(self) -> bool: + return not self._completed and not self.cancelled + + @property + def cancelled(self) -> bool: + return self._cancel_scope.cancel_called + + +class BaseSession( + Generic[ + SendRequestT, + SendNotificationT, + SendResultT, + ReceiveRequestT, + ReceiveNotificationT, + ], +): + """ + Implements an MCP "session" on top of read/write streams, including features + like request/response linking, notifications, and progress. + + This class is an async context manager that automatically starts processing + messages when entered. + """ + + _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] + _request_id: int + _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _progress_callbacks: dict[RequestId, ProgressFnT] + + def __init__( + self, + read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], + write_stream: MemoryObjectSendStream[SessionMessage], + receive_request_type: type[ReceiveRequestT], + receive_notification_type: type[ReceiveNotificationT], + # If none, reading will never time out + read_timeout_seconds: timedelta | None = None, + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + self._response_streams = {} + self._request_id = 0 + self._receive_request_type = receive_request_type + self._receive_notification_type = receive_notification_type + self._session_read_timeout_seconds = read_timeout_seconds + self._in_flight = {} + self._progress_callbacks = {} + self._exit_stack = AsyncExitStack() + + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + self._task_group.start_soon(self._receive_loop) + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + await self._exit_stack.aclose() + # Using BaseSession as a context manager should not block on exit (this + # would be very surprising behavior), so make sure to cancel the tasks + # in the task group. + self._task_group.cancel_scope.cancel() + return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def send_request( + self, + request: SendRequestT, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: timedelta | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ReceiveResultT: + """ + Sends a request and wait for a response. Raises an McpError if the + response contains an error. If a request read timeout is provided, it + will take precedence over the session read timeout. + + Do not use this method to emit notifications! Use send_notification() + instead. + """ + request_id = self._request_id + self._request_id = request_id + 1 + + response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) + self._response_streams[request_id] = response_stream + + # Set up progress token if progress callback is provided + request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + if progress_callback is not None: + # Use request_id as progress token + if "params" not in request_data: + request_data["params"] = {} + if "_meta" not in request_data["params"]: + request_data["params"]["_meta"] = {} + request_data["params"]["_meta"]["progressToken"] = request_id + # Store the callback for this request + self._progress_callbacks[request_id] = progress_callback + + try: + jsonrpc_request = JSONRPCRequest( + jsonrpc="2.0", + id=request_id, + **request_data, + ) + + await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + + # request read timeout takes precedence over session read timeout + timeout = None + if request_read_timeout_seconds is not None: + timeout = request_read_timeout_seconds.total_seconds() + elif self._session_read_timeout_seconds is not None: + timeout = self._session_read_timeout_seconds.total_seconds() + + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + raise McpError( + ErrorData( + code=httpx.codes.REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to " + f"{request.__class__.__name__}. Waited " + f"{timeout} seconds." + ), + ) + ) + + if isinstance(response_or_error, JSONRPCError): + raise McpError(response_or_error.error) + else: + return result_type.model_validate(response_or_error.result) + + finally: + self._response_streams.pop(request_id, None) + self._progress_callbacks.pop(request_id, None) + await response_stream.aclose() + await response_stream_reader.aclose() + + async def send_notification( + self, + notification: SendNotificationT, + related_request_id: RequestId | None = None, + ) -> None: + """ + Emits a notification, which is a one-way message that does not expect + a response. + """ + # Some transport implementations may need to set the related_request_id + # to attribute to the notifications to the request that triggered them. + jsonrpc_notification = JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + session_message = SessionMessage( + message=JSONRPCMessage(jsonrpc_notification), + metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, + ) + await self._write_stream.send(session_message) + + async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: + if isinstance(response, ErrorData): + jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + await self._write_stream.send(session_message) + else: + jsonrpc_response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=response.model_dump(by_alias=True, mode="json", exclude_none=True), + ) + session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + await self._write_stream.send(session_message) + + async def _receive_loop(self) -> None: + async with ( + self._read_stream, + self._write_stream, + ): + try: + async for message in self._read_stream: + if isinstance(message, Exception): + await self._handle_incoming(message) + elif isinstance(message.message.root, JSONRPCRequest): + try: + validated_request = self._receive_request_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + responder = RequestResponder( + request_id=message.message.root.id, + request_meta=validated_request.root.params.meta + if validated_request.root.params + else None, + request=validated_request, + session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), + message_metadata=message.metadata, + ) + self._in_flight[responder.request_id] = responder + await self._received_request(responder) + + if not responder._completed: # type: ignore[reportPrivateUsage] + await self._handle_incoming(responder) + except Exception as e: + # For request validation errors, send a proper JSON-RPC error + # response instead of crashing the server + logging.warning(f"Failed to validate request: {e}") + logging.debug(f"Message that failed validation: {message.message.root}") + error_response = JSONRPCError( + jsonrpc="2.0", + id=message.message.root.id, + error=ErrorData( + code=INVALID_PARAMS, + message="Invalid request parameters", + data="", + ), + ) + session_message = SessionMessage(message=JSONRPCMessage(error_response)) + await self._write_stream.send(session_message) + + elif isinstance(message.message.root, JSONRPCNotification): + try: + notification = self._receive_notification_type.model_validate( + message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + # Handle progress notifications callback + if isinstance(notification.root, ProgressNotification): + progress_token = notification.root.params.progressToken + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + await callback( + notification.root.params.progress, + notification.root.params.total, + notification.root.params.message, + ) + await self._received_notification(notification) + await self._handle_incoming(notification) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. Message was: {message.message.root}" + ) + else: # Response or error + stream = self._response_streams.pop(message.message.root.id, None) + if stream: + await stream.send(message.message.root) + else: + await self._handle_incoming( + RuntimeError(f"Received response with an unknown request ID: {message}") + ) + + except anyio.ClosedResourceError: + # This is expected when the client disconnects abruptly. + # Without this handler, the exception would propagate up and + # crash the server's task group. + logging.debug("Read stream closed by client") + except Exception as e: + # Other exceptions are not expected and should be logged. We purposefully + # catch all exceptions here to avoid crashing the server. + logging.exception(f"Unhandled exception in receive loop: {e}") + finally: + # after the read stream is closed, we need to send errors + # to any pending requests + for id, stream in self._response_streams.items(): + error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + try: + await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) + await stream.aclose() + except Exception: + # Stream might already be closed + pass + self._response_streams.clear() + + async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: + """ + Can be overridden by subclasses to handle a request without needing to + listen on the message stream. + + If the request is responded to within this method, it will not be + forwarded on to the message stream. + """ + + async def _received_notification(self, notification: ReceiveNotificationT) -> None: + """ + Can be overridden by subclasses to handle a notification without needing + to listen on the message stream. + """ + + async def send_progress_notification( + self, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + """ + Sends a progress notification for a request that is currently being + processed. + """ + + async def _handle_incoming( + self, + req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, + ) -> None: + """A generic handler for incoming messages. Overwritten by subclasses.""" + pass diff --git a/src/mcp/mcp/shared/version.py b/src/mcp/mcp/shared/version.py new file mode 100644 index 000000000..23c46d04b --- /dev/null +++ b/src/mcp/mcp/shared/version.py @@ -0,0 +1,3 @@ +from mcp.types import LATEST_PROTOCOL_VERSION + +SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] diff --git a/src/mcp/mcp/types.py b/src/mcp/mcp/types.py new file mode 100644 index 000000000..98fefa080 --- /dev/null +++ b/src/mcp/mcp/types.py @@ -0,0 +1,1320 @@ +from collections.abc import Callable +from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar + +from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic.networks import AnyUrl, UrlConstraints +from typing_extensions import deprecated + +""" +Model Context Protocol bindings for Python + +These bindings were generated from https://github.com/modelcontextprotocol/specification, +using Claude, with a prompt something like the following: + +Generate idiomatic Python bindings for this schema for MCP, or the "Model Context +Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version +for reference. + +* For the bindings, let's use Pydantic V2 models. +* Each model should allow extra fields everywhere, by specifying `model_config = + ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class. +* Union types should be represented with a Pydantic `RootModel`. +* Define additional model classes instead of using dictionaries. Do this even if they're + not separate types in the schema. +""" + +LATEST_PROTOCOL_VERSION = "2025-06-18" + +""" +The default negotiated version of the Model Context Protocol when no version is specified. +We need this to satisfy the MCP specification, which requires the server to assume a +specific version if none is provided by the client. See section "Protocol Version Header" at +https://modelcontextprotocol.io/specification +""" +DEFAULT_NEGOTIATED_VERSION = "2025-03-26" + +ProgressToken = str | int +Cursor = str +Role = Literal["user", "assistant"] +RequestId = Annotated[int, Field(strict=True)] | str +AnyFunction: TypeAlias = Callable[..., Any] + + +class RequestParams(BaseModel): + class Meta(BaseModel): + progressToken: ProgressToken | None = None + """ + If specified, the caller requests out-of-band progress notifications for + this request (as represented by notifications/progress). The value of this + parameter is an opaque token that will be attached to any subsequent + notifications. The receiver is not obligated to provide these notifications. + """ + + model_config = ConfigDict(extra="allow") + + meta: Meta | None = Field(alias="_meta", default=None) + + +class PaginatedRequestParams(RequestParams): + cursor: Cursor | None = None + """ + An opaque token representing the current pagination position. + If provided, the server should return results starting after this cursor. + """ + + +class NotificationParams(BaseModel): + class Meta(BaseModel): + model_config = ConfigDict(extra="allow") + + meta: Meta | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + + +RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) +NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None) +MethodT = TypeVar("MethodT", bound=str) + + +class Request(BaseModel, Generic[RequestParamsT, MethodT]): + """Base class for JSON-RPC requests.""" + + method: MethodT + params: RequestParamsT + model_config = ConfigDict(extra="allow") + + +class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): + """Base class for paginated requests, + matching the schema's PaginatedRequest interface.""" + + params: PaginatedRequestParams | None = None + + +class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): + """Base class for JSON-RPC notifications.""" + + method: MethodT + params: NotificationParamsT + model_config = ConfigDict(extra="allow") + + +class Result(BaseModel): + """Base class for JSON-RPC results.""" + + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class PaginatedResult(Result): + nextCursor: Cursor | None = None + """ + An opaque token representing the pagination position after the last returned result. + If present, there may be more results available. + """ + + +class JSONRPCRequest(Request[dict[str, Any] | None, str]): + """A request that expects a response.""" + + jsonrpc: Literal["2.0"] + id: RequestId + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(Notification[dict[str, Any] | None, str]): + """A notification which does not expect a response.""" + + jsonrpc: Literal["2.0"] + params: dict[str, Any] | None = None + + +class JSONRPCResponse(BaseModel): + """A successful (non-error) response to a request.""" + + jsonrpc: Literal["2.0"] + id: RequestId + result: dict[str, Any] + model_config = ConfigDict(extra="allow") + + +# SDK error codes +CONNECTION_CLOSED = -32000 +# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this + +# Standard JSON-RPC error codes +PARSE_ERROR = -32700 +INVALID_REQUEST = -32600 +METHOD_NOT_FOUND = -32601 +INVALID_PARAMS = -32602 +INTERNAL_ERROR = -32603 + + +class ErrorData(BaseModel): + """Error information for JSON-RPC error responses.""" + + code: int + """The error type that occurred.""" + + message: str + """ + A short description of the error. The message SHOULD be limited to a concise single + sentence. + """ + + data: Any | None = None + """ + Additional information about the error. The value of this member is defined by the + sender (e.g. detailed error information, nested errors etc.). + """ + + model_config = ConfigDict(extra="allow") + + +class JSONRPCError(BaseModel): + """A response to a request that indicates an error occurred.""" + + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + model_config = ConfigDict(extra="allow") + + +class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): + pass + + +class EmptyResult(Result): + """A response that indicates success but carries no data.""" + + +class BaseMetadata(BaseModel): + """Base class for entities with name and optional title fields.""" + + name: str + """The programmatic name of the entity.""" + + title: str | None = None + """ + Intended for UI and end-user contexts — optimized to be human-readable and easily understood, + even by those unfamiliar with domain-specific terminology. + + If not provided, the name should be used for display (except for Tool, + where `annotations.title` should be given precedence over using `name`, + if present). + """ + + +class Implementation(BaseMetadata): + """Describes the name and version of an MCP implementation.""" + + version: str + model_config = ConfigDict(extra="allow") + + +class RootsCapability(BaseModel): + """Capability for root operations.""" + + listChanged: bool | None = None + """Whether the client supports notifications for changes to the roots list.""" + model_config = ConfigDict(extra="allow") + + +class SamplingCapability(BaseModel): + """Capability for sampling operations.""" + + model_config = ConfigDict(extra="allow") + + +class ElicitationCapability(BaseModel): + """Capability for elicitation operations.""" + + model_config = ConfigDict(extra="allow") + + +class ClientCapabilities(BaseModel): + """Capabilities a client may support.""" + + experimental: dict[str, dict[str, Any]] | None = None + """Experimental, non-standard capabilities that the client supports.""" + sampling: SamplingCapability | None = None + """Present if the client supports sampling from an LLM.""" + elicitation: ElicitationCapability | None = None + """Present if the client supports elicitation from the user.""" + roots: RootsCapability | None = None + """Present if the client supports listing roots.""" + model_config = ConfigDict(extra="allow") + + +class PromptsCapability(BaseModel): + """Capability for prompts operations.""" + + listChanged: bool | None = None + """Whether this server supports notifications for changes to the prompt list.""" + model_config = ConfigDict(extra="allow") + + +class ResourcesCapability(BaseModel): + """Capability for resources operations.""" + + subscribe: bool | None = None + """Whether this server supports subscribing to resource updates.""" + listChanged: bool | None = None + """Whether this server supports notifications for changes to the resource list.""" + model_config = ConfigDict(extra="allow") + + +class ToolsCapability(BaseModel): + """Capability for tools operations.""" + + listChanged: bool | None = None + """Whether this server supports notifications for changes to the tool list.""" + model_config = ConfigDict(extra="allow") + + +class LoggingCapability(BaseModel): + """Capability for logging operations.""" + + model_config = ConfigDict(extra="allow") + + +class CompletionsCapability(BaseModel): + """Capability for completions operations.""" + + model_config = ConfigDict(extra="allow") + + +class ServerCapabilities(BaseModel): + """Capabilities that a server may support.""" + + experimental: dict[str, dict[str, Any]] | None = None + """Experimental, non-standard capabilities that the server supports.""" + logging: LoggingCapability | None = None + """Present if the server supports sending log messages to the client.""" + prompts: PromptsCapability | None = None + """Present if the server offers any prompt templates.""" + resources: ResourcesCapability | None = None + """Present if the server offers any resources to read.""" + tools: ToolsCapability | None = None + """Present if the server offers any tools to call.""" + completions: CompletionsCapability | None = None + """Present if the server offers autocompletion suggestions for prompts and resources.""" + model_config = ConfigDict(extra="allow") + + +class InitializeRequestParams(RequestParams): + """Parameters for the initialize request.""" + + protocolVersion: str | int + """The latest version of the Model Context Protocol that the client supports.""" + capabilities: ClientCapabilities + clientInfo: Implementation + model_config = ConfigDict(extra="allow") + + +class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]): + """ + This request is sent from the client to the server when it first connects, asking it + to begin initialization. + """ + + method: Literal["initialize"] + params: InitializeRequestParams + + +class InitializeResult(Result): + """After receiving an initialize request from the client, the server sends this.""" + + protocolVersion: str | int + """The version of the Model Context Protocol that the server wants to use.""" + capabilities: ServerCapabilities + serverInfo: Implementation + instructions: str | None = None + """Instructions describing how to use the server and its features.""" + + +class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]): + """ + This notification is sent from the client to the server after initialization has + finished. + """ + + method: Literal["notifications/initialized"] + params: NotificationParams | None = None + + +class PingRequest(Request[RequestParams | None, Literal["ping"]]): + """ + A ping, issued by either the server or the client, to check that the other party is + still alive. + """ + + method: Literal["ping"] + params: RequestParams | None = None + + +class ProgressNotificationParams(NotificationParams): + """Parameters for progress notifications.""" + + progressToken: ProgressToken + """ + The progress token which was given in the initial request, used to associate this + notification with the request that is proceeding. + """ + progress: float + """ + The progress thus far. This should increase every time progress is made, even if the + total is unknown. + """ + total: float | None = None + """Total number of items to process (or total progress required), if known.""" + message: str | None = None + """ + Message related to progress. This should provide relevant human readable + progress information. + """ + model_config = ConfigDict(extra="allow") + + +class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]): + """ + An out-of-band notification used to inform the receiver of a progress update for a + long-running request. + """ + + method: Literal["notifications/progress"] + params: ProgressNotificationParams + + +class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): + """Sent from the client to request a list of resources the server has.""" + + method: Literal["resources/list"] + + +class Annotations(BaseModel): + audience: list[Role] | None = None + priority: Annotated[float, Field(ge=0.0, le=1.0)] | None = None + model_config = ConfigDict(extra="allow") + + +class Resource(BaseMetadata): + """A known resource that the server is capable of reading.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """The URI of this resource.""" + description: str | None = None + """A description of what this resource represents.""" + mimeType: str | None = None + """The MIME type of this resource, if known.""" + size: int | None = None + """ + The size of the raw resource content, in bytes (i.e., before base64 encoding + or any tokenization), if known. + + This can be used by Hosts to display file sizes and estimate context window usage. + """ + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ResourceTemplate(BaseMetadata): + """A template description for resources available on the server.""" + + uriTemplate: str + """ + A URI template (according to RFC 6570) that can be used to construct resource + URIs. + """ + description: str | None = None + """A human-readable description of what this template is for.""" + mimeType: str | None = None + """ + The MIME type for all resources that match this template. This should only be + included if all resources matching this template have the same type. + """ + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ListResourcesResult(PaginatedResult): + """The server's response to a resources/list request from the client.""" + + resources: list[Resource] + + +class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): + """Sent from the client to request a list of resource templates the server has.""" + + method: Literal["resources/templates/list"] + + +class ListResourceTemplatesResult(PaginatedResult): + """The server's response to a resources/templates/list request from the client.""" + + resourceTemplates: list[ResourceTemplate] + + +class ReadResourceRequestParams(RequestParams): + """Parameters for reading a resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """ + The URI of the resource to read. The URI can use any protocol; it is up to the + server how to interpret it. + """ + model_config = ConfigDict(extra="allow") + + +class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): + """Sent from the client to the server, to read a specific resource URI.""" + + method: Literal["resources/read"] + params: ReadResourceRequestParams + + +class ResourceContents(BaseModel): + """The contents of a specific resource or sub-resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """The URI of this resource.""" + mimeType: str | None = None + """The MIME type of this resource, if known.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class TextResourceContents(ResourceContents): + """Text contents of a resource.""" + + text: str + """ + The text of the item. This must only be set if the item can actually be represented + as text (not binary data). + """ + + +class BlobResourceContents(ResourceContents): + """Binary contents of a resource.""" + + blob: str + """A base64-encoded string representing the binary data of the item.""" + + +class ReadResourceResult(Result): + """The server's response to a resources/read request from the client.""" + + contents: list[TextResourceContents | BlobResourceContents] + + +class ResourceListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]] +): + """ + An optional notification from the server to the client, informing it that the list + of resources it can read from has changed. + """ + + method: Literal["notifications/resources/list_changed"] + params: NotificationParams | None = None + + +class SubscribeRequestParams(RequestParams): + """Parameters for subscribing to a resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """ + The URI of the resource to subscribe to. The URI can use any protocol; it is up to + the server how to interpret it. + """ + model_config = ConfigDict(extra="allow") + + +class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]): + """ + Sent from the client to request resources/updated notifications from the server + whenever a particular resource changes. + """ + + method: Literal["resources/subscribe"] + params: SubscribeRequestParams + + +class UnsubscribeRequestParams(RequestParams): + """Parameters for unsubscribing from a resource.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """The URI of the resource to unsubscribe from.""" + model_config = ConfigDict(extra="allow") + + +class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]): + """ + Sent from the client to request cancellation of resources/updated notifications from + the server. + """ + + method: Literal["resources/unsubscribe"] + params: UnsubscribeRequestParams + + +class ResourceUpdatedNotificationParams(NotificationParams): + """Parameters for resource update notifications.""" + + uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] + """ + The URI of the resource that has been updated. This might be a sub-resource of the + one that the client actually subscribed to. + """ + model_config = ConfigDict(extra="allow") + + +class ResourceUpdatedNotification( + Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]] +): + """ + A notification from the server to the client, informing it that a resource has + changed and may need to be read again. + """ + + method: Literal["notifications/resources/updated"] + params: ResourceUpdatedNotificationParams + + +class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): + """Sent from the client to request a list of prompts and prompt templates.""" + + method: Literal["prompts/list"] + + +class PromptArgument(BaseModel): + """An argument for a prompt template.""" + + name: str + """The name of the argument.""" + description: str | None = None + """A human-readable description of the argument.""" + required: bool | None = None + """Whether this argument must be provided.""" + model_config = ConfigDict(extra="allow") + + +class Prompt(BaseMetadata): + """A prompt or prompt template that the server offers.""" + + description: str | None = None + """An optional description of what this prompt provides.""" + arguments: list[PromptArgument] | None = None + """A list of arguments to use for templating the prompt.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ListPromptsResult(PaginatedResult): + """The server's response to a prompts/list request from the client.""" + + prompts: list[Prompt] + + +class GetPromptRequestParams(RequestParams): + """Parameters for getting a prompt.""" + + name: str + """The name of the prompt or prompt template.""" + arguments: dict[str, str] | None = None + """Arguments to use for templating the prompt.""" + model_config = ConfigDict(extra="allow") + + +class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): + """Used by the client to get a prompt provided by the server.""" + + method: Literal["prompts/get"] + params: GetPromptRequestParams + + +class TextContent(BaseModel): + """Text content for a message.""" + + type: Literal["text"] + text: str + """The text content of the message.""" + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ImageContent(BaseModel): + """Image content for a message.""" + + type: Literal["image"] + data: str + """The base64-encoded image data.""" + mimeType: str + """ + The MIME type of the image. Different providers may support different + image types. + """ + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class AudioContent(BaseModel): + """Audio content for a message.""" + + type: Literal["audio"] + data: str + """The base64-encoded audio data.""" + mimeType: str + """ + The MIME type of the audio. Different providers may support different + audio types. + """ + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class SamplingMessage(BaseModel): + """Describes a message issued to or received from an LLM API.""" + + role: Role + content: TextContent | ImageContent | AudioContent + model_config = ConfigDict(extra="allow") + + +class EmbeddedResource(BaseModel): + """ + The contents of a resource, embedded into a prompt or tool call result. + + It is up to the client how best to render embedded resources for the benefit + of the LLM and/or the user. + """ + + type: Literal["resource"] + resource: TextResourceContents | BlobResourceContents + annotations: Annotations | None = None + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ResourceLink(Resource): + """ + A resource that the server is capable of reading, included in a prompt or tool call result. + + Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. + """ + + type: Literal["resource_link"] + + +ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource +"""A content block that can be used in prompts and tool results.""" + +Content: TypeAlias = ContentBlock +# """DEPRECATED: Content is deprecated, you should use ContentBlock directly.""" + + +class PromptMessage(BaseModel): + """Describes a message returned as part of a prompt.""" + + role: Role + content: ContentBlock + model_config = ConfigDict(extra="allow") + + +class GetPromptResult(Result): + """The server's response to a prompts/get request from the client.""" + + description: str | None = None + """An optional description for the prompt.""" + messages: list[PromptMessage] + + +class PromptListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]] +): + """ + An optional notification from the server to the client, informing it that the list + of prompts it offers has changed. + """ + + method: Literal["notifications/prompts/list_changed"] + params: NotificationParams | None = None + + +class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): + """Sent from the client to request a list of tools the server has.""" + + method: Literal["tools/list"] + + +class ToolAnnotations(BaseModel): + """ + Additional properties describing a Tool to clients. + + NOTE: all properties in ToolAnnotations are **hints**. + They are not guaranteed to provide a faithful description of + tool behavior (including descriptive properties like `title`). + + Clients should never make tool use decisions based on ToolAnnotations + received from untrusted servers. + """ + + title: str | None = None + """A human-readable title for the tool.""" + + readOnlyHint: bool | None = None + """ + If true, the tool does not modify its environment. + Default: false + """ + + destructiveHint: bool | None = None + """ + If true, the tool may perform destructive updates to its environment. + If false, the tool performs only additive updates. + (This property is meaningful only when `readOnlyHint == false`) + Default: true + """ + + idempotentHint: bool | None = None + """ + If true, calling the tool repeatedly with the same arguments + will have no additional effect on the its environment. + (This property is meaningful only when `readOnlyHint == false`) + Default: false + """ + + openWorldHint: bool | None = None + """ + If true, this tool may interact with an "open world" of external + entities. If false, the tool's domain of interaction is closed. + For example, the world of a web search tool is open, whereas that + of a memory tool is not. + Default: true + """ + model_config = ConfigDict(extra="allow") + + +class Tool(BaseMetadata): + """Definition for a tool the client can call.""" + + description: str | None = None + """A human-readable description of the tool.""" + inputSchema: dict[str, Any] + """A JSON Schema object defining the expected parameters for the tool.""" + outputSchema: dict[str, Any] | None = None + """ + An optional JSON Schema object defining the structure of the tool's output + returned in the structuredContent field of a CallToolResult. + """ + annotations: ToolAnnotations | None = None + """Optional additional tool information.""" + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ListToolsResult(PaginatedResult): + """The server's response to a tools/list request from the client.""" + + tools: list[Tool] + + +class CallToolRequestParams(RequestParams): + """Parameters for calling a tool.""" + + name: str + arguments: dict[str, Any] | None = None + model_config = ConfigDict(extra="allow") + + +class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): + """Used by the client to invoke a tool provided by the server.""" + + method: Literal["tools/call"] + params: CallToolRequestParams + + +class CallToolResult(Result): + """The server's response to a tool call.""" + + content: list[ContentBlock] + structuredContent: dict[str, Any] | None = None + """An optional JSON object that represents the structured result of the tool call.""" + isError: bool = False + + +class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]): + """ + An optional notification from the server to the client, informing it that the list + of tools it offers has changed. + """ + + method: Literal["notifications/tools/list_changed"] + params: NotificationParams | None = None + + +LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] + + +class SetLevelRequestParams(RequestParams): + """Parameters for setting the logging level.""" + + level: LoggingLevel + """The level of logging that the client wants to receive from the server.""" + model_config = ConfigDict(extra="allow") + + +class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): + """A request from the client to the server, to enable or adjust logging.""" + + method: Literal["logging/setLevel"] + params: SetLevelRequestParams + + +class LoggingMessageNotificationParams(NotificationParams): + """Parameters for logging message notifications.""" + + level: LoggingLevel + """The severity of this log message.""" + logger: str | None = None + """An optional name of the logger issuing this message.""" + data: Any + """ + The data to be logged, such as a string message or an object. Any JSON serializable + type is allowed here. + """ + model_config = ConfigDict(extra="allow") + + +class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): + """Notification of a log message passed from server to client.""" + + method: Literal["notifications/message"] + params: LoggingMessageNotificationParams + + +IncludeContext = Literal["none", "thisServer", "allServers"] + + +class ModelHint(BaseModel): + """Hints to use for model selection.""" + + name: str | None = None + """A hint for a model name.""" + + model_config = ConfigDict(extra="allow") + + +class ModelPreferences(BaseModel): + """ + The server's preferences for model selection, requested by the client during + sampling. + + Because LLMs can vary along multiple dimensions, choosing the "best" model is + rarely straightforward. Different models excel in different areas—some are + faster but less capable, others are more capable but more expensive, and so + on. This interface allows servers to express their priorities across multiple + dimensions to help clients make an appropriate selection for their use case. + + These preferences are always advisory. The client MAY ignore them. It is also + up to the client to decide how to interpret these preferences and how to + balance them against other considerations. + """ + + hints: list[ModelHint] | None = None + """ + Optional hints to use for model selection. + + If multiple hints are specified, the client MUST evaluate them in order + (such that the first match is taken). + + The client SHOULD prioritize these hints over the numeric priorities, but + MAY still use the priorities to select from ambiguous matches. + """ + + costPriority: float | None = None + """ + How much to prioritize cost when selecting a model. A value of 0 means cost + is not important, while a value of 1 means cost is the most important + factor. + """ + + speedPriority: float | None = None + """ + How much to prioritize sampling speed (latency) when selecting a model. A + value of 0 means speed is not important, while a value of 1 means speed is + the most important factor. + """ + + intelligencePriority: float | None = None + """ + How much to prioritize intelligence and capabilities when selecting a + model. A value of 0 means intelligence is not important, while a value of 1 + means intelligence is the most important factor. + """ + + model_config = ConfigDict(extra="allow") + + +class CreateMessageRequestParams(RequestParams): + """Parameters for creating a message.""" + + messages: list[SamplingMessage] + modelPreferences: ModelPreferences | None = None + """ + The server's preferences for which model to select. The client MAY ignore + these preferences. + """ + systemPrompt: str | None = None + """An optional system prompt the server wants to use for sampling.""" + includeContext: IncludeContext | None = None + """ + A request to include context from one or more MCP servers (including the caller), to + be attached to the prompt. + """ + temperature: float | None = None + maxTokens: int + """The maximum number of tokens to sample, as requested by the server.""" + stopSequences: list[str] | None = None + metadata: dict[str, Any] | None = None + """Optional metadata to pass through to the LLM provider.""" + model_config = ConfigDict(extra="allow") + + +class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): + """A request from the server to sample an LLM via the client.""" + + method: Literal["sampling/createMessage"] + params: CreateMessageRequestParams + + +StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str + + +class CreateMessageResult(Result): + """The client's response to a sampling/create_message request from the server.""" + + role: Role + content: TextContent | ImageContent | AudioContent + model: str + """The name of the model that generated the message.""" + stopReason: StopReason | None = None + """The reason why sampling stopped, if known.""" + + +class ResourceTemplateReference(BaseModel): + """A reference to a resource or resource template definition.""" + + type: Literal["ref/resource"] + uri: str + """The URI or URI template of the resource.""" + model_config = ConfigDict(extra="allow") + + +@deprecated("`ResourceReference` is deprecated, you should use `ResourceTemplateReference`.") +class ResourceReference(ResourceTemplateReference): + pass + + +class PromptReference(BaseModel): + """Identifies a prompt.""" + + type: Literal["ref/prompt"] + name: str + """The name of the prompt or prompt template""" + model_config = ConfigDict(extra="allow") + + +class CompletionArgument(BaseModel): + """The argument's information for completion requests.""" + + name: str + """The name of the argument""" + value: str + """The value of the argument to use for completion matching.""" + model_config = ConfigDict(extra="allow") + + +class CompletionContext(BaseModel): + """Additional, optional context for completions.""" + + arguments: dict[str, str] | None = None + """Previously-resolved variables in a URI template or prompt.""" + model_config = ConfigDict(extra="allow") + + +class CompleteRequestParams(RequestParams): + """Parameters for completion requests.""" + + ref: ResourceTemplateReference | PromptReference + argument: CompletionArgument + context: CompletionContext | None = None + """Additional, optional context for completions""" + model_config = ConfigDict(extra="allow") + + +class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): + """A request from the client to the server, to ask for completion options.""" + + method: Literal["completion/complete"] + params: CompleteRequestParams + + +class Completion(BaseModel): + """Completion information.""" + + values: list[str] + """An array of completion values. Must not exceed 100 items.""" + total: int | None = None + """ + The total number of completion options available. This can exceed the number of + values actually sent in the response. + """ + hasMore: bool | None = None + """ + Indicates whether there are additional completion options beyond those provided in + the current response, even if the exact total is unknown. + """ + model_config = ConfigDict(extra="allow") + + +class CompleteResult(Result): + """The server's response to a completion/complete request""" + + completion: Completion + + +class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): + """ + Sent from the server to request a list of root URIs from the client. Roots allow + servers to ask for specific directories or files to operate on. A common example + for roots is providing a set of repositories or directories a server should operate + on. + + This request is typically used when the server needs to understand the file system + structure or access specific locations that the client has permission to read from. + """ + + method: Literal["roots/list"] + params: RequestParams | None = None + + +class Root(BaseModel): + """Represents a root directory or file that the server can operate on.""" + + uri: FileUrl + """ + The URI identifying the root. This *must* start with file:// for now. + This restriction may be relaxed in future versions of the protocol to allow + other URI schemes. + """ + name: str | None = None + """ + An optional name for the root. This can be used to provide a human-readable + identifier for the root, which may be useful for display purposes or for + referencing the root in other parts of the application. + """ + meta: dict[str, Any] | None = Field(alias="_meta", default=None) + """ + See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + for notes on _meta usage. + """ + model_config = ConfigDict(extra="allow") + + +class ListRootsResult(Result): + """ + The client's response to a roots/list request from the server. + This result contains an array of Root objects, each representing a root directory + or file that the server can operate on. + """ + + roots: list[Root] + + +class RootsListChangedNotification( + Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] +): + """ + A notification from the client to the server, informing it that the list of + roots has changed. + + This notification should be sent whenever the client adds, removes, or + modifies any root. The server should then request an updated list of roots + using the ListRootsRequest. + """ + + method: Literal["notifications/roots/list_changed"] + params: NotificationParams | None = None + + +class CancelledNotificationParams(NotificationParams): + """Parameters for cancellation notifications.""" + + requestId: RequestId + """The ID of the request to cancel.""" + reason: str | None = None + """An optional string describing the reason for the cancellation.""" + model_config = ConfigDict(extra="allow") + + +class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]): + """ + This notification can be sent by either side to indicate that it is canceling a + previously-issued request. + """ + + method: Literal["notifications/cancelled"] + params: CancelledNotificationParams + + +class ClientRequest( + RootModel[ + PingRequest + | InitializeRequest + | CompleteRequest + | SetLevelRequest + | GetPromptRequest + | ListPromptsRequest + | ListResourcesRequest + | ListResourceTemplatesRequest + | ReadResourceRequest + | SubscribeRequest + | UnsubscribeRequest + | CallToolRequest + | ListToolsRequest + ] +): + pass + + +class ClientNotification( + RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] +): + pass + + +# Type for elicitation schema - a JSON Schema dict +ElicitRequestedSchema: TypeAlias = dict[str, Any] +"""Schema for elicitation requests.""" + + +class ElicitRequestParams(RequestParams): + """Parameters for elicitation requests.""" + + message: str + requestedSchema: ElicitRequestedSchema + model_config = ConfigDict(extra="allow") + + +class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): + """A request from the server to elicit information from the client.""" + + method: Literal["elicitation/create"] + params: ElicitRequestParams + + +class ElicitResult(Result): + """The client's response to an elicitation request.""" + + action: Literal["accept", "decline", "cancel"] + """ + The user action in response to the elicitation. + - "accept": User submitted the form/confirmed the action + - "decline": User explicitly declined the action + - "cancel": User dismissed without making an explicit choice + """ + + content: dict[str, str | int | float | bool | None] | None = None + """ + The submitted form data, only present when action is "accept". + Contains values matching the requested schema. + """ + + +class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): + pass + + +class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): + pass + + +class ServerNotification( + RootModel[ + CancelledNotification + | ProgressNotification + | LoggingMessageNotification + | ResourceUpdatedNotification + | ResourceListChangedNotification + | ToolListChangedNotification + | PromptListChangedNotification + ] +): + pass + + +class ServerResult( + RootModel[ + EmptyResult + | InitializeResult + | CompleteResult + | GetPromptResult + | ListPromptsResult + | ListResourcesResult + | ListResourceTemplatesResult + | ReadResourceResult + | CallToolResult + | ListToolsResult + ] +): + pass From dcd0310b6304dc83c0e27c379a98312a2f0a8e2d Mon Sep 17 00:00:00 2001 From: "wenhui.zhang" Date: Tue, 26 Aug 2025 14:50:12 -0700 Subject: [PATCH 2/2] add secure annotation --- src/mcp/client/auth.py | 220 ++- src/mcp/client/stdio/__init__.py | 1 + src/mcp/client/streamable_http.py | 4 +- src/mcp/mcp/__init__.py | 116 -- src/mcp/mcp/cli/__init__.py | 6 - src/mcp/mcp/cli/claude.py | 148 -- src/mcp/mcp/cli/cli.py | 488 ------ src/mcp/mcp/client/__init__.py | 0 src/mcp/mcp/client/__main__.py | 85 -- src/mcp/mcp/client/auth.py | 551 ------- src/mcp/mcp/client/session.py | 458 ------ src/mcp/mcp/client/session_group.py | 366 ----- src/mcp/mcp/client/sse.py | 144 -- src/mcp/mcp/client/stdio/__init__.py | 277 ---- src/mcp/mcp/client/streamable_http.py | 511 ------- src/mcp/mcp/client/websocket.py | 86 -- src/mcp/mcp/os/__init__.py | 1 - src/mcp/mcp/os/posix/__init__.py | 1 - src/mcp/mcp/os/posix/utilities.py | 60 - src/mcp/mcp/os/win32/__init__.py | 1 - src/mcp/mcp/os/win32/utilities.py | 338 ----- src/mcp/mcp/py.typed | 0 src/mcp/mcp/server/__init__.py | 5 - src/mcp/mcp/server/__main__.py | 50 - src/mcp/mcp/server/auth/__init__.py | 3 - src/mcp/mcp/server/auth/errors.py | 5 - src/mcp/mcp/server/auth/handlers/__init__.py | 3 - src/mcp/mcp/server/auth/handlers/authorize.py | 224 --- src/mcp/mcp/server/auth/handlers/metadata.py | 29 - src/mcp/mcp/server/auth/handlers/register.py | 120 -- src/mcp/mcp/server/auth/handlers/revoke.py | 94 -- src/mcp/mcp/server/auth/handlers/token.py | 238 --- src/mcp/mcp/server/auth/json_response.py | 10 - .../mcp/server/auth/middleware/__init__.py | 3 - .../server/auth/middleware/auth_context.py | 48 - .../mcp/server/auth/middleware/bearer_auth.py | 128 -- .../mcp/server/auth/middleware/client_auth.py | 51 - src/mcp/mcp/server/auth/provider.py | 306 ---- src/mcp/mcp/server/auth/routes.py | 227 --- src/mcp/mcp/server/auth/settings.py | 30 - src/mcp/mcp/server/elicitation.py | 111 -- src/mcp/mcp/server/fastmcp/__init__.py | 9 - src/mcp/mcp/server/fastmcp/exceptions.py | 21 - .../mcp/server/fastmcp/prompts/__init__.py | 4 - src/mcp/mcp/server/fastmcp/prompts/base.py | 153 -- src/mcp/mcp/server/fastmcp/prompts/manager.py | 48 - .../mcp/server/fastmcp/resources/__init__.py | 23 - src/mcp/mcp/server/fastmcp/resources/base.py | 45 - .../fastmcp/resources/resource_manager.py | 97 -- .../mcp/server/fastmcp/resources/templates.py | 83 -- src/mcp/mcp/server/fastmcp/resources/types.py | 196 --- src/mcp/mcp/server/fastmcp/server.py | 1188 --------------- src/mcp/mcp/server/fastmcp/tools/__init__.py | 4 - src/mcp/mcp/server/fastmcp/tools/base.py | 119 -- .../mcp/server/fastmcp/tools/tool_manager.py | 83 -- .../mcp/server/fastmcp/utilities/__init__.py | 1 - .../server/fastmcp/utilities/func_metadata.py | 523 ------- .../mcp/server/fastmcp/utilities/logging.py | 43 - src/mcp/mcp/server/fastmcp/utilities/types.py | 101 -- src/mcp/mcp/server/lowlevel/__init__.py | 3 - src/mcp/mcp/server/lowlevel/helper_types.py | 9 - src/mcp/mcp/server/lowlevel/server.py | 686 --------- src/mcp/mcp/server/models.py | 17 - src/mcp/mcp/server/session.py | 359 ----- src/mcp/mcp/server/sse.py | 249 ---- src/mcp/mcp/server/stdio.py | 88 -- src/mcp/mcp/server/streamable_http.py | 901 ----------- src/mcp/mcp/server/streamable_http_manager.py | 279 ---- .../mcp/server/streaming_asgi_transport.py | 203 --- src/mcp/mcp/server/transport_security.py | 127 -- src/mcp/mcp/server/websocket.py | 62 - src/mcp/mcp/shared/__init__.py | 0 src/mcp/mcp/shared/_httpx_utils.py | 83 -- src/mcp/mcp/shared/auth.py | 155 -- src/mcp/mcp/shared/auth_utils.py | 69 - src/mcp/mcp/shared/context.py | 20 - src/mcp/mcp/shared/exceptions.py | 14 - src/mcp/mcp/shared/memory.py | 99 -- src/mcp/mcp/shared/message.py | 43 - src/mcp/mcp/shared/metadata_utils.py | 45 - src/mcp/mcp/shared/progress.py | 58 - src/mcp/mcp/shared/session.py | 470 ------ src/mcp/mcp/shared/version.py | 3 - src/mcp/mcp/types.py | 1320 ----------------- src/mcp/server/auth/provider.py | 2 +- src/mcp/server/auth/routes.py | 4 + src/mcp/server/fastmcp/__init__.py | 4 +- .../server/fastmcp/prompts/prompt_manager.py | 33 - .../server/fastmcp/secure/__init__.py | 0 .../server/fastmcp/secure/annotations.py | 0 .../server/fastmcp/secure/identity.py | 0 .../{mcp => }/server/fastmcp/secure/prompt.py | 0 .../server/fastmcp/secure/resource.py | 0 .../server/fastmcp/secure/session.py | 0 .../{mcp => }/server/fastmcp/secure/tool.py | 0 .../{mcp => }/server/fastmcp/secure/utils.py | 0 src/mcp/server/fastmcp/server.py | 118 +- src/mcp/server/fastmcp/tools/base.py | 2 +- .../server/fastmcp/utilities/func_metadata.py | 47 +- src/mcp/server/fastmcp/utilities/types.py | 49 +- src/mcp/server/lowlevel/server.py | 11 +- src/mcp/server/streamable_http.py | 8 +- src/mcp/server/streamable_http_manager.py | 8 +- src/mcp/shared/auth.py | 19 +- src/mcp/types.py | 4 +- 105 files changed, 297 insertions(+), 13662 deletions(-) delete mode 100644 src/mcp/mcp/__init__.py delete mode 100644 src/mcp/mcp/cli/__init__.py delete mode 100644 src/mcp/mcp/cli/claude.py delete mode 100644 src/mcp/mcp/cli/cli.py delete mode 100644 src/mcp/mcp/client/__init__.py delete mode 100644 src/mcp/mcp/client/__main__.py delete mode 100644 src/mcp/mcp/client/auth.py delete mode 100644 src/mcp/mcp/client/session.py delete mode 100644 src/mcp/mcp/client/session_group.py delete mode 100644 src/mcp/mcp/client/sse.py delete mode 100644 src/mcp/mcp/client/stdio/__init__.py delete mode 100644 src/mcp/mcp/client/streamable_http.py delete mode 100644 src/mcp/mcp/client/websocket.py delete mode 100644 src/mcp/mcp/os/__init__.py delete mode 100644 src/mcp/mcp/os/posix/__init__.py delete mode 100644 src/mcp/mcp/os/posix/utilities.py delete mode 100644 src/mcp/mcp/os/win32/__init__.py delete mode 100644 src/mcp/mcp/os/win32/utilities.py delete mode 100644 src/mcp/mcp/py.typed delete mode 100644 src/mcp/mcp/server/__init__.py delete mode 100644 src/mcp/mcp/server/__main__.py delete mode 100644 src/mcp/mcp/server/auth/__init__.py delete mode 100644 src/mcp/mcp/server/auth/errors.py delete mode 100644 src/mcp/mcp/server/auth/handlers/__init__.py delete mode 100644 src/mcp/mcp/server/auth/handlers/authorize.py delete mode 100644 src/mcp/mcp/server/auth/handlers/metadata.py delete mode 100644 src/mcp/mcp/server/auth/handlers/register.py delete mode 100644 src/mcp/mcp/server/auth/handlers/revoke.py delete mode 100644 src/mcp/mcp/server/auth/handlers/token.py delete mode 100644 src/mcp/mcp/server/auth/json_response.py delete mode 100644 src/mcp/mcp/server/auth/middleware/__init__.py delete mode 100644 src/mcp/mcp/server/auth/middleware/auth_context.py delete mode 100644 src/mcp/mcp/server/auth/middleware/bearer_auth.py delete mode 100644 src/mcp/mcp/server/auth/middleware/client_auth.py delete mode 100644 src/mcp/mcp/server/auth/provider.py delete mode 100644 src/mcp/mcp/server/auth/routes.py delete mode 100644 src/mcp/mcp/server/auth/settings.py delete mode 100644 src/mcp/mcp/server/elicitation.py delete mode 100644 src/mcp/mcp/server/fastmcp/__init__.py delete mode 100644 src/mcp/mcp/server/fastmcp/exceptions.py delete mode 100644 src/mcp/mcp/server/fastmcp/prompts/__init__.py delete mode 100644 src/mcp/mcp/server/fastmcp/prompts/base.py delete mode 100644 src/mcp/mcp/server/fastmcp/prompts/manager.py delete mode 100644 src/mcp/mcp/server/fastmcp/resources/__init__.py delete mode 100644 src/mcp/mcp/server/fastmcp/resources/base.py delete mode 100644 src/mcp/mcp/server/fastmcp/resources/resource_manager.py delete mode 100644 src/mcp/mcp/server/fastmcp/resources/templates.py delete mode 100644 src/mcp/mcp/server/fastmcp/resources/types.py delete mode 100644 src/mcp/mcp/server/fastmcp/server.py delete mode 100644 src/mcp/mcp/server/fastmcp/tools/__init__.py delete mode 100644 src/mcp/mcp/server/fastmcp/tools/base.py delete mode 100644 src/mcp/mcp/server/fastmcp/tools/tool_manager.py delete mode 100644 src/mcp/mcp/server/fastmcp/utilities/__init__.py delete mode 100644 src/mcp/mcp/server/fastmcp/utilities/func_metadata.py delete mode 100644 src/mcp/mcp/server/fastmcp/utilities/logging.py delete mode 100644 src/mcp/mcp/server/fastmcp/utilities/types.py delete mode 100644 src/mcp/mcp/server/lowlevel/__init__.py delete mode 100644 src/mcp/mcp/server/lowlevel/helper_types.py delete mode 100644 src/mcp/mcp/server/lowlevel/server.py delete mode 100644 src/mcp/mcp/server/models.py delete mode 100644 src/mcp/mcp/server/session.py delete mode 100644 src/mcp/mcp/server/sse.py delete mode 100644 src/mcp/mcp/server/stdio.py delete mode 100644 src/mcp/mcp/server/streamable_http.py delete mode 100644 src/mcp/mcp/server/streamable_http_manager.py delete mode 100644 src/mcp/mcp/server/streaming_asgi_transport.py delete mode 100644 src/mcp/mcp/server/transport_security.py delete mode 100644 src/mcp/mcp/server/websocket.py delete mode 100644 src/mcp/mcp/shared/__init__.py delete mode 100644 src/mcp/mcp/shared/_httpx_utils.py delete mode 100644 src/mcp/mcp/shared/auth.py delete mode 100644 src/mcp/mcp/shared/auth_utils.py delete mode 100644 src/mcp/mcp/shared/context.py delete mode 100644 src/mcp/mcp/shared/exceptions.py delete mode 100644 src/mcp/mcp/shared/memory.py delete mode 100644 src/mcp/mcp/shared/message.py delete mode 100644 src/mcp/mcp/shared/metadata_utils.py delete mode 100644 src/mcp/mcp/shared/progress.py delete mode 100644 src/mcp/mcp/shared/session.py delete mode 100644 src/mcp/mcp/shared/version.py delete mode 100644 src/mcp/mcp/types.py delete mode 100644 src/mcp/server/fastmcp/prompts/prompt_manager.py rename src/mcp/{mcp => }/server/fastmcp/secure/__init__.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/annotations.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/identity.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/prompt.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/resource.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/session.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/tool.py (100%) rename src/mcp/{mcp => }/server/fastmcp/secure/utils.py (100%) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 8bafe18eb..376036e8c 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -7,6 +7,7 @@ import base64 import hashlib import logging +import re import secrets import string import time @@ -203,10 +204,39 @@ def __init__( ) self._initialized = False - async def _discover_protected_resource(self) -> httpx.Request: - """Build discovery request for protected resource metadata.""" - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not init_response or init_response.status_code != 401: + return None + + www_auth_header = init_response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: + # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response + url = self._extract_resource_metadata_from_www_auth(init_response) + + if not url: + # Fallback to well-known discovery + auth_base_url = self.context.get_authorization_base_url(self.context.server_url) + url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) async def _handle_protected_resource_response(self, response: httpx.Response) -> None: @@ -221,72 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> except ValidationError: pass - def _build_well_known_path(self, pathname: str) -> str: - """Construct well-known path for OAuth metadata discovery.""" - well_known_path = f"/.well-known/oauth-authorization-server{pathname}" - if pathname.endswith("/"): - # Strip trailing slash from pathname to avoid double slashes - well_known_path = well_known_path[:-1] - return well_known_path - - def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool: - """Determine if fallback to root discovery should be attempted.""" - return response_status == 404 and pathname != "/" - - async def _try_metadata_discovery(self, url: str) -> httpx.Request: - """Build metadata discovery request for a specific URL.""" - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _discover_oauth_metadata(self) -> httpx.Request: - """Build OAuth metadata discovery request with fallback support.""" - if self.context.auth_server_url: - auth_server_url = self.context.auth_server_url - else: - auth_server_url = self.context.server_url - - # Per RFC 8414, try path-aware discovery first + def _get_discovery_urls(self) -> list[str]: + """Generate ordered list of (url, type) tuples for discovery attempts.""" + urls: list[str] = [] + auth_server_url = self.context.auth_server_url or self.context.server_url parsed = urlparse(auth_server_url) - well_known_path = self._build_well_known_path(parsed.path) base_url = f"{parsed.scheme}://{parsed.netloc}" - url = urljoin(base_url, well_known_path) - - # Store fallback info for use in response handler - self.context.discovery_base_url = base_url - self.context.discovery_pathname = parsed.path - return await self._try_metadata_discovery(url) + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) - async def _discover_oauth_metadata_fallback(self) -> httpx.Request: - """Build fallback OAuth metadata discovery request for legacy servers.""" - base_url = getattr(self.context, "discovery_base_url", "") - if not base_url: - raise OAuthFlowError("No base URL available for fallback discovery") + # OAuth root fallback + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - # Fallback to root discovery for legacy servers - url = urljoin(base_url, "/.well-known/oauth-authorization-server") - return await self._try_metadata_discovery(url) - - async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool: - """Handle OAuth metadata response. Returns True if handled successfully.""" - if response.status_code == 200: - try: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self.context.oauth_metadata = metadata - # Apply default scope if none specified - if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: - self.context.client_metadata.scope = " ".join(metadata.scopes_supported) - return True - except ValidationError: - pass + # RFC 8414 section 5: Path-aware OIDC discovery + # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + if parsed.path and parsed.path != "/": + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) - # Check if we should attempt fallback (404 on path-aware discovery) - if not is_fallback and self._should_attempt_fallback( - response.status_code, getattr(self.context, "discovery_pathname", "/") - ): - return False # Signal that fallback should be attempted + # OIDC 1.0 fallback (appends to full URL per OIDC spec) + oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" + urls.append(oidc_fallback) - return True # Signal no fallback needed (either success or non-404 error) + return urls async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" @@ -481,6 +471,17 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" + def _create_oauth_metadata_request(self, url: str) -> httpx.Request: + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: + content = await response.aread() + metadata = OAuthMetadata.model_validate_json(content) + self.context.oauth_metadata = metadata + # Apply default scope if needed + if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: + self.context.client_metadata.scope = " ".join(metadata.scopes_supported) + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -490,77 +491,43 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - # Perform OAuth flow if not authenticated - if not self.context.is_token_valid(): - try: - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - except Exception: - logger.exception("OAuth flow error") - raise - - # Add authorization header and make request - self._add_auth_header(request) - response = yield request - - # Handle 401 responses - if response.status_code == 401 and self.context.can_refresh_token(): + if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token refresh_request = await self._refresh_token() refresh_response = yield refresh_request - if await self._handle_refresh_response(refresh_response): - # Retry original request with new token - self._add_auth_header(request) - yield request - else: + if not await self._handle_refresh_response(refresh_response): # Refresh failed, need full re-authentication self._initialized = False + if self.context.is_token_valid(): + self._add_auth_header(request) + + response = yield request + + if response.status_code == 401: + # Perform full OAuth flow + try: # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (spec revision 2025-06-18) - discovery_request = await self._discover_protected_resource() + # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) + discovery_request = await self._discover_protected_resource(response) discovery_response = yield discovery_request await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) - oauth_request = await self._discover_oauth_metadata() - oauth_response = yield oauth_request - handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False) - - # If path-aware discovery failed with 404, try fallback to root - if not handled: - fallback_request = await self._discover_oauth_metadata_fallback() - fallback_response = yield fallback_request - await self._handle_oauth_metadata_response(fallback_response, is_fallback=True) + discovery_urls = self._get_discovery_urls() + for url in discovery_urls: + oauth_metadata_request = self._create_oauth_metadata_request(url) + oauth_metadata_response = yield oauth_metadata_request + + if oauth_metadata_response.status_code == 200: + try: + await self._handle_oauth_metadata_response(oauth_metadata_response) + break + except ValidationError: + continue + elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: + break # Non-4XX error, stop trying # Step 3: Register client if needed registration_request = await self._register_client() @@ -575,7 +542,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. token_request = await self._exchange_token(auth_code, code_verifier) token_response = yield token_request await self._handle_token_response(token_response) + except Exception: + logger.exception("OAuth flow error") + raise - # Retry with new tokens - self._add_auth_header(request) - yield request + # Retry with new tokens + self._add_auth_header(request) + yield request diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 50bceddec..e3532e988 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -32,6 +32,7 @@ "HOMEPATH", "LOCALAPPDATA", "PATH", + "PATHEXT", "PROCESSOR_ARCHITECTURE", "SYSTEMDRIVE", "SYSTEMROOT", diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 63b09133f..b1ab2c079 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -43,7 +43,7 @@ MCP_PROTOCOL_VERSION = "mcp-protocol-version" LAST_EVENT_ID = "last-event-id" CONTENT_TYPE = "content-type" -ACCEPT = "Accept" +ACCEPT = "accept" JSON = "application/json" @@ -248,6 +248,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: ctx.metadata.on_resumption_token_update if ctx.metadata else None, ) if is_complete: + await event_source.response.aclose() break async def _handle_post_request(self, ctx: RequestContext) -> None: @@ -330,6 +331,7 @@ async def _handle_sse_response( # If the SSE event indicates completion, like returning respose/error # break the loop if is_complete: + await response.aclose() break except Exception as e: logger.exception("Error reading SSE stream:") diff --git a/src/mcp/mcp/__init__.py b/src/mcp/mcp/__init__.py deleted file mode 100644 index e93b95c90..000000000 --- a/src/mcp/mcp/__init__.py +++ /dev/null @@ -1,116 +0,0 @@ -from .client.session import ClientSession -from .client.session_group import ClientSessionGroup -from .client.stdio import StdioServerParameters, stdio_client -from .server.session import ServerSession -from .server.stdio import stdio_server -from .shared.exceptions import McpError -from .types import ( - CallToolRequest, - ClientCapabilities, - ClientNotification, - ClientRequest, - ClientResult, - CompleteRequest, - CreateMessageRequest, - CreateMessageResult, - ErrorData, - GetPromptRequest, - GetPromptResult, - Implementation, - IncludeContext, - InitializedNotification, - InitializeRequest, - InitializeResult, - JSONRPCError, - JSONRPCRequest, - JSONRPCResponse, - ListPromptsRequest, - ListPromptsResult, - ListResourcesRequest, - ListResourcesResult, - ListToolsResult, - LoggingLevel, - LoggingMessageNotification, - Notification, - PingRequest, - ProgressNotification, - PromptsCapability, - ReadResourceRequest, - ReadResourceResult, - Resource, - ResourcesCapability, - ResourceUpdatedNotification, - RootsCapability, - SamplingMessage, - ServerCapabilities, - ServerNotification, - ServerRequest, - ServerResult, - SetLevelRequest, - StopReason, - SubscribeRequest, - Tool, - ToolsCapability, - UnsubscribeRequest, -) -from .types import ( - Role as SamplingRole, -) - -__all__ = [ - "CallToolRequest", - "ClientCapabilities", - "ClientNotification", - "ClientRequest", - "ClientResult", - "ClientSession", - "ClientSessionGroup", - "CreateMessageRequest", - "CreateMessageResult", - "ErrorData", - "GetPromptRequest", - "GetPromptResult", - "Implementation", - "IncludeContext", - "InitializeRequest", - "InitializeResult", - "InitializedNotification", - "JSONRPCError", - "JSONRPCRequest", - "ListPromptsRequest", - "ListPromptsResult", - "ListResourcesRequest", - "ListResourcesResult", - "ListToolsResult", - "LoggingLevel", - "LoggingMessageNotification", - "McpError", - "Notification", - "PingRequest", - "ProgressNotification", - "PromptsCapability", - "ReadResourceRequest", - "ReadResourceResult", - "ResourcesCapability", - "ResourceUpdatedNotification", - "Resource", - "RootsCapability", - "SamplingMessage", - "SamplingRole", - "ServerCapabilities", - "ServerNotification", - "ServerRequest", - "ServerResult", - "ServerSession", - "SetLevelRequest", - "StdioServerParameters", - "StopReason", - "SubscribeRequest", - "Tool", - "ToolsCapability", - "UnsubscribeRequest", - "stdio_client", - "stdio_server", - "CompleteRequest", - "JSONRPCResponse", -] diff --git a/src/mcp/mcp/cli/__init__.py b/src/mcp/mcp/cli/__init__.py deleted file mode 100644 index 3ef56d806..000000000 --- a/src/mcp/mcp/cli/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""FastMCP CLI package.""" - -from .cli import app - -if __name__ == "__main__": - app() diff --git a/src/mcp/mcp/cli/claude.py b/src/mcp/mcp/cli/claude.py deleted file mode 100644 index 6a2effa3b..000000000 --- a/src/mcp/mcp/cli/claude.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Claude app integration utilities.""" - -import json -import os -import shutil -import sys -from pathlib import Path -from typing import Any - -from mcp.server.fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - -MCP_PACKAGE = "mcp[cli]" - - -def get_claude_config_path() -> Path | None: - """Get the Claude config directory based on platform.""" - if sys.platform == "win32": - path = Path(Path.home(), "AppData", "Roaming", "Claude") - elif sys.platform == "darwin": - path = Path(Path.home(), "Library", "Application Support", "Claude") - elif sys.platform.startswith("linux"): - path = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"), "Claude") - else: - return None - - if path.exists(): - return path - return None - - -def get_uv_path() -> str: - """Get the full path to the uv executable.""" - uv_path = shutil.which("uv") - if not uv_path: - logger.error( - "uv executable not found in PATH, falling back to 'uv'. Please ensure uv is installed and in your PATH" - ) - return "uv" # Fall back to just "uv" if not found - return uv_path - - -def update_claude_config( - file_spec: str, - server_name: str, - *, - with_editable: Path | None = None, - with_packages: list[str] | None = None, - env_vars: dict[str, str] | None = None, -) -> bool: - """Add or update a FastMCP server in Claude's configuration. - - Args: - file_spec: Path to the server file, optionally with :object suffix - server_name: Name for the server in Claude's config - with_editable: Optional directory to install in editable mode - with_packages: Optional list of additional packages to install - env_vars: Optional dictionary of environment variables. These are merged with - any existing variables, with new values taking precedence. - - Raises: - RuntimeError: If Claude Desktop's config directory is not found, indicating - Claude Desktop may not be installed or properly set up. - """ - config_dir = get_claude_config_path() - uv_path = get_uv_path() - if not config_dir: - raise RuntimeError( - "Claude Desktop config directory not found. Please ensure Claude Desktop" - " is installed and has been run at least once to initialize its config." - ) - - config_file = config_dir / "claude_desktop_config.json" - if not config_file.exists(): - try: - config_file.write_text("{}") - except Exception: - logger.exception( - "Failed to create Claude config file", - extra={ - "config_file": str(config_file), - }, - ) - return False - - try: - config = json.loads(config_file.read_text()) - if "mcpServers" not in config: - config["mcpServers"] = {} - - # Always preserve existing env vars and merge with new ones - if server_name in config["mcpServers"] and "env" in config["mcpServers"][server_name]: - existing_env = config["mcpServers"][server_name]["env"] - if env_vars: - # New vars take precedence over existing ones - env_vars = {**existing_env, **env_vars} - else: - env_vars = existing_env - - # Build uv run command - args = ["run"] - - # Collect all packages in a set to deduplicate - packages = {MCP_PACKAGE} - if with_packages: - packages.update(pkg for pkg in with_packages if pkg) - - # Add all packages with --with - for pkg in sorted(packages): - args.extend(["--with", pkg]) - - if with_editable: - args.extend(["--with-editable", str(with_editable)]) - - # Convert file path to absolute before adding to command - # Split off any :object suffix first - if ":" in file_spec: - file_path, server_object = file_spec.rsplit(":", 1) - file_spec = f"{Path(file_path).resolve()}:{server_object}" - else: - file_spec = str(Path(file_spec).resolve()) - - # Add fastmcp run command - args.extend(["mcp", "run", file_spec]) - - server_config: dict[str, Any] = {"command": uv_path, "args": args} - - # Add environment variables if specified - if env_vars: - server_config["env"] = env_vars - - config["mcpServers"][server_name] = server_config - - config_file.write_text(json.dumps(config, indent=2)) - logger.info( - f"Added server '{server_name}' to Claude config", - extra={"config_file": str(config_file)}, - ) - return True - except Exception: - logger.exception( - "Failed to update Claude config", - extra={ - "config_file": str(config_file), - }, - ) - return False diff --git a/src/mcp/mcp/cli/cli.py b/src/mcp/mcp/cli/cli.py deleted file mode 100644 index 4a7257a11..000000000 --- a/src/mcp/mcp/cli/cli.py +++ /dev/null @@ -1,488 +0,0 @@ -"""MCP CLI tools.""" - -import importlib.metadata -import importlib.util -import os -import subprocess -import sys -from pathlib import Path -from typing import Annotated, Any - -from mcp.server import FastMCP -from mcp.server import Server as LowLevelServer - -try: - import typer -except ImportError: - print("Error: typer is required. Install with 'pip install mcp[cli]'") - sys.exit(1) - -try: - from mcp.cli import claude - from mcp.server.fastmcp.utilities.logging import get_logger -except ImportError: - print("Error: mcp.server.fastmcp is not installed or not in PYTHONPATH") - sys.exit(1) - -try: - import dotenv -except ImportError: - dotenv = None - -logger = get_logger("cli") - -app = typer.Typer( - name="mcp", - help="MCP development tools", - add_completion=False, - no_args_is_help=True, # Show help if no args provided -) - - -def _get_npx_command(): - """Get the correct npx command for the current platform.""" - if sys.platform == "win32": - # Try both npx.cmd and npx.exe on Windows - for cmd in ["npx.cmd", "npx.exe", "npx"]: - try: - subprocess.run([cmd, "--version"], check=True, capture_output=True, shell=True) - return cmd - except subprocess.CalledProcessError: - continue - return None - return "npx" # On Unix-like systems, just use npx - - -def _parse_env_var(env_var: str) -> tuple[str, str]: - """Parse environment variable string in format KEY=VALUE.""" - if "=" not in env_var: - logger.error(f"Invalid environment variable format: {env_var}. Must be KEY=VALUE") - sys.exit(1) - key, value = env_var.split("=", 1) - return key.strip(), value.strip() - - -def _build_uv_command( - file_spec: str, - with_editable: Path | None = None, - with_packages: list[str] | None = None, -) -> list[str]: - """Build the uv run command that runs a MCP server through mcp run.""" - cmd = ["uv"] - - cmd.extend(["run", "--with", "mcp"]) - - if with_editable: - cmd.extend(["--with-editable", str(with_editable)]) - - if with_packages: - for pkg in with_packages: - if pkg: - cmd.extend(["--with", pkg]) - - # Add mcp run command - cmd.extend(["mcp", "run", file_spec]) - return cmd - - -def _parse_file_path(file_spec: str) -> tuple[Path, str | None]: - """Parse a file path that may include a server object specification. - - Args: - file_spec: Path to file, optionally with :object suffix - - Returns: - Tuple of (file_path, server_object) - """ - # First check if we have a Windows path (e.g., C:\...) - has_windows_drive = len(file_spec) > 1 and file_spec[1] == ":" - - # Split on the last colon, but only if it's not part of the Windows drive letter - # and there's actually another colon in the string after the drive letter - if ":" in (file_spec[2:] if has_windows_drive else file_spec): - file_str, server_object = file_spec.rsplit(":", 1) - else: - file_str, server_object = file_spec, None - - # Resolve the file path - file_path = Path(file_str).expanduser().resolve() - if not file_path.exists(): - logger.error(f"File not found: {file_path}") - sys.exit(1) - if not file_path.is_file(): - logger.error(f"Not a file: {file_path}") - sys.exit(1) - - return file_path, server_object - - -def _import_server(file: Path, server_object: str | None = None): - """Import a MCP server from a file. - - Args: - file: Path to the file - server_object: Optional object name in format "module:object" or just "object" - - Returns: - The server object - """ - # Add parent directory to Python path so imports can be resolved - file_dir = str(file.parent) - if file_dir not in sys.path: - sys.path.insert(0, file_dir) - - # Import the module - spec = importlib.util.spec_from_file_location("server_module", file) - if not spec or not spec.loader: - logger.error("Could not load module", extra={"file": str(file)}) - sys.exit(1) - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - def _check_server_object(server_object: Any, object_name: str): - """Helper function to check that the server object is supported - - Args: - server_object: The server object to check. - - Returns: - True if it's supported. - """ - if not isinstance(server_object, FastMCP): - logger.error(f"The server object {object_name} is of type {type(server_object)} (expecting {FastMCP}).") - if isinstance(server_object, LowLevelServer): - logger.warning( - "Note that only FastMCP server is supported. Low level Server class is not yet supported." - ) - return False - return True - - # If no object specified, try common server names - if not server_object: - # Look for the most common server object names - for name in ["mcp", "server", "app"]: - if hasattr(module, name): - if not _check_server_object(getattr(module, name), f"{file}:{name}"): - logger.error(f"Ignoring object '{file}:{name}' as it's not a valid server object") - continue - return getattr(module, name) - - logger.error( - f"No server object found in {file}. Please either:\n" - "1. Use a standard variable name (mcp, server, or app)\n" - "2. Specify the object name with file:object syntax" - "3. If the server creates the FastMCP object within main() " - " or another function, refactor the FastMCP object to be a " - " global variable named mcp, server, or app.", - extra={"file": str(file)}, - ) - sys.exit(1) - - # Handle module:object syntax - if ":" in server_object: - module_name, object_name = server_object.split(":", 1) - try: - server_module = importlib.import_module(module_name) - server = getattr(server_module, object_name, None) - except ImportError: - logger.error( - f"Could not import module '{module_name}'", - extra={"file": str(file)}, - ) - sys.exit(1) - else: - # Just object name - server = getattr(module, server_object, None) - - if server is None: - logger.error( - f"Server object '{server_object}' not found", - extra={"file": str(file)}, - ) - sys.exit(1) - - if not _check_server_object(server, server_object): - sys.exit(1) - - return server - - -@app.command() -def version() -> None: - """Show the MCP version.""" - try: - version = importlib.metadata.version("mcp") - print(f"MCP version {version}") - except importlib.metadata.PackageNotFoundError: - print("MCP version unknown (package not installed)") - sys.exit(1) - - -@app.command() -def dev( - file_spec: str = typer.Argument( - ..., - help="Python file to run, optionally with :object suffix", - ), - with_editable: Annotated[ - Path | None, - typer.Option( - "--with-editable", - "-e", - help="Directory containing pyproject.toml to install in editable mode", - exists=True, - file_okay=False, - resolve_path=True, - ), - ] = None, - with_packages: Annotated[ - list[str], - typer.Option( - "--with", - help="Additional packages to install", - ), - ] = [], -) -> None: - """Run a MCP server with the MCP Inspector.""" - file, server_object = _parse_file_path(file_spec) - - logger.debug( - "Starting dev server", - extra={ - "file": str(file), - "server_object": server_object, - "with_editable": str(with_editable) if with_editable else None, - "with_packages": with_packages, - }, - ) - - try: - # Import server to get dependencies - server = _import_server(file, server_object) - if hasattr(server, "dependencies"): - with_packages = list(set(with_packages + server.dependencies)) - - uv_cmd = _build_uv_command(file_spec, with_editable, with_packages) - - # Get the correct npx command - npx_cmd = _get_npx_command() - if not npx_cmd: - logger.error( - "npx not found. Please ensure Node.js and npm are properly installed and added to your system PATH." - ) - sys.exit(1) - - # Run the MCP Inspector command with shell=True on Windows - shell = sys.platform == "win32" - process = subprocess.run( - [npx_cmd, "@modelcontextprotocol/inspector"] + uv_cmd, - check=True, - shell=shell, - env=dict(os.environ.items()), # Convert to list of tuples for env update - ) - sys.exit(process.returncode) - except subprocess.CalledProcessError as e: - logger.error( - "Dev server failed", - extra={ - "file": str(file), - "error": str(e), - "returncode": e.returncode, - }, - ) - sys.exit(e.returncode) - except FileNotFoundError: - logger.error( - "npx not found. Please ensure Node.js and npm are properly installed " - "and added to your system PATH. You may need to restart your terminal " - "after installation.", - extra={"file": str(file)}, - ) - sys.exit(1) - - -@app.command() -def run( - file_spec: str = typer.Argument( - ..., - help="Python file to run, optionally with :object suffix", - ), - transport: Annotated[ - str | None, - typer.Option( - "--transport", - "-t", - help="Transport protocol to use (stdio or sse)", - ), - ] = None, -) -> None: - """Run a MCP server. - - The server can be specified in two ways:\n - 1. Module approach: server.py - runs the module directly, expecting a server.run() call.\n - 2. Import approach: server.py:app - imports and runs the specified server object.\n\n - - Note: This command runs the server directly. You are responsible for ensuring - all dependencies are available.\n - For dependency management, use `mcp install` or `mcp dev` instead. - """ # noqa: E501 - file, server_object = _parse_file_path(file_spec) - - logger.debug( - "Running server", - extra={ - "file": str(file), - "server_object": server_object, - "transport": transport, - }, - ) - - try: - # Import and get server object - server = _import_server(file, server_object) - - # Run the server - kwargs = {} - if transport: - kwargs["transport"] = transport - - server.run(**kwargs) - - except Exception: - logger.exception( - "Failed to run server", - extra={ - "file": str(file), - }, - ) - sys.exit(1) - - -@app.command() -def install( - file_spec: str = typer.Argument( - ..., - help="Python file to run, optionally with :object suffix", - ), - server_name: Annotated[ - str | None, - typer.Option( - "--name", - "-n", - help="Custom name for the server (defaults to server's name attribute or file name)", - ), - ] = None, - with_editable: Annotated[ - Path | None, - typer.Option( - "--with-editable", - "-e", - help="Directory containing pyproject.toml to install in editable mode", - exists=True, - file_okay=False, - resolve_path=True, - ), - ] = None, - with_packages: Annotated[ - list[str], - typer.Option( - "--with", - help="Additional packages to install", - ), - ] = [], - env_vars: Annotated[ - list[str], - typer.Option( - "--env-var", - "-v", - help="Environment variables in KEY=VALUE format", - ), - ] = [], - env_file: Annotated[ - Path | None, - typer.Option( - "--env-file", - "-f", - help="Load environment variables from a .env file", - exists=True, - file_okay=True, - dir_okay=False, - resolve_path=True, - ), - ] = None, -) -> None: - """Install a MCP server in the Claude desktop app. - - Environment variables are preserved once added and only updated if new values - are explicitly provided. - """ - file, server_object = _parse_file_path(file_spec) - - logger.debug( - "Installing server", - extra={ - "file": str(file), - "server_name": server_name, - "server_object": server_object, - "with_editable": str(with_editable) if with_editable else None, - "with_packages": with_packages, - }, - ) - - if not claude.get_claude_config_path(): - logger.error("Claude app not found") - sys.exit(1) - - # Try to import server to get its name, but fall back to file name if dependencies - # missing - name = server_name - server = None - if not name: - try: - server = _import_server(file, server_object) - name = server.name - except (ImportError, ModuleNotFoundError) as e: - logger.debug( - "Could not import server (likely missing dependencies), using file name", - extra={"error": str(e)}, - ) - name = file.stem - - # Get server dependencies if available - server_dependencies = getattr(server, "dependencies", []) if server else [] - if server_dependencies: - with_packages = list(set(with_packages + server_dependencies)) - - # Process environment variables if provided - env_dict: dict[str, str] | None = None - if env_file or env_vars: - env_dict = {} - # Load from .env file if specified - if env_file: - if dotenv: - try: - env_dict |= {k: v for k, v in dotenv.dotenv_values(env_file).items() if v is not None} - except (OSError, ValueError): - logger.exception("Failed to load .env file") - sys.exit(1) - else: - logger.error("python-dotenv is not installed. Cannot load .env file.") - sys.exit(1) - - # Add command line environment variables - for env_var in env_vars: - key, value = _parse_env_var(env_var) - env_dict[key] = value - - if claude.update_claude_config( - file_spec, - name, - with_editable=with_editable, - with_packages=with_packages, - env_vars=env_dict, - ): - logger.info(f"Successfully installed {name} in Claude app") - else: - logger.error(f"Failed to install {name} in Claude app") - sys.exit(1) diff --git a/src/mcp/mcp/client/__init__.py b/src/mcp/mcp/client/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/mcp/mcp/client/__main__.py b/src/mcp/mcp/client/__main__.py deleted file mode 100644 index 2efe05d53..000000000 --- a/src/mcp/mcp/client/__main__.py +++ /dev/null @@ -1,85 +0,0 @@ -import argparse -import logging -import sys -from functools import partial -from urllib.parse import urlparse - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -import mcp.types as types -from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.client.stdio import StdioServerParameters, stdio_client -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder - -if not sys.warnoptions: - import warnings - - warnings.simplefilter("ignore") - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("client") - - -async def message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, -) -> None: - if isinstance(message, Exception): - logger.error("Error: %s", message) - return - - logger.info("Received message from server: %s", message) - - -async def run_session( - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], - client_info: types.Implementation | None = None, -): - async with ClientSession( - read_stream, - write_stream, - message_handler=message_handler, - client_info=client_info, - ) as session: - logger.info("Initializing session") - await session.initialize() - logger.info("Initialized") - - -async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]): - env_dict = dict(env) - - if urlparse(command_or_url).scheme in ("http", "https"): - # Use SSE client for HTTP(S) URLs - async with sse_client(command_or_url) as streams: - await run_session(*streams) - else: - # Use stdio client for commands - server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict) - async with stdio_client(server_parameters) as streams: - await run_session(*streams) - - -def cli(): - parser = argparse.ArgumentParser() - parser.add_argument("command_or_url", help="Command or URL to connect to") - parser.add_argument("args", nargs="*", help="Additional arguments") - parser.add_argument( - "-e", - "--env", - nargs=2, - action="append", - metavar=("KEY", "VALUE"), - help="Environment variables to set. Can be used multiple times.", - default=[], - ) - - args = parser.parse_args() - anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio") - - -if __name__ == "__main__": - cli() diff --git a/src/mcp/mcp/client/auth.py b/src/mcp/mcp/client/auth.py deleted file mode 100644 index 376036e8c..000000000 --- a/src/mcp/mcp/client/auth.py +++ /dev/null @@ -1,551 +0,0 @@ -""" -OAuth2 Authentication implementation for HTTPX. - -Implements authorization code flow with PKCE and automatic token refresh. -""" - -import base64 -import hashlib -import logging -import re -import secrets -import string -import time -from collections.abc import AsyncGenerator, Awaitable, Callable -from dataclasses import dataclass, field -from typing import Protocol -from urllib.parse import urlencode, urljoin, urlparse - -import anyio -import httpx -from pydantic import BaseModel, Field, ValidationError - -from mcp.client.streamable_http import MCP_PROTOCOL_VERSION -from mcp.shared.auth import ( - OAuthClientInformationFull, - OAuthClientMetadata, - OAuthMetadata, - OAuthToken, - ProtectedResourceMetadata, -) -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url -from mcp.types import LATEST_PROTOCOL_VERSION - -logger = logging.getLogger(__name__) - - -class OAuthFlowError(Exception): - """Base exception for OAuth flow errors.""" - - -class OAuthTokenError(OAuthFlowError): - """Raised when token operations fail.""" - - -class OAuthRegistrationError(OAuthFlowError): - """Raised when client registration fails.""" - - -class PKCEParameters(BaseModel): - """PKCE (Proof Key for Code Exchange) parameters.""" - - code_verifier: str = Field(..., min_length=43, max_length=128) - code_challenge: str = Field(..., min_length=43, max_length=128) - - @classmethod - def generate(cls) -> "PKCEParameters": - """Generate new PKCE parameters.""" - code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) - digest = hashlib.sha256(code_verifier.encode()).digest() - code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") - return cls(code_verifier=code_verifier, code_challenge=code_challenge) - - -class TokenStorage(Protocol): - """Protocol for token storage implementations.""" - - async def get_tokens(self) -> OAuthToken | None: - """Get stored tokens.""" - ... - - async def set_tokens(self, tokens: OAuthToken) -> None: - """Store tokens.""" - ... - - async def get_client_info(self) -> OAuthClientInformationFull | None: - """Get stored client information.""" - ... - - async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: - """Store client information.""" - ... - - -@dataclass -class OAuthContext: - """OAuth flow context.""" - - server_url: str - client_metadata: OAuthClientMetadata - storage: TokenStorage - redirect_handler: Callable[[str], Awaitable[None]] - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] - timeout: float = 300.0 - - # Discovered metadata - protected_resource_metadata: ProtectedResourceMetadata | None = None - oauth_metadata: OAuthMetadata | None = None - auth_server_url: str | None = None - protocol_version: str | None = None - - # Client registration - client_info: OAuthClientInformationFull | None = None - - # Token management - current_tokens: OAuthToken | None = None - token_expiry_time: float | None = None - - # State - lock: anyio.Lock = field(default_factory=anyio.Lock) - - # Discovery state for fallback support - discovery_base_url: str | None = None - discovery_pathname: str | None = None - - def get_authorization_base_url(self, server_url: str) -> str: - """Extract base URL by removing path component.""" - parsed = urlparse(server_url) - return f"{parsed.scheme}://{parsed.netloc}" - - def update_token_expiry(self, token: OAuthToken) -> None: - """Update token expiry time.""" - if token.expires_in: - self.token_expiry_time = time.time() + token.expires_in - else: - self.token_expiry_time = None - - def is_token_valid(self) -> bool: - """Check if current token is valid.""" - return bool( - self.current_tokens - and self.current_tokens.access_token - and (not self.token_expiry_time or time.time() <= self.token_expiry_time) - ) - - def can_refresh_token(self) -> bool: - """Check if token can be refreshed.""" - return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) - - def clear_tokens(self) -> None: - """Clear current tokens.""" - self.current_tokens = None - self.token_expiry_time = None - - def get_resource_url(self) -> str: - """Get resource URL for RFC 8707. - - Uses PRM resource if it's a valid parent, otherwise uses canonical server URL. - """ - resource = resource_url_from_server_url(self.server_url) - - # If PRM provides a resource that's a valid parent, use it - if self.protected_resource_metadata and self.protected_resource_metadata.resource: - prm_resource = str(self.protected_resource_metadata.resource) - if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): - resource = prm_resource - - return resource - - def should_include_resource_param(self, protocol_version: str | None = None) -> bool: - """Determine if the resource parameter should be included in OAuth requests. - - Returns True if: - - Protected resource metadata is available, OR - - MCP-Protocol-Version header is 2025-06-18 or later - """ - # If we have protected resource metadata, include the resource param - if self.protected_resource_metadata is not None: - return True - - # If no protocol version provided, don't include resource param - if not protocol_version: - return False - - # Check if protocol version is 2025-06-18 or later - # Version format is YYYY-MM-DD, so string comparison works - return protocol_version >= "2025-06-18" - - -class OAuthClientProvider(httpx.Auth): - """ - OAuth2 authentication for httpx. - Handles OAuth flow with automatic client registration and token storage. - """ - - requires_response_body = True - - def __init__( - self, - server_url: str, - client_metadata: OAuthClientMetadata, - storage: TokenStorage, - redirect_handler: Callable[[str], Awaitable[None]], - callback_handler: Callable[[], Awaitable[tuple[str, str | None]]], - timeout: float = 300.0, - ): - """Initialize OAuth2 authentication.""" - self.context = OAuthContext( - server_url=server_url, - client_metadata=client_metadata, - storage=storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - timeout=timeout, - ) - self._initialized = False - - def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: - """ - Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. - - Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise - """ - if not init_response or init_response.status_code != 401: - return None - - www_auth_header = init_response.headers.get("WWW-Authenticate") - if not www_auth_header: - return None - - # Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted) - pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) - - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) - - return None - - async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request: - # RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response - url = self._extract_resource_metadata_from_www_auth(init_response) - - if not url: - # Fallback to well-known discovery - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") - - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _handle_protected_resource_response(self, response: httpx.Response) -> None: - """Handle discovery response.""" - if response.status_code == 200: - try: - content = await response.aread() - metadata = ProtectedResourceMetadata.model_validate_json(content) - self.context.protected_resource_metadata = metadata - if metadata.authorization_servers: - self.context.auth_server_url = str(metadata.authorization_servers[0]) - except ValidationError: - pass - - def _get_discovery_urls(self) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts.""" - urls: list[str] = [] - auth_server_url = self.context.auth_server_url or self.context.server_url - parsed = urlparse(auth_server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # RFC 8414: Path-aware OAuth discovery - if parsed.path and parsed.path != "/": - oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oauth_path)) - - # OAuth root fallback - urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - - # RFC 8414 section 5: Path-aware OIDC discovery - # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 - if parsed.path and parsed.path != "/": - oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oidc_path)) - - # OIDC 1.0 fallback (appends to full URL per OIDC spec) - oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" - urls.append(oidc_fallback) - - return urls - - async def _register_client(self) -> httpx.Request | None: - """Build registration request or skip if already registered.""" - if self.context.client_info: - return None - - if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: - registration_url = str(self.context.oauth_metadata.registration_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - registration_url = urljoin(auth_base_url, "/register") - - registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - - return httpx.Request( - "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} - ) - - async def _handle_registration_response(self, response: httpx.Response) -> None: - """Handle registration response.""" - if response.status_code not in (200, 201): - await response.aread() - raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") - - try: - content = await response.aread() - client_info = OAuthClientInformationFull.model_validate_json(content) - self.context.client_info = client_info - await self.context.storage.set_client_info(client_info) - except ValidationError as e: - raise OAuthRegistrationError(f"Invalid registration response: {e}") - - async def _perform_authorization(self) -> tuple[str, str]: - """Perform the authorization redirect and get auth code.""" - if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - auth_endpoint = urljoin(auth_base_url, "/authorize") - - if not self.context.client_info: - raise OAuthFlowError("No client info available for authorization") - - # Generate PKCE parameters - pkce_params = PKCEParameters.generate() - state = secrets.token_urlsafe(32) - - auth_params = { - "response_type": "code", - "client_id": self.context.client_info.client_id, - "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), - "state": state, - "code_challenge": pkce_params.code_challenge, - "code_challenge_method": "S256", - } - - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 - - if self.context.client_metadata.scope: - auth_params["scope"] = self.context.client_metadata.scope - - authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" - await self.context.redirect_handler(authorization_url) - - # Wait for callback - auth_code, returned_state = await self.context.callback_handler() - - if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") - - if not auth_code: - raise OAuthFlowError("No authorization code received") - - # Return auth code and code verifier for token exchange - return auth_code, pkce_params.code_verifier - - async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Request: - """Build token exchange request.""" - if not self.context.client_info: - raise OAuthFlowError("Missing client info") - - if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: - token_url = str(self.context.oauth_metadata.token_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - token_url = urljoin(auth_base_url, "/token") - - token_data = { - "grant_type": "authorization_code", - "code": auth_code, - "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), - "client_id": self.context.client_info.client_id, - "code_verifier": code_verifier, - } - - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - token_data["resource"] = self.context.get_resource_url() # RFC 8707 - - if self.context.client_info.client_secret: - token_data["client_secret"] = self.context.client_info.client_secret - - return httpx.Request( - "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) - - async def _handle_token_response(self, response: httpx.Response) -> None: - """Handle token exchange response.""" - if response.status_code != 200: - raise OAuthTokenError(f"Token exchange failed: {response.status_code}") - - try: - content = await response.aread() - token_response = OAuthToken.model_validate_json(content) - - # Validate scopes - if token_response.scope and self.context.client_metadata.scope: - requested_scopes = set(self.context.client_metadata.scope.split()) - returned_scopes = set(token_response.scope.split()) - unauthorized_scopes = returned_scopes - requested_scopes - if unauthorized_scopes: - raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}") - - self.context.current_tokens = token_response - self.context.update_token_expiry(token_response) - await self.context.storage.set_tokens(token_response) - except ValidationError as e: - raise OAuthTokenError(f"Invalid token response: {e}") - - async def _refresh_token(self) -> httpx.Request: - """Build token refresh request.""" - if not self.context.current_tokens or not self.context.current_tokens.refresh_token: - raise OAuthTokenError("No refresh token available") - - if not self.context.client_info: - raise OAuthTokenError("No client info available") - - if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: - token_url = str(self.context.oauth_metadata.token_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - token_url = urljoin(auth_base_url, "/token") - - refresh_data = { - "grant_type": "refresh_token", - "refresh_token": self.context.current_tokens.refresh_token, - "client_id": self.context.client_info.client_id, - } - - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - - if self.context.client_info.client_secret: - refresh_data["client_secret"] = self.context.client_info.client_secret - - return httpx.Request( - "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) - - async def _handle_refresh_response(self, response: httpx.Response) -> bool: - """Handle token refresh response. Returns True if successful.""" - if response.status_code != 200: - logger.warning(f"Token refresh failed: {response.status_code}") - self.context.clear_tokens() - return False - - try: - content = await response.aread() - token_response = OAuthToken.model_validate_json(content) - - self.context.current_tokens = token_response - self.context.update_token_expiry(token_response) - await self.context.storage.set_tokens(token_response) - - return True - except ValidationError: - logger.exception("Invalid refresh response") - self.context.clear_tokens() - return False - - async def _initialize(self) -> None: - """Load stored tokens and client info.""" - self.context.current_tokens = await self.context.storage.get_tokens() - self.context.client_info = await self.context.storage.get_client_info() - self._initialized = True - - def _add_auth_header(self, request: httpx.Request) -> None: - """Add authorization header to request if we have valid tokens.""" - if self.context.current_tokens and self.context.current_tokens.access_token: - request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" - - def _create_oauth_metadata_request(self, url: str) -> httpx.Request: - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self.context.oauth_metadata = metadata - # Apply default scope if needed - if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: - self.context.client_metadata.scope = " ".join(metadata.scopes_supported) - - async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: - """HTTPX auth flow integration.""" - async with self.context.lock: - if not self._initialized: - await self._initialize() - - # Capture protocol version from request headers - self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - - if not self.context.is_token_valid() and self.context.can_refresh_token(): - # Try to refresh token - refresh_request = await self._refresh_token() - refresh_response = yield refresh_request - - if not await self._handle_refresh_response(refresh_response): - # Refresh failed, need full re-authentication - self._initialized = False - - if self.context.is_token_valid(): - self._add_auth_header(request) - - response = yield request - - if response.status_code == 401: - # Perform full OAuth flow - try: - # OAuth flow must be inline due to generator constraints - # Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support) - discovery_request = await self._discover_protected_resource(response) - discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) - - # Step 2: Discover OAuth metadata (with fallback for legacy servers) - discovery_urls = self._get_discovery_urls() - for url in discovery_urls: - oauth_metadata_request = self._create_oauth_metadata_request(url) - oauth_metadata_response = yield oauth_metadata_request - - if oauth_metadata_response.status_code == 200: - try: - await self._handle_oauth_metadata_response(oauth_metadata_response) - break - except ValidationError: - continue - elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: - break # Non-4XX error, stop trying - - # Step 3: Register client if needed - registration_request = await self._register_client() - if registration_request: - registration_response = yield registration_request - await self._handle_registration_response(registration_response) - - # Step 4: Perform authorization - auth_code, code_verifier = await self._perform_authorization() - - # Step 5: Exchange authorization code for tokens - token_request = await self._exchange_token(auth_code, code_verifier) - token_response = yield token_request - await self._handle_token_response(token_response) - except Exception: - logger.exception("OAuth flow error") - raise - - # Retry with new tokens - self._add_auth_header(request) - yield request diff --git a/src/mcp/mcp/client/session.py b/src/mcp/mcp/client/session.py deleted file mode 100644 index 1853ce7c1..000000000 --- a/src/mcp/mcp/client/session.py +++ /dev/null @@ -1,458 +0,0 @@ -import logging -from datetime import timedelta -from typing import Any, Protocol - -import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from jsonschema import SchemaError, ValidationError, validate -from pydantic import AnyUrl, TypeAdapter - -import mcp.types as types -from mcp.shared.context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS - -DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") - -logger = logging.getLogger("client") - - -class SamplingFnT(Protocol): - async def __call__( - self, - context: RequestContext["ClientSession", Any], - params: types.CreateMessageRequestParams, - ) -> types.CreateMessageResult | types.ErrorData: ... - - -class ElicitationFnT(Protocol): - async def __call__( - self, - context: RequestContext["ClientSession", Any], - params: types.ElicitRequestParams, - ) -> types.ElicitResult | types.ErrorData: ... - - -class ListRootsFnT(Protocol): - async def __call__( - self, context: RequestContext["ClientSession", Any] - ) -> types.ListRootsResult | types.ErrorData: ... - - -class LoggingFnT(Protocol): - async def __call__( - self, - params: types.LoggingMessageNotificationParams, - ) -> None: ... - - -class MessageHandlerFnT(Protocol): - async def __call__( - self, - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: ... - - -async def _default_message_handler( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, -) -> None: - await anyio.lowlevel.checkpoint() - - -async def _default_sampling_callback( - context: RequestContext["ClientSession", Any], - params: types.CreateMessageRequestParams, -) -> types.CreateMessageResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Sampling not supported", - ) - - -async def _default_elicitation_callback( - context: RequestContext["ClientSession", Any], - params: types.ElicitRequestParams, -) -> types.ElicitResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="Elicitation not supported", - ) - - -async def _default_list_roots_callback( - context: RequestContext["ClientSession", Any], -) -> types.ListRootsResult | types.ErrorData: - return types.ErrorData( - code=types.INVALID_REQUEST, - message="List roots not supported", - ) - - -async def _default_logging_callback( - params: types.LoggingMessageNotificationParams, -) -> None: - pass - - -ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData) - - -class ClientSession( - BaseSession[ - types.ClientRequest, - types.ClientNotification, - types.ClientResult, - types.ServerRequest, - types.ServerNotification, - ] -): - def __init__( - self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], - read_timeout_seconds: timedelta | None = None, - sampling_callback: SamplingFnT | None = None, - elicitation_callback: ElicitationFnT | None = None, - list_roots_callback: ListRootsFnT | None = None, - logging_callback: LoggingFnT | None = None, - message_handler: MessageHandlerFnT | None = None, - client_info: types.Implementation | None = None, - ) -> None: - super().__init__( - read_stream, - write_stream, - types.ServerRequest, - types.ServerNotification, - read_timeout_seconds=read_timeout_seconds, - ) - self._client_info = client_info or DEFAULT_CLIENT_INFO - self._sampling_callback = sampling_callback or _default_sampling_callback - self._elicitation_callback = elicitation_callback or _default_elicitation_callback - self._list_roots_callback = list_roots_callback or _default_list_roots_callback - self._logging_callback = logging_callback or _default_logging_callback - self._message_handler = message_handler or _default_message_handler - self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} - - async def initialize(self) -> types.InitializeResult: - sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None - elicitation = ( - types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None - ) - roots = ( - # TODO: Should this be based on whether we - # _will_ send notifications, or only whether - # they're supported? - types.RootsCapability(listChanged=True) - if self._list_roots_callback is not _default_list_roots_callback - else None - ) - - result = await self.send_request( - types.ClientRequest( - types.InitializeRequest( - method="initialize", - params=types.InitializeRequestParams( - protocolVersion=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities( - sampling=sampling, - elicitation=elicitation, - experimental=None, - roots=roots, - ), - clientInfo=self._client_info, - ), - ) - ), - types.InitializeResult, - ) - - if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS: - raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}") - - await self.send_notification( - types.ClientNotification(types.InitializedNotification(method="notifications/initialized")) - ) - - return result - - async def send_ping(self) -> types.EmptyResult: - """Send a ping request.""" - return await self.send_request( - types.ClientRequest( - types.PingRequest( - method="ping", - ) - ), - types.EmptyResult, - ) - - async def send_progress_notification( - self, - progress_token: str | int, - progress: float, - total: float | None = None, - message: str | None = None, - ) -> None: - """Send a progress notification.""" - await self.send_notification( - types.ClientNotification( - types.ProgressNotification( - method="notifications/progress", - params=types.ProgressNotificationParams( - progressToken=progress_token, - progress=progress, - total=total, - message=message, - ), - ), - ) - ) - - async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult: - """Send a logging/setLevel request.""" - return await self.send_request( - types.ClientRequest( - types.SetLevelRequest( - method="logging/setLevel", - params=types.SetLevelRequestParams(level=level), - ) - ), - types.EmptyResult, - ) - - async def list_resources(self, cursor: str | None = None) -> types.ListResourcesResult: - """Send a resources/list request.""" - return await self.send_request( - types.ClientRequest( - types.ListResourcesRequest( - method="resources/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), - types.ListResourcesResult, - ) - - async def list_resource_templates(self, cursor: str | None = None) -> types.ListResourceTemplatesResult: - """Send a resources/templates/list request.""" - return await self.send_request( - types.ClientRequest( - types.ListResourceTemplatesRequest( - method="resources/templates/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), - types.ListResourceTemplatesResult, - ) - - async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: - """Send a resources/read request.""" - return await self.send_request( - types.ClientRequest( - types.ReadResourceRequest( - method="resources/read", - params=types.ReadResourceRequestParams(uri=uri), - ) - ), - types.ReadResourceResult, - ) - - async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: - """Send a resources/subscribe request.""" - return await self.send_request( - types.ClientRequest( - types.SubscribeRequest( - method="resources/subscribe", - params=types.SubscribeRequestParams(uri=uri), - ) - ), - types.EmptyResult, - ) - - async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: - """Send a resources/unsubscribe request.""" - return await self.send_request( - types.ClientRequest( - types.UnsubscribeRequest( - method="resources/unsubscribe", - params=types.UnsubscribeRequestParams(uri=uri), - ) - ), - types.EmptyResult, - ) - - async def call_tool( - self, - name: str, - arguments: dict[str, Any] | None = None, - read_timeout_seconds: timedelta | None = None, - progress_callback: ProgressFnT | None = None, - ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" - - result = await self.send_request( - types.ClientRequest( - types.CallToolRequest( - method="tools/call", - params=types.CallToolRequestParams( - name=name, - arguments=arguments, - ), - ) - ), - types.CallToolResult, - request_read_timeout_seconds=read_timeout_seconds, - progress_callback=progress_callback, - ) - - if not result.isError: - await self._validate_tool_result(name, result) - - return result - - async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None: - """Validate the structured content of a tool result against its output schema.""" - if name not in self._tool_output_schemas: - # refresh output schema cache - await self.list_tools() - - output_schema = None - if name in self._tool_output_schemas: - output_schema = self._tool_output_schemas.get(name) - else: - logger.warning(f"Tool {name} not listed by server, cannot validate any structured content") - - if output_schema is not None: - if result.structuredContent is None: - raise RuntimeError(f"Tool {name} has an output schema but did not return structured content") - try: - validate(result.structuredContent, output_schema) - except ValidationError as e: - raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") - except SchemaError as e: - raise RuntimeError(f"Invalid schema for tool {name}: {e}") - - async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult: - """Send a prompts/list request.""" - return await self.send_request( - types.ClientRequest( - types.ListPromptsRequest( - method="prompts/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), - types.ListPromptsResult, - ) - - async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: - """Send a prompts/get request.""" - return await self.send_request( - types.ClientRequest( - types.GetPromptRequest( - method="prompts/get", - params=types.GetPromptRequestParams(name=name, arguments=arguments), - ) - ), - types.GetPromptResult, - ) - - async def complete( - self, - ref: types.ResourceTemplateReference | types.PromptReference, - argument: dict[str, str], - context_arguments: dict[str, str] | None = None, - ) -> types.CompleteResult: - """Send a completion/complete request.""" - context = None - if context_arguments is not None: - context = types.CompletionContext(arguments=context_arguments) - - return await self.send_request( - types.ClientRequest( - types.CompleteRequest( - method="completion/complete", - params=types.CompleteRequestParams( - ref=ref, - argument=types.CompletionArgument(**argument), - context=context, - ), - ) - ), - types.CompleteResult, - ) - - async def list_tools(self, cursor: str | None = None) -> types.ListToolsResult: - """Send a tools/list request.""" - result = await self.send_request( - types.ClientRequest( - types.ListToolsRequest( - method="tools/list", - params=types.PaginatedRequestParams(cursor=cursor) if cursor is not None else None, - ) - ), - types.ListToolsResult, - ) - - # Cache tool output schemas for future validation - # Note: don't clear the cache, as we may be using a cursor - for tool in result.tools: - self._tool_output_schemas[tool.name] = tool.outputSchema - - return result - - async def send_roots_list_changed(self) -> None: - """Send a roots/list_changed notification.""" - await self.send_notification( - types.ClientNotification( - types.RootsListChangedNotification( - method="notifications/roots/list_changed", - ) - ) - ) - - async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession, Any]( - request_id=responder.request_id, - meta=responder.request_meta, - session=self, - lifespan_context=None, - ) - - match responder.request.root: - case types.CreateMessageRequest(params=params): - with responder: - response = await self._sampling_callback(ctx, params) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - - case types.ElicitRequest(params=params): - with responder: - response = await self._elicitation_callback(ctx, params) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - - case types.ListRootsRequest(): - with responder: - response = await self._list_roots_callback(ctx) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - - case types.PingRequest(): - with responder: - return await responder.respond(types.ClientResult(root=types.EmptyResult())) - - async def _handle_incoming( - self, - req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - """Handle incoming messages by forwarding to the message handler.""" - await self._message_handler(req) - - async def _received_notification(self, notification: types.ServerNotification) -> None: - """Handle notifications from the server.""" - # Process specific notification types - match notification.root: - case types.LoggingMessageNotification(params=params): - await self._logging_callback(params) - case _: - pass diff --git a/src/mcp/mcp/client/session_group.py b/src/mcp/mcp/client/session_group.py deleted file mode 100644 index 700b5417f..000000000 --- a/src/mcp/mcp/client/session_group.py +++ /dev/null @@ -1,366 +0,0 @@ -""" -SessionGroup concurrently manages multiple MCP session connections. - -Tools, resources, and prompts are aggregated across servers. Servers may -be connected to or disconnected from at any point after initialization. - -This abstractions can handle naming collisions using a custom user-provided -hook. -""" - -import contextlib -import logging -from collections.abc import Callable -from datetime import timedelta -from types import TracebackType -from typing import Any, TypeAlias - -import anyio -from pydantic import BaseModel -from typing_extensions import Self - -import mcp -from mcp import types -from mcp.client.sse import sse_client -from mcp.client.stdio import StdioServerParameters -from mcp.client.streamable_http import streamablehttp_client -from mcp.shared.exceptions import McpError - - -class SseServerParameters(BaseModel): - """Parameters for intializing a sse_client.""" - - # The endpoint URL. - url: str - - # Optional headers to include in requests. - headers: dict[str, Any] | None = None - - # HTTP timeout for regular operations. - timeout: float = 5 - - # Timeout for SSE read operations. - sse_read_timeout: float = 60 * 5 - - -class StreamableHttpParameters(BaseModel): - """Parameters for intializing a streamablehttp_client.""" - - # The endpoint URL. - url: str - - # Optional headers to include in requests. - headers: dict[str, Any] | None = None - - # HTTP timeout for regular operations. - timeout: timedelta = timedelta(seconds=30) - - # Timeout for SSE read operations. - sse_read_timeout: timedelta = timedelta(seconds=60 * 5) - - # Close the client session when the transport closes. - terminate_on_close: bool = True - - -ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters - - -class ClientSessionGroup: - """Client for managing connections to multiple MCP servers. - - This class is responsible for encapsulating management of server connections. - It aggregates tools, resources, and prompts from all connected servers. - - For auxiliary handlers, such as resource subscription, this is delegated to - the client and can be accessed via the session. - - Example Usage: - name_fn = lambda name, server_info: f"{(server_info.name)}_{name}" - async with ClientSessionGroup(component_name_hook=name_fn) as group: - for server_params in server_params: - await group.connect_to_server(server_param) - ... - - """ - - class _ComponentNames(BaseModel): - """Used for reverse index to find components.""" - - prompts: set[str] = set() - resources: set[str] = set() - tools: set[str] = set() - - # Standard MCP components. - _prompts: dict[str, types.Prompt] - _resources: dict[str, types.Resource] - _tools: dict[str, types.Tool] - - # Client-server connection management. - _sessions: dict[mcp.ClientSession, _ComponentNames] - _tool_to_session: dict[str, mcp.ClientSession] - _exit_stack: contextlib.AsyncExitStack - _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] - - # Optional fn consuming (component_name, serverInfo) for custom names. - # This is provide a means to mitigate naming conflicts across servers. - # Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}" - _ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str] - _component_name_hook: _ComponentNameHook | None - - def __init__( - self, - exit_stack: contextlib.AsyncExitStack | None = None, - component_name_hook: _ComponentNameHook | None = None, - ) -> None: - """Initializes the MCP client.""" - - self._tools = {} - self._resources = {} - self._prompts = {} - - self._sessions = {} - self._tool_to_session = {} - if exit_stack is None: - self._exit_stack = contextlib.AsyncExitStack() - self._owns_exit_stack = True - else: - self._exit_stack = exit_stack - self._owns_exit_stack = False - self._session_exit_stacks = {} - self._component_name_hook = component_name_hook - - async def __aenter__(self) -> Self: - # Enter the exit stack only if we created it ourselves - if self._owns_exit_stack: - await self._exit_stack.__aenter__() - return self - - async def __aexit__( - self, - _exc_type: type[BaseException] | None, - _exc_val: BaseException | None, - _exc_tb: TracebackType | None, - ) -> bool | None: - """Closes session exit stacks and main exit stack upon completion.""" - - # Only close the main exit stack if we created it - if self._owns_exit_stack: - await self._exit_stack.aclose() - - # Concurrently close session stacks. - async with anyio.create_task_group() as tg: - for exit_stack in self._session_exit_stacks.values(): - tg.start_soon(exit_stack.aclose) - - @property - def sessions(self) -> list[mcp.ClientSession]: - """Returns the list of sessions being managed.""" - return list(self._sessions.keys()) - - @property - def prompts(self) -> dict[str, types.Prompt]: - """Returns the prompts as a dictionary of names to prompts.""" - return self._prompts - - @property - def resources(self) -> dict[str, types.Resource]: - """Returns the resources as a dictionary of names to resources.""" - return self._resources - - @property - def tools(self) -> dict[str, types.Tool]: - """Returns the tools as a dictionary of names to tools.""" - return self._tools - - async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: - """Executes a tool given its name and arguments.""" - session = self._tool_to_session[name] - session_tool_name = self.tools[name].name - return await session.call_tool(session_tool_name, args) - - async def disconnect_from_server(self, session: mcp.ClientSession) -> None: - """Disconnects from a single MCP server.""" - - session_known_for_components = session in self._sessions - session_known_for_stack = session in self._session_exit_stacks - - if not session_known_for_components and not session_known_for_stack: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message="Provided session is not managed or already disconnected.", - ) - ) - - if session_known_for_components: - component_names = self._sessions.pop(session) # Pop from _sessions tracking - - # Remove prompts associated with the session. - for name in component_names.prompts: - if name in self._prompts: - del self._prompts[name] - # Remove resources associated with the session. - for name in component_names.resources: - if name in self._resources: - del self._resources[name] - # Remove tools associated with the session. - for name in component_names.tools: - if name in self._tools: - del self._tools[name] - if name in self._tool_to_session: - del self._tool_to_session[name] - - # Clean up the session's resources via its dedicated exit stack - if session_known_for_stack: - session_stack_to_close = self._session_exit_stacks.pop(session) - await session_stack_to_close.aclose() - - async def connect_with_session( - self, server_info: types.Implementation, session: mcp.ClientSession - ) -> mcp.ClientSession: - """Connects to a single MCP server.""" - await self._aggregate_components(server_info, session) - return session - - async def connect_to_server( - self, - server_params: ServerParameters, - ) -> mcp.ClientSession: - """Connects to a single MCP server.""" - server_info, session = await self._establish_session(server_params) - return await self.connect_with_session(server_info, session) - - async def _establish_session( - self, server_params: ServerParameters - ) -> tuple[types.Implementation, mcp.ClientSession]: - """Establish a client session to an MCP server.""" - - session_stack = contextlib.AsyncExitStack() - try: - # Create read and write streams that facilitate io with the server. - if isinstance(server_params, StdioServerParameters): - client = mcp.stdio_client(server_params) - read, write = await session_stack.enter_async_context(client) - elif isinstance(server_params, SseServerParameters): - client = sse_client( - url=server_params.url, - headers=server_params.headers, - timeout=server_params.timeout, - sse_read_timeout=server_params.sse_read_timeout, - ) - read, write = await session_stack.enter_async_context(client) - else: - client = streamablehttp_client( - url=server_params.url, - headers=server_params.headers, - timeout=server_params.timeout, - sse_read_timeout=server_params.sse_read_timeout, - terminate_on_close=server_params.terminate_on_close, - ) - read, write, _ = await session_stack.enter_async_context(client) - - session = await session_stack.enter_async_context(mcp.ClientSession(read, write)) - result = await session.initialize() - - # Session successfully initialized. - # Store its stack and register the stack with the main group stack. - self._session_exit_stacks[session] = session_stack - # session_stack itself becomes a resource managed by the - # main _exit_stack. - await self._exit_stack.enter_async_context(session_stack) - - return result.serverInfo, session - except Exception: - # If anything during this setup fails, ensure the session-specific - # stack is closed. - await session_stack.aclose() - raise - - async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None: - """Aggregates prompts, resources, and tools from a given session.""" - - # Create a reverse index so we can find all prompts, resources, and - # tools belonging to this session. Used for removing components from - # the session group via self.disconnect_from_server. - component_names = self._ComponentNames() - - # Temporary components dicts. We do not want to modify the aggregate - # lists in case of an intermediate failure. - prompts_temp: dict[str, types.Prompt] = {} - resources_temp: dict[str, types.Resource] = {} - tools_temp: dict[str, types.Tool] = {} - tool_to_session_temp: dict[str, mcp.ClientSession] = {} - - # Query the server for its prompts and aggregate to list. - try: - prompts = (await session.list_prompts()).prompts - for prompt in prompts: - name = self._component_name(prompt.name, server_info) - prompts_temp[name] = prompt - component_names.prompts.add(name) - except McpError as err: - logging.warning(f"Could not fetch prompts: {err}") - - # Query the server for its resources and aggregate to list. - try: - resources = (await session.list_resources()).resources - for resource in resources: - name = self._component_name(resource.name, server_info) - resources_temp[name] = resource - component_names.resources.add(name) - except McpError as err: - logging.warning(f"Could not fetch resources: {err}") - - # Query the server for its tools and aggregate to list. - try: - tools = (await session.list_tools()).tools - for tool in tools: - name = self._component_name(tool.name, server_info) - tools_temp[name] = tool - tool_to_session_temp[name] = session - component_names.tools.add(name) - except McpError as err: - logging.warning(f"Could not fetch tools: {err}") - - # Clean up exit stack for session if we couldn't retrieve anything - # from the server. - if not any((prompts_temp, resources_temp, tools_temp)): - del self._session_exit_stacks[session] - - # Check for duplicates. - matching_prompts = prompts_temp.keys() & self._prompts.keys() - if matching_prompts: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message=f"{matching_prompts} already exist in group prompts.", - ) - ) - matching_resources = resources_temp.keys() & self._resources.keys() - if matching_resources: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message=f"{matching_resources} already exist in group resources.", - ) - ) - matching_tools = tools_temp.keys() & self._tools.keys() - if matching_tools: - raise McpError( - types.ErrorData( - code=types.INVALID_PARAMS, - message=f"{matching_tools} already exist in group tools.", - ) - ) - - # Aggregate components. - self._sessions[session] = component_names - self._prompts.update(prompts_temp) - self._resources.update(resources_temp) - self._tools.update(tools_temp) - self._tool_to_session.update(tool_to_session_temp) - - def _component_name(self, name: str, server_info: types.Implementation) -> str: - if self._component_name_hook: - return self._component_name_hook(name, server_info) - return name diff --git a/src/mcp/mcp/client/sse.py b/src/mcp/mcp/client/sse.py deleted file mode 100644 index 7ca8d19af..000000000 --- a/src/mcp/mcp/client/sse.py +++ /dev/null @@ -1,144 +0,0 @@ -import logging -from contextlib import asynccontextmanager -from typing import Any -from urllib.parse import urljoin, urlparse - -import anyio -import httpx -from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from httpx_sse import aconnect_sse - -import mcp.types as types -from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client -from mcp.shared.message import SessionMessage - -logger = logging.getLogger(__name__) - - -def remove_request_params(url: str) -> str: - return urljoin(url, urlparse(url).path) - - -@asynccontextmanager -async def sse_client( - url: str, - headers: dict[str, Any] | None = None, - timeout: float = 5, - sse_read_timeout: float = 60 * 5, - httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, - auth: httpx.Auth | None = None, -): - """ - Client transport for SSE. - - `sse_read_timeout` determines how long (in seconds) the client will wait for a new - event before disconnecting. All other HTTP operations are controlled by `timeout`. - - Args: - url: The SSE endpoint URL. - headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations. - sse_read_timeout: Timeout for SSE read operations. - auth: Optional HTTPX authentication handler. - """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) - ) as client: - async with aconnect_sse( - client, - "GET", - url, - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader( - task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ): - try: - async for sse in event_source.aiter_sse(): - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.debug(f"Received endpoint URL: {endpoint_url}") - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - - task_status.started(endpoint_url) - - case "message": - try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data - ) - logger.debug(f"Received server message: {message}") - except Exception as exc: - logger.exception("Error parsing server message") - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - case _: - logger.warning(f"Unknown SSE event: {sse.event}") - except Exception as exc: - logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) - finally: - await read_stream_writer.aclose() - - async def post_writer(endpoint_url: str): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug(f"Sending client message: {session_message}") - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), - ) - response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: - logger.exception("Error in post_writer") - finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") - tg.start_soon(post_writer, endpoint_url) - - try: - yield read_stream, write_stream - finally: - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() diff --git a/src/mcp/mcp/client/stdio/__init__.py b/src/mcp/mcp/client/stdio/__init__.py deleted file mode 100644 index e3532e988..000000000 --- a/src/mcp/mcp/client/stdio/__init__.py +++ /dev/null @@ -1,277 +0,0 @@ -import logging -import os -import sys -from contextlib import asynccontextmanager -from pathlib import Path -from typing import Literal, TextIO - -import anyio -import anyio.lowlevel -from anyio.abc import Process -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from anyio.streams.text import TextReceiveStream -from pydantic import BaseModel, Field - -import mcp.types as types -from mcp.os.posix.utilities import terminate_posix_process_tree -from mcp.os.win32.utilities import ( - FallbackProcess, - create_windows_process, - get_windows_executable_command, - terminate_windows_process_tree, -) -from mcp.shared.message import SessionMessage - -logger = logging.getLogger(__name__) - -# Environment variables to inherit by default -DEFAULT_INHERITED_ENV_VARS = ( - [ - "APPDATA", - "HOMEDRIVE", - "HOMEPATH", - "LOCALAPPDATA", - "PATH", - "PATHEXT", - "PROCESSOR_ARCHITECTURE", - "SYSTEMDRIVE", - "SYSTEMROOT", - "TEMP", - "USERNAME", - "USERPROFILE", - ] - if sys.platform == "win32" - else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"] -) - -# Timeout for process termination before falling back to force kill -PROCESS_TERMINATION_TIMEOUT = 2.0 - - -def get_default_environment() -> dict[str, str]: - """ - Returns a default environment object including only environment variables deemed - safe to inherit. - """ - env: dict[str, str] = {} - - for key in DEFAULT_INHERITED_ENV_VARS: - value = os.environ.get(key) - if value is None: - continue - - if value.startswith("()"): - # Skip functions, which are a security risk - continue - - env[key] = value - - return env - - -class StdioServerParameters(BaseModel): - command: str - """The executable to run to start the server.""" - - args: list[str] = Field(default_factory=list) - """Command line arguments to pass to the executable.""" - - env: dict[str, str] | None = None - """ - The environment to use when spawning the process. - - If not specified, the result of get_default_environment() will be used. - """ - - cwd: str | Path | None = None - """The working directory to use when spawning the process.""" - - encoding: str = "utf-8" - """ - The text encoding used when sending/receiving messages to the server - - defaults to utf-8 - """ - - encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict" - """ - The text encoding error handler. - - See https://docs.python.org/3/library/codecs.html#codec-base-classes for - explanations of possible values - """ - - -@asynccontextmanager -async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): - """ - Client transport for stdio: this will connect to a server by spawning a - process and communicating with it over stdin/stdout. - """ - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - try: - command = _get_executable_command(server.command) - - # Open process with stderr piped for capture - process = await _create_platform_compatible_process( - command=command, - args=server.args, - env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()), - errlog=errlog, - cwd=server.cwd, - ) - except OSError: - # Clean up streams if process creation fails - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - raise - - async def stdout_reader(): - assert process.stdout, "Opened process is missing stdout" - - try: - async with read_stream_writer: - buffer = "" - async for chunk in TextReceiveStream( - process.stdout, - encoding=server.encoding, - errors=server.encoding_error_handler, - ): - lines = (buffer + chunk).split("\n") - buffer = lines.pop() - - for line in lines: - try: - message = types.JSONRPCMessage.model_validate_json(line) - except Exception as exc: - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: - await anyio.lowlevel.checkpoint() - - async def stdin_writer(): - assert process.stdin, "Opened process is missing stdin" - - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) - await process.stdin.send( - (json + "\n").encode( - encoding=server.encoding, - errors=server.encoding_error_handler, - ) - ) - except anyio.ClosedResourceError: - await anyio.lowlevel.checkpoint() - - async with ( - anyio.create_task_group() as tg, - process, - ): - tg.start_soon(stdout_reader) - tg.start_soon(stdin_writer) - try: - yield read_stream, write_stream - finally: - # MCP spec: stdio shutdown sequence - # 1. Close input stream to server - # 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time - # 3. Send SIGKILL if still not exited - if process.stdin: - try: - await process.stdin.aclose() - except Exception: - # stdin might already be closed, which is fine - pass - - try: - # Give the process time to exit gracefully after stdin closes - with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT): - await process.wait() - except TimeoutError: - # Process didn't exit from stdin closure, use platform-specific termination - # which handles SIGTERM -> SIGKILL escalation - await _terminate_process_tree(process) - except ProcessLookupError: - # Process already exited, which is fine - pass - await read_stream.aclose() - await write_stream.aclose() - await read_stream_writer.aclose() - await write_stream_reader.aclose() - - -def _get_executable_command(command: str) -> str: - """ - Get the correct executable command normalized for the current platform. - - Args: - command: Base command (e.g., 'uvx', 'npx') - - Returns: - str: Platform-appropriate command - """ - if sys.platform == "win32": - return get_windows_executable_command(command) - else: - return command - - -async def _create_platform_compatible_process( - command: str, - args: list[str], - env: dict[str, str] | None = None, - errlog: TextIO = sys.stderr, - cwd: Path | str | None = None, -): - """ - Creates a subprocess in a platform-compatible way. - - Unix: Creates process in a new session/process group for killpg support - Windows: Creates process in a Job Object for reliable child termination - """ - if sys.platform == "win32": - process = await create_windows_process(command, args, env, errlog, cwd) - else: - process = await anyio.open_process( - [command, *args], - env=env, - stderr=errlog, - cwd=cwd, - start_new_session=True, - ) - - return process - - -async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None: - """ - Terminate a process and all its children using platform-specific methods. - - Unix: Uses os.killpg() for atomic process group termination - Windows: Uses Job Objects via pywin32 for reliable child process cleanup - - Args: - process: The process to terminate - timeout_seconds: Timeout in seconds before force killing (default: 2.0) - """ - if sys.platform == "win32": - await terminate_windows_process_tree(process, timeout_seconds) - else: - # FallbackProcess should only be used for Windows compatibility - assert isinstance(process, Process) - await terminate_posix_process_tree(process, timeout_seconds) diff --git a/src/mcp/mcp/client/streamable_http.py b/src/mcp/mcp/client/streamable_http.py deleted file mode 100644 index b1ab2c079..000000000 --- a/src/mcp/mcp/client/streamable_http.py +++ /dev/null @@ -1,511 +0,0 @@ -""" -StreamableHTTP Client Transport Module - -This module implements the StreamableHTTP transport for MCP clients, -providing support for HTTP POST requests with optional SSE streaming responses -and session management. -""" - -import logging -from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager -from dataclasses import dataclass -from datetime import timedelta - -import anyio -import httpx -from anyio.abc import TaskGroup -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from httpx_sse import EventSource, ServerSentEvent, aconnect_sse - -from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client -from mcp.shared.message import ClientMessageMetadata, SessionMessage -from mcp.types import ( - ErrorData, - InitializeResult, - JSONRPCError, - JSONRPCMessage, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - RequestId, -) - -logger = logging.getLogger(__name__) - - -SessionMessageOrError = SessionMessage | Exception -StreamWriter = MemoryObjectSendStream[SessionMessageOrError] -StreamReader = MemoryObjectReceiveStream[SessionMessage] -GetSessionIdCallback = Callable[[], str | None] - -MCP_SESSION_ID = "mcp-session-id" -MCP_PROTOCOL_VERSION = "mcp-protocol-version" -LAST_EVENT_ID = "last-event-id" -CONTENT_TYPE = "content-type" -ACCEPT = "accept" - - -JSON = "application/json" -SSE = "text/event-stream" - - -class StreamableHTTPError(Exception): - """Base exception for StreamableHTTP transport errors.""" - - -class ResumptionError(StreamableHTTPError): - """Raised when resumption request is invalid.""" - - -@dataclass -class RequestContext: - """Context for a request operation.""" - - client: httpx.AsyncClient - headers: dict[str, str] - session_id: str | None - session_message: SessionMessage - metadata: ClientMessageMetadata | None - read_stream_writer: StreamWriter - sse_read_timeout: float - - -class StreamableHTTPTransport: - """StreamableHTTP client transport implementation.""" - - def __init__( - self, - url: str, - headers: dict[str, str] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, - auth: httpx.Auth | None = None, - ) -> None: - """Initialize the StreamableHTTP transport. - - Args: - url: The endpoint URL. - headers: Optional headers to include in requests. - timeout: HTTP timeout for regular operations. - sse_read_timeout: Timeout for SSE read operations. - auth: Optional HTTPX authentication handler. - """ - self.url = url - self.headers = headers or {} - self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout - self.sse_read_timeout = ( - sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout - ) - self.auth = auth - self.session_id = None - self.protocol_version = None - self.request_headers = { - ACCEPT: f"{JSON}, {SSE}", - CONTENT_TYPE: JSON, - **self.headers, - } - - def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]: - """Update headers with session ID and protocol version if available.""" - headers = base_headers.copy() - if self.session_id: - headers[MCP_SESSION_ID] = self.session_id - if self.protocol_version: - headers[MCP_PROTOCOL_VERSION] = self.protocol_version - return headers - - def _is_initialization_request(self, message: JSONRPCMessage) -> bool: - """Check if the message is an initialization request.""" - return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" - - def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: - """Check if the message is an initialized notification.""" - return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" - - def _maybe_extract_session_id_from_response( - self, - response: httpx.Response, - ) -> None: - """Extract and store session ID from response headers.""" - new_session_id = response.headers.get(MCP_SESSION_ID) - if new_session_id: - self.session_id = new_session_id - logger.info(f"Received session ID: {self.session_id}") - - def _maybe_extract_protocol_version_from_message( - self, - message: JSONRPCMessage, - ) -> None: - """Extract protocol version from initialization response message.""" - if isinstance(message.root, JSONRPCResponse) and message.root.result: - try: - # Parse the result as InitializeResult for type safety - init_result = InitializeResult.model_validate(message.root.result) - self.protocol_version = str(init_result.protocolVersion) - logger.info(f"Negotiated protocol version: {self.protocol_version}") - except Exception as exc: - logger.warning(f"Failed to parse initialization response as InitializeResult: {exc}") - logger.warning(f"Raw result: {message.root.result}") - - async def _handle_sse_event( - self, - sse: ServerSentEvent, - read_stream_writer: StreamWriter, - original_request_id: RequestId | None = None, - resumption_callback: Callable[[str], Awaitable[None]] | None = None, - is_initialization: bool = False, - ) -> bool: - """Handle an SSE event, returning True if the response is complete.""" - if sse.event == "message": - try: - message = JSONRPCMessage.model_validate_json(sse.data) - logger.debug(f"SSE message: {message}") - - # Extract protocol version from initialization response - if is_initialization: - self._maybe_extract_protocol_version_from_message(message) - - # If this is a response and we have original_request_id, replace it - if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): - message.root.id = original_request_id - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - - # Call resumption token callback if we have an ID - if sse.id and resumption_callback: - await resumption_callback(sse.id) - - # If this is a response or error return True indicating completion - # Otherwise, return False to continue listening - return isinstance(message.root, JSONRPCResponse | JSONRPCError) - - except Exception as exc: - logger.exception("Error parsing SSE message") - await read_stream_writer.send(exc) - return False - else: - logger.warning(f"Unknown SSE event: {sse.event}") - return False - - async def handle_get_stream( - self, - client: httpx.AsyncClient, - read_stream_writer: StreamWriter, - ) -> None: - """Handle GET stream for server-initiated messages.""" - try: - if not self.session_id: - return - - headers = self._prepare_request_headers(self.request_headers) - - async with aconnect_sse( - client, - "GET", - self.url, - headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("GET SSE connection established") - - async for sse in event_source.aiter_sse(): - await self._handle_sse_event(sse, read_stream_writer) - - except Exception as exc: - logger.debug(f"GET stream error (non-fatal): {exc}") - - async def _handle_resumption_request(self, ctx: RequestContext) -> None: - """Handle a resumption request using GET with SSE.""" - headers = self._prepare_request_headers(ctx.headers) - if ctx.metadata and ctx.metadata.resumption_token: - headers[LAST_EVENT_ID] = ctx.metadata.resumption_token - else: - raise ResumptionError("Resumption request requires a resumption token") - - # Extract original request ID to map responses - original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): - original_request_id = ctx.session_message.message.root.id - - async with aconnect_sse( - ctx.client, - "GET", - self.url, - headers=headers, - timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("Resumption GET SSE connection established") - - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - original_request_id, - ctx.metadata.on_resumption_token_update if ctx.metadata else None, - ) - if is_complete: - await event_source.response.aclose() - break - - async def _handle_post_request(self, ctx: RequestContext) -> None: - """Handle a POST request with response processing.""" - headers = self._prepare_request_headers(ctx.headers) - message = ctx.session_message.message - is_initialization = self._is_initialization_request(message) - - async with ctx.client.stream( - "POST", - self.url, - json=message.model_dump(by_alias=True, mode="json", exclude_none=True), - headers=headers, - ) as response: - if response.status_code == 202: - logger.debug("Received 202 Accepted") - return - - if response.status_code == 404: - if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( - ctx.read_stream_writer, - message.root.id, - ) - return - - response.raise_for_status() - if is_initialization: - self._maybe_extract_session_id_from_response(response) - - content_type = response.headers.get(CONTENT_TYPE, "").lower() - - if content_type.startswith(JSON): - await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) - elif content_type.startswith(SSE): - await self._handle_sse_response(response, ctx, is_initialization) - else: - await self._handle_unexpected_content_type( - content_type, - ctx.read_stream_writer, - ) - - async def _handle_json_response( - self, - response: httpx.Response, - read_stream_writer: StreamWriter, - is_initialization: bool = False, - ) -> None: - """Handle JSON response from the server.""" - try: - content = await response.aread() - message = JSONRPCMessage.model_validate_json(content) - - # Extract protocol version from initialization response - if is_initialization: - self._maybe_extract_protocol_version_from_message(message) - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except Exception as exc: - logger.exception("Error parsing JSON response") - await read_stream_writer.send(exc) - - async def _handle_sse_response( - self, - response: httpx.Response, - ctx: RequestContext, - is_initialization: bool = False, - ) -> None: - """Handle SSE response from the server.""" - try: - event_source = EventSource(response) - async for sse in event_source.aiter_sse(): - is_complete = await self._handle_sse_event( - sse, - ctx.read_stream_writer, - resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None), - is_initialization=is_initialization, - ) - # If the SSE event indicates completion, like returning respose/error - # break the loop - if is_complete: - await response.aclose() - break - except Exception as e: - logger.exception("Error reading SSE stream:") - await ctx.read_stream_writer.send(e) - - async def _handle_unexpected_content_type( - self, - content_type: str, - read_stream_writer: StreamWriter, - ) -> None: - """Handle unexpected content type in response.""" - error_msg = f"Unexpected content type: {content_type}" - logger.error(error_msg) - await read_stream_writer.send(ValueError(error_msg)) - - async def _send_session_terminated_error( - self, - read_stream_writer: StreamWriter, - request_id: RequestId, - ) -> None: - """Send a session terminated error response.""" - jsonrpc_error = JSONRPCError( - jsonrpc="2.0", - id=request_id, - error=ErrorData(code=32600, message="Session terminated"), - ) - session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) - await read_stream_writer.send(session_message) - - async def post_writer( - self, - client: httpx.AsyncClient, - write_stream_reader: StreamReader, - read_stream_writer: StreamWriter, - write_stream: MemoryObjectSendStream[SessionMessage], - start_get_stream: Callable[[], None], - tg: TaskGroup, - ) -> None: - """Handle writing requests to the server.""" - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - message = session_message.message - metadata = ( - session_message.metadata - if isinstance(session_message.metadata, ClientMessageMetadata) - else None - ) - - # Check if this is a resumption request - is_resumption = bool(metadata and metadata.resumption_token) - - logger.debug(f"Sending client message: {message}") - - # Handle initialized notification - if self._is_initialized_notification(message): - start_get_stream() - - ctx = RequestContext( - client=client, - headers=self.request_headers, - session_id=self.session_id, - session_message=session_message, - metadata=metadata, - read_stream_writer=read_stream_writer, - sse_read_timeout=self.sse_read_timeout, - ) - - async def handle_request_async(): - if is_resumption: - await self._handle_resumption_request(ctx) - else: - await self._handle_post_request(ctx) - - # If this is a request, start a new task to handle it - if isinstance(message.root, JSONRPCRequest): - tg.start_soon(handle_request_async) - else: - await handle_request_async() - - except Exception: - logger.exception("Error in post_writer") - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - - async def terminate_session(self, client: httpx.AsyncClient) -> None: - """Terminate the session by sending a DELETE request.""" - if not self.session_id: - return - - try: - headers = self._prepare_request_headers(self.request_headers) - response = await client.delete(self.url, headers=headers) - - if response.status_code == 405: - logger.debug("Server does not allow session termination") - elif response.status_code not in (200, 204): - logger.warning(f"Session termination failed: {response.status_code}") - except Exception as exc: - logger.warning(f"Session termination failed: {exc}") - - def get_session_id(self) -> str | None: - """Get the current session ID.""" - return self.session_id - - -@asynccontextmanager -async def streamablehttp_client( - url: str, - headers: dict[str, str] | None = None, - timeout: float | timedelta = 30, - sse_read_timeout: float | timedelta = 60 * 5, - terminate_on_close: bool = True, - httpx_client_factory: McpHttpClientFactory = create_mcp_http_client, - auth: httpx.Auth | None = None, -) -> AsyncGenerator[ - tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], - GetSessionIdCallback, - ], - None, -]: - """ - Client transport for StreamableHTTP. - - `sse_read_timeout` determines how long (in seconds) the client will wait for a new - event before disconnecting. All other HTTP operations are controlled by `timeout`. - - Yields: - Tuple containing: - - read_stream: Stream for reading messages from the server - - write_stream: Stream for sending messages to the server - - get_session_id_callback: Function to retrieve the current session ID - """ - transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth) - - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - - async with httpx_client_factory( - headers=transport.request_headers, - timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout), - auth=transport.auth, - ) as client: - # Define callbacks that need access to tg - def start_get_stream() -> None: - tg.start_soon(transport.handle_get_stream, client, read_stream_writer) - - tg.start_soon( - transport.post_writer, - client, - write_stream_reader, - read_stream_writer, - write_stream, - start_get_stream, - tg, - ) - - try: - yield ( - read_stream, - write_stream, - transport.get_session_id, - ) - finally: - if transport.session_id and terminate_on_close: - await transport.terminate_session(client) - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() diff --git a/src/mcp/mcp/client/websocket.py b/src/mcp/mcp/client/websocket.py deleted file mode 100644 index 0a371610b..000000000 --- a/src/mcp/mcp/client/websocket.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import logging -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import ValidationError -from websockets.asyncio.client import connect as ws_connect -from websockets.typing import Subprotocol - -import mcp.types as types -from mcp.shared.message import SessionMessage - -logger = logging.getLogger(__name__) - - -@asynccontextmanager -async def websocket_client( - url: str, -) -> AsyncGenerator[ - tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]], - None, -]: - """ - WebSocket client transport for MCP, symmetrical to the server version. - - Connects to 'url' using the 'mcp' subprotocol, then yields: - (read_stream, write_stream) - - - read_stream: As you read from this stream, you'll receive either valid - JSONRPCMessage objects or Exception objects (when validation fails). - - write_stream: Write JSONRPCMessage objects to this stream to send them - over the WebSocket to the server. - """ - - # Create two in-memory streams: - # - One for incoming messages (read_stream, written by ws_reader) - # - One for outgoing messages (write_stream, read by ws_writer) - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - # Connect using websockets, requesting the "mcp" subprotocol - async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws: - - async def ws_reader(): - """ - Reads text messages from the WebSocket, parses them as JSON-RPC messages, - and sends them into read_stream_writer. - """ - async with read_stream_writer: - async for raw_text in ws: - try: - message = types.JSONRPCMessage.model_validate_json(raw_text) - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except ValidationError as exc: - # If JSON parse or model validation fails, send the exception - await read_stream_writer.send(exc) - - async def ws_writer(): - """ - Reads JSON-RPC messages from write_stream_reader and - sends them to the server. - """ - async with write_stream_reader: - async for session_message in write_stream_reader: - # Convert to a dict, then to JSON - msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True) - await ws.send(json.dumps(msg_dict)) - - async with anyio.create_task_group() as tg: - # Start reader and writer tasks - tg.start_soon(ws_reader) - tg.start_soon(ws_writer) - - # Yield the receive/send streams - yield (read_stream, write_stream) - - # Once the caller's 'async with' block exits, we shut down - tg.cancel_scope.cancel() diff --git a/src/mcp/mcp/os/__init__.py b/src/mcp/mcp/os/__init__.py deleted file mode 100644 index fa5dbc809..000000000 --- a/src/mcp/mcp/os/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Platform-specific utilities for MCP.""" diff --git a/src/mcp/mcp/os/posix/__init__.py b/src/mcp/mcp/os/posix/__init__.py deleted file mode 100644 index 23aff8bb0..000000000 --- a/src/mcp/mcp/os/posix/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""POSIX-specific utilities for MCP.""" diff --git a/src/mcp/mcp/os/posix/utilities.py b/src/mcp/mcp/os/posix/utilities.py deleted file mode 100644 index dd1aea363..000000000 --- a/src/mcp/mcp/os/posix/utilities.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -POSIX-specific functionality for stdio client operations. -""" - -import logging -import os -import signal - -import anyio -from anyio.abc import Process - -logger = logging.getLogger(__name__) - - -async def terminate_posix_process_tree(process: Process, timeout_seconds: float = 2.0) -> None: - """ - Terminate a process and all its children on POSIX systems. - - Uses os.killpg() for atomic process group termination. - - Args: - process: The process to terminate - timeout_seconds: Timeout in seconds before force killing (default: 2.0) - """ - pid = getattr(process, "pid", None) or getattr(getattr(process, "popen", None), "pid", None) - if not pid: - # No PID means there's no process to terminate - it either never started, - # already exited, or we have an invalid process object - return - - try: - pgid = os.getpgid(pid) - os.killpg(pgid, signal.SIGTERM) - - with anyio.move_on_after(timeout_seconds): - while True: - try: - # Check if process group still exists (signal 0 = check only) - os.killpg(pgid, 0) - await anyio.sleep(0.1) - except ProcessLookupError: - return - - try: - os.killpg(pgid, signal.SIGKILL) - except ProcessLookupError: - pass - - except (ProcessLookupError, PermissionError, OSError) as e: - logger.warning(f"Process group termination failed for PID {pid}: {e}, falling back to simple terminate") - try: - process.terminate() - with anyio.fail_after(timeout_seconds): - await process.wait() - except Exception: - logger.warning(f"Process termination failed for PID {pid}, attempting force kill") - try: - process.kill() - except Exception: - logger.exception(f"Failed to kill process {pid}") diff --git a/src/mcp/mcp/os/win32/__init__.py b/src/mcp/mcp/os/win32/__init__.py deleted file mode 100644 index f1ebab98d..000000000 --- a/src/mcp/mcp/os/win32/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Windows-specific utilities for MCP.""" diff --git a/src/mcp/mcp/os/win32/utilities.py b/src/mcp/mcp/os/win32/utilities.py deleted file mode 100644 index 962be0229..000000000 --- a/src/mcp/mcp/os/win32/utilities.py +++ /dev/null @@ -1,338 +0,0 @@ -""" -Windows-specific functionality for stdio client operations. -""" - -import logging -import shutil -import subprocess -import sys -from pathlib import Path -from typing import BinaryIO, TextIO, cast - -import anyio -from anyio import to_thread -from anyio.abc import Process -from anyio.streams.file import FileReadStream, FileWriteStream -from typing_extensions import deprecated - -logger = logging.getLogger("client.stdio.win32") - -# Windows-specific imports for Job Objects -if sys.platform == "win32": - import pywintypes - import win32api - import win32con - import win32job -else: - # Type stubs for non-Windows platforms - win32api = None - win32con = None - win32job = None - pywintypes = None - -JobHandle = int - - -def get_windows_executable_command(command: str) -> str: - """ - Get the correct executable command normalized for Windows. - - On Windows, commands might exist with specific extensions (.exe, .cmd, etc.) - that need to be located for proper execution. - - Args: - command: Base command (e.g., 'uvx', 'npx') - - Returns: - str: Windows-appropriate command path - """ - try: - # First check if command exists in PATH as-is - if command_path := shutil.which(command): - return command_path - - # Check for Windows-specific extensions - for ext in [".cmd", ".bat", ".exe", ".ps1"]: - ext_version = f"{command}{ext}" - if ext_path := shutil.which(ext_version): - return ext_path - - # For regular commands or if we couldn't find special versions - return command - except OSError: - # Handle file system errors during path resolution - # (permissions, broken symlinks, etc.) - return command - - -class FallbackProcess: - """ - A fallback process wrapper for Windows to handle async I/O - when using subprocess.Popen, which provides sync-only FileIO objects. - - This wraps stdin and stdout into async-compatible - streams (FileReadStream, FileWriteStream), - so that MCP clients expecting async streams can work properly. - """ - - def __init__(self, popen_obj: subprocess.Popen[bytes]): - self.popen: subprocess.Popen[bytes] = popen_obj - self.stdin_raw = popen_obj.stdin # type: ignore[assignment] - self.stdout_raw = popen_obj.stdout # type: ignore[assignment] - self.stderr = popen_obj.stderr # type: ignore[assignment] - - self.stdin = FileWriteStream(cast(BinaryIO, self.stdin_raw)) if self.stdin_raw else None - self.stdout = FileReadStream(cast(BinaryIO, self.stdout_raw)) if self.stdout_raw else None - - async def __aenter__(self): - """Support async context manager entry.""" - return self - - async def __aexit__( - self, - exc_type: BaseException | None, - exc_val: BaseException | None, - exc_tb: object | None, - ) -> None: - """Terminate and wait on process exit inside a thread.""" - self.popen.terminate() - await to_thread.run_sync(self.popen.wait) - - # Close the file handles to prevent ResourceWarning - if self.stdin: - await self.stdin.aclose() - if self.stdout: - await self.stdout.aclose() - if self.stdin_raw: - self.stdin_raw.close() - if self.stdout_raw: - self.stdout_raw.close() - if self.stderr: - self.stderr.close() - - async def wait(self): - """Async wait for process completion.""" - return await to_thread.run_sync(self.popen.wait) - - def terminate(self): - """Terminate the subprocess immediately.""" - return self.popen.terminate() - - def kill(self) -> None: - """Kill the subprocess immediately (alias for terminate).""" - self.terminate() - - @property - def pid(self) -> int: - """Return the process ID.""" - return self.popen.pid - - -# ------------------------ -# Updated function -# ------------------------ - - -async def create_windows_process( - command: str, - args: list[str], - env: dict[str, str] | None = None, - errlog: TextIO | None = sys.stderr, - cwd: Path | str | None = None, -) -> Process | FallbackProcess: - """ - Creates a subprocess in a Windows-compatible way with Job Object support. - - Attempt to use anyio's open_process for async subprocess creation. - In some cases this will throw NotImplementedError on Windows, e.g. - when using the SelectorEventLoop which does not support async subprocesses. - In that case, we fall back to using subprocess.Popen. - - The process is automatically added to a Job Object to ensure all child - processes are terminated when the parent is terminated. - - Args: - command (str): The executable to run - args (list[str]): List of command line arguments - env (dict[str, str] | None): Environment variables - errlog (TextIO | None): Where to send stderr output (defaults to sys.stderr) - cwd (Path | str | None): Working directory for the subprocess - - Returns: - Process | FallbackProcess: Async-compatible subprocess with stdin and stdout streams - """ - job = _create_job_object() - process = None - - try: - # First try using anyio with Windows-specific flags to hide console window - process = await anyio.open_process( - [command, *args], - env=env, - # Ensure we don't create console windows for each process - creationflags=subprocess.CREATE_NO_WINDOW # type: ignore - if hasattr(subprocess, "CREATE_NO_WINDOW") - else 0, - stderr=errlog, - cwd=cwd, - ) - except NotImplementedError: - # If Windows doesn't support async subprocess creation, use fallback - process = await _create_windows_fallback_process(command, args, env, errlog, cwd) - except Exception: - # Try again without creation flags - process = await anyio.open_process( - [command, *args], - env=env, - stderr=errlog, - cwd=cwd, - ) - - _maybe_assign_process_to_job(process, job) - return process - - -async def _create_windows_fallback_process( - command: str, - args: list[str], - env: dict[str, str] | None = None, - errlog: TextIO | None = sys.stderr, - cwd: Path | str | None = None, -) -> FallbackProcess: - """ - Create a subprocess using subprocess.Popen as a fallback when anyio fails. - - This function wraps the sync subprocess.Popen in an async-compatible interface. - """ - try: - # Try launching with creationflags to avoid opening a new console window - popen_obj = subprocess.Popen( - [command, *args], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=errlog, - env=env, - cwd=cwd, - bufsize=0, # Unbuffered output - creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), - ) - except Exception: - # If creationflags failed, fallback without them - popen_obj = subprocess.Popen( - [command, *args], - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=errlog, - env=env, - cwd=cwd, - bufsize=0, - ) - return FallbackProcess(popen_obj) - - -def _create_job_object() -> int | None: - """ - Create a Windows Job Object configured to terminate all processes when closed. - """ - if sys.platform != "win32" or not win32job: - return None - - try: - job = win32job.CreateJobObject(None, "") - extended_info = win32job.QueryInformationJobObject(job, win32job.JobObjectExtendedLimitInformation) - - extended_info["BasicLimitInformation"]["LimitFlags"] |= win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE - win32job.SetInformationJobObject(job, win32job.JobObjectExtendedLimitInformation, extended_info) - return job - except Exception as e: - logger.warning(f"Failed to create Job Object for process tree management: {e}") - return None - - -def _maybe_assign_process_to_job(process: Process | FallbackProcess, job: JobHandle | None) -> None: - """ - Try to assign a process to a job object. If assignment fails - for any reason, the job handle is closed. - """ - if not job: - return - - if sys.platform != "win32" or not win32api or not win32con or not win32job: - return - - try: - process_handle = win32api.OpenProcess( - win32con.PROCESS_SET_QUOTA | win32con.PROCESS_TERMINATE, False, process.pid - ) - if not process_handle: - raise Exception("Failed to open process handle") - - try: - win32job.AssignProcessToJobObject(job, process_handle) - process._job_object = job - finally: - win32api.CloseHandle(process_handle) - except Exception as e: - logger.warning(f"Failed to assign process {process.pid} to Job Object: {e}") - if win32api: - win32api.CloseHandle(job) - - -async def terminate_windows_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None: - """ - Terminate a process and all its children on Windows. - - If the process has an associated job object, it will be terminated. - Otherwise, falls back to basic process termination. - - Args: - process: The process to terminate - timeout_seconds: Timeout in seconds before force killing (default: 2.0) - """ - if sys.platform != "win32": - return - - job = getattr(process, "_job_object", None) - if job and win32job: - try: - win32job.TerminateJobObject(job, 1) - except Exception: - # Job might already be terminated - pass - finally: - if win32api: - try: - win32api.CloseHandle(job) - except Exception: - pass - - # Always try to terminate the process itself as well - try: - process.terminate() - except Exception: - pass - - -@deprecated( - "terminate_windows_process is deprecated and will be removed in a future version. " - "Process termination is now handled internally by the stdio_client context manager." -) -async def terminate_windows_process(process: Process | FallbackProcess): - """ - Terminate a Windows process. - - Note: On Windows, terminating a process with process.terminate() doesn't - always guarantee immediate process termination. - So we give it 2s to exit, or we call process.kill() - which sends a SIGKILL equivalent signal. - - Args: - process: The process to terminate - """ - try: - process.terminate() - with anyio.fail_after(2.0): - await process.wait() - except TimeoutError: - # Force kill if it doesn't terminate - process.kill() diff --git a/src/mcp/mcp/py.typed b/src/mcp/mcp/py.typed deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/mcp/mcp/server/__init__.py b/src/mcp/mcp/server/__init__.py deleted file mode 100644 index 0feed368e..000000000 --- a/src/mcp/mcp/server/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .fastmcp import FastMCP -from .lowlevel import NotificationOptions, Server -from .models import InitializationOptions - -__all__ = ["Server", "FastMCP", "NotificationOptions", "InitializationOptions"] diff --git a/src/mcp/mcp/server/__main__.py b/src/mcp/mcp/server/__main__.py deleted file mode 100644 index 1970eca7d..000000000 --- a/src/mcp/mcp/server/__main__.py +++ /dev/null @@ -1,50 +0,0 @@ -import importlib.metadata -import logging -import sys - -import anyio - -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.server.stdio import stdio_server -from mcp.types import ServerCapabilities - -if not sys.warnoptions: - import warnings - - warnings.simplefilter("ignore") - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger("server") - - -async def receive_loop(session: ServerSession): - logger.info("Starting receive loop") - async for message in session.incoming_messages: - if isinstance(message, Exception): - logger.error("Error: %s", message) - continue - - logger.info("Received message from client: %s", message) - - -async def main(): - version = importlib.metadata.version("mcp") - async with stdio_server() as (read_stream, write_stream): - async with ( - ServerSession( - read_stream, - write_stream, - InitializationOptions( - server_name="mcp", - server_version=version, - capabilities=ServerCapabilities(), - ), - ) as session, - write_stream, - ): - await receive_loop(session) - - -if __name__ == "__main__": - anyio.run(main, backend="trio") diff --git a/src/mcp/mcp/server/auth/__init__.py b/src/mcp/mcp/server/auth/__init__.py deleted file mode 100644 index 6888ffe8d..000000000 --- a/src/mcp/mcp/server/auth/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -MCP OAuth server authorization components. -""" diff --git a/src/mcp/mcp/server/auth/errors.py b/src/mcp/mcp/server/auth/errors.py deleted file mode 100644 index 117deea83..000000000 --- a/src/mcp/mcp/server/auth/errors.py +++ /dev/null @@ -1,5 +0,0 @@ -from pydantic import ValidationError - - -def stringify_pydantic_error(validation_error: ValidationError) -> str: - return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors()) diff --git a/src/mcp/mcp/server/auth/handlers/__init__.py b/src/mcp/mcp/server/auth/handlers/__init__.py deleted file mode 100644 index e99a62de1..000000000 --- a/src/mcp/mcp/server/auth/handlers/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Request handlers for MCP authorization endpoints. -""" diff --git a/src/mcp/mcp/server/auth/handlers/authorize.py b/src/mcp/mcp/server/auth/handlers/authorize.py deleted file mode 100644 index f484cf886..000000000 --- a/src/mcp/mcp/server/auth/handlers/authorize.py +++ /dev/null @@ -1,224 +0,0 @@ -import logging -from dataclasses import dataclass -from typing import Any, Literal - -from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError -from starlette.datastructures import FormData, QueryParams -from starlette.requests import Request -from starlette.responses import RedirectResponse, Response - -from mcp.server.auth.errors import stringify_pydantic_error -from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.provider import ( - AuthorizationErrorCode, - AuthorizationParams, - AuthorizeError, - OAuthAuthorizationServerProvider, - construct_redirect_uri, -) -from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError - -logger = logging.getLogger(__name__) - - -class AuthorizationRequest(BaseModel): - # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 - client_id: str = Field(..., description="The client ID") - redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization") - - # see OAuthClientMetadata; we only support `code` - response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow") - code_challenge: str = Field(..., description="PKCE code challenge") - code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256") - state: str | None = Field(None, description="Optional state parameter") - scope: str | None = Field( - None, - description="Optional scope; if specified, should be a space-separated list of scope strings", - ) - resource: str | None = Field( - None, - description="RFC 8707 resource indicator - the MCP server this token will be used with", - ) - - -class AuthorizationErrorResponse(BaseModel): - error: AuthorizationErrorCode - error_description: str | None - error_uri: AnyUrl | None = None - # must be set if provided in the request - state: str | None = None - - -def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None: - if params is None: - return None - value = params.get(key) - if isinstance(value, str): - return value - return None - - -class AnyUrlModel(RootModel[AnyUrl]): - root: AnyUrl - - -@dataclass -class AuthorizationHandler: - provider: OAuthAuthorizationServerProvider[Any, Any, Any] - - async def handle(self, request: Request) -> Response: - # implements authorization requests for grant_type=code; - # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 - - state = None - redirect_uri = None - client = None - params = None - - async def error_response( - error: AuthorizationErrorCode, - error_description: str | None, - attempt_load_client: bool = True, - ): - # Error responses take two different formats: - # 1. The request has a valid client ID & redirect_uri: we issue a redirect - # back to the redirect_uri with the error response fields as query - # parameters. This allows the client to be notified of the error. - # 2. Otherwise, we return an error response directly to the end user; - # we choose to do so in JSON, but this is left undefined in the - # specification. - # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1 - # - # This logic is a bit awkward to handle, because the error might be thrown - # very early in request validation, before we've done the usual Pydantic - # validation, loaded the client, etc. To handle this, error_response() - # contains fallback logic which attempts to load the parameters directly - # from the request. - - nonlocal client, redirect_uri, state - if client is None and attempt_load_client: - # make last-ditch attempt to load the client - client_id = best_effort_extract_string("client_id", params) - client = client_id and await self.provider.get_client(client_id) - if redirect_uri is None and client: - # make last-ditch effort to load the redirect uri - try: - if params is not None and "redirect_uri" not in params: - raw_redirect_uri = None - else: - raw_redirect_uri = AnyUrlModel.model_validate( - best_effort_extract_string("redirect_uri", params) - ).root - redirect_uri = client.validate_redirect_uri(raw_redirect_uri) - except (ValidationError, InvalidRedirectUriError): - # if the redirect URI is invalid, ignore it & just return the - # initial error - pass - - # the error response MUST contain the state specified by the client, if any - if state is None: - # make last-ditch effort to load state - state = best_effort_extract_string("state", params) - - error_resp = AuthorizationErrorResponse( - error=error, - error_description=error_description, - state=state, - ) - - if redirect_uri and client: - return RedirectResponse( - url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)), - status_code=302, - headers={"Cache-Control": "no-store"}, - ) - else: - return PydanticJSONResponse( - status_code=400, - content=error_resp, - headers={"Cache-Control": "no-store"}, - ) - - try: - # Parse request parameters - if request.method == "GET": - # Convert query_params to dict for pydantic validation - params = request.query_params - else: - # Parse form data for POST requests - params = await request.form() - - # Save state if it exists, even before validation - state = best_effort_extract_string("state", params) - - try: - auth_request = AuthorizationRequest.model_validate(params) - state = auth_request.state # Update with validated state - except ValidationError as validation_error: - error: AuthorizationErrorCode = "invalid_request" - for e in validation_error.errors(): - if e["loc"] == ("response_type",) and e["type"] == "literal_error": - error = "unsupported_response_type" - break - return await error_response(error, stringify_pydantic_error(validation_error)) - - # Get client information - client = await self.provider.get_client( - auth_request.client_id, - ) - if not client: - # For client_id validation errors, return direct error (no redirect) - return await error_response( - error="invalid_request", - error_description=f"Client ID '{auth_request.client_id}' not found", - attempt_load_client=False, - ) - - # Validate redirect_uri against client's registered URIs - try: - redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri) - except InvalidRedirectUriError as validation_error: - # For redirect_uri validation errors, return direct error (no redirect) - return await error_response( - error="invalid_request", - error_description=validation_error.message, - ) - - # Validate scope - for scope errors, we can redirect - try: - scopes = client.validate_scope(auth_request.scope) - except InvalidScopeError as validation_error: - # For scope errors, redirect with error parameters - return await error_response( - error="invalid_scope", - error_description=validation_error.message, - ) - - # Setup authorization parameters - auth_params = AuthorizationParams( - state=state, - scopes=scopes, - code_challenge=auth_request.code_challenge, - redirect_uri=redirect_uri, - redirect_uri_provided_explicitly=auth_request.redirect_uri is not None, - resource=auth_request.resource, # RFC 8707 - ) - - try: - # Let the provider pick the next URI to redirect to - return RedirectResponse( - url=await self.provider.authorize( - client, - auth_params, - ), - status_code=302, - headers={"Cache-Control": "no-store"}, - ) - except AuthorizeError as e: - # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 - return await error_response(error=e.error, error_description=e.error_description) - - except Exception as validation_error: - # Catch-all for unexpected errors - logger.exception("Unexpected error in authorization_handler", exc_info=validation_error) - return await error_response(error="server_error", error_description="An unexpected error occurred") diff --git a/src/mcp/mcp/server/auth/handlers/metadata.py b/src/mcp/mcp/server/auth/handlers/metadata.py deleted file mode 100644 index f12644215..000000000 --- a/src/mcp/mcp/server/auth/handlers/metadata.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass - -from starlette.requests import Request -from starlette.responses import Response - -from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata - - -@dataclass -class MetadataHandler: - metadata: OAuthMetadata - - async def handle(self, request: Request) -> Response: - return PydanticJSONResponse( - content=self.metadata, - headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour - ) - - -@dataclass -class ProtectedResourceMetadataHandler: - metadata: ProtectedResourceMetadata - - async def handle(self, request: Request) -> Response: - return PydanticJSONResponse( - content=self.metadata, - headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour - ) diff --git a/src/mcp/mcp/server/auth/handlers/register.py b/src/mcp/mcp/server/auth/handlers/register.py deleted file mode 100644 index e6d99e66d..000000000 --- a/src/mcp/mcp/server/auth/handlers/register.py +++ /dev/null @@ -1,120 +0,0 @@ -import secrets -import time -from dataclasses import dataclass -from typing import Any -from uuid import uuid4 - -from pydantic import BaseModel, RootModel, ValidationError -from starlette.requests import Request -from starlette.responses import Response - -from mcp.server.auth.errors import stringify_pydantic_error -from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode -from mcp.server.auth.settings import ClientRegistrationOptions -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata - - -class RegistrationRequest(RootModel[OAuthClientMetadata]): - # this wrapper is a no-op; it's just to separate out the types exposed to the - # provider from what we use in the HTTP handler - root: OAuthClientMetadata - - -class RegistrationErrorResponse(BaseModel): - error: RegistrationErrorCode - error_description: str | None - - -@dataclass -class RegistrationHandler: - provider: OAuthAuthorizationServerProvider[Any, Any, Any] - options: ClientRegistrationOptions - - async def handle(self, request: Request) -> Response: - # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 - try: - # Parse request body as JSON - body = await request.json() - client_metadata = OAuthClientMetadata.model_validate(body) - - # Scope validation is handled below - except ValidationError as validation_error: - return PydanticJSONResponse( - content=RegistrationErrorResponse( - error="invalid_client_metadata", - error_description=stringify_pydantic_error(validation_error), - ), - status_code=400, - ) - - client_id = str(uuid4()) - client_secret = None - if client_metadata.token_endpoint_auth_method != "none": - # cryptographically secure random 32-byte hex string - client_secret = secrets.token_hex(32) - - if client_metadata.scope is None and self.options.default_scopes is not None: - client_metadata.scope = " ".join(self.options.default_scopes) - elif client_metadata.scope is not None and self.options.valid_scopes is not None: - requested_scopes = set(client_metadata.scope.split()) - valid_scopes = set(self.options.valid_scopes) - if not requested_scopes.issubset(valid_scopes): - return PydanticJSONResponse( - content=RegistrationErrorResponse( - error="invalid_client_metadata", - error_description="Requested scopes are not valid: " - f"{', '.join(requested_scopes - valid_scopes)}", - ), - status_code=400, - ) - if set(client_metadata.grant_types) != {"authorization_code", "refresh_token"}: - return PydanticJSONResponse( - content=RegistrationErrorResponse( - error="invalid_client_metadata", - error_description="grant_types must be authorization_code and refresh_token", - ), - status_code=400, - ) - - client_id_issued_at = int(time.time()) - client_secret_expires_at = ( - client_id_issued_at + self.options.client_secret_expiry_seconds - if self.options.client_secret_expiry_seconds is not None - else None - ) - - client_info = OAuthClientInformationFull( - client_id=client_id, - client_id_issued_at=client_id_issued_at, - client_secret=client_secret, - client_secret_expires_at=client_secret_expires_at, - # passthrough information from the client request - redirect_uris=client_metadata.redirect_uris, - token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, - grant_types=client_metadata.grant_types, - response_types=client_metadata.response_types, - client_name=client_metadata.client_name, - client_uri=client_metadata.client_uri, - logo_uri=client_metadata.logo_uri, - scope=client_metadata.scope, - contacts=client_metadata.contacts, - tos_uri=client_metadata.tos_uri, - policy_uri=client_metadata.policy_uri, - jwks_uri=client_metadata.jwks_uri, - jwks=client_metadata.jwks, - software_id=client_metadata.software_id, - software_version=client_metadata.software_version, - ) - try: - # Register client - await self.provider.register_client(client_info) - - # Return client information - return PydanticJSONResponse(content=client_info, status_code=201) - except RegistrationError as e: - # Handle registration errors as defined in RFC 7591 Section 3.2.2 - return PydanticJSONResponse( - content=RegistrationErrorResponse(error=e.error, error_description=e.error_description), - status_code=400, - ) diff --git a/src/mcp/mcp/server/auth/handlers/revoke.py b/src/mcp/mcp/server/auth/handlers/revoke.py deleted file mode 100644 index 478ad7a01..000000000 --- a/src/mcp/mcp/server/auth/handlers/revoke.py +++ /dev/null @@ -1,94 +0,0 @@ -from dataclasses import dataclass -from functools import partial -from typing import Any, Literal - -from pydantic import BaseModel, ValidationError -from starlette.requests import Request -from starlette.responses import Response - -from mcp.server.auth.errors import ( - stringify_pydantic_error, -) -from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator -from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken - - -class RevocationRequest(BaseModel): - """ - # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 - """ - - token: str - token_type_hint: Literal["access_token", "refresh_token"] | None = None - client_id: str - client_secret: str | None - - -class RevocationErrorResponse(BaseModel): - error: Literal["invalid_request", "unauthorized_client"] - error_description: str | None = None - - -@dataclass -class RevocationHandler: - provider: OAuthAuthorizationServerProvider[Any, Any, Any] - client_authenticator: ClientAuthenticator - - async def handle(self, request: Request) -> Response: - """ - Handler for the OAuth 2.0 Token Revocation endpoint. - """ - try: - form_data = await request.form() - revocation_request = RevocationRequest.model_validate(dict(form_data)) - except ValidationError as e: - return PydanticJSONResponse( - status_code=400, - content=RevocationErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(e), - ), - ) - - # Authenticate client - try: - client = await self.client_authenticator.authenticate( - revocation_request.client_id, revocation_request.client_secret - ) - except AuthenticationError as e: - return PydanticJSONResponse( - status_code=401, - content=RevocationErrorResponse( - error="unauthorized_client", - error_description=e.message, - ), - ) - - loaders = [ - self.provider.load_access_token, - partial(self.provider.load_refresh_token, client), - ] - if revocation_request.token_type_hint == "refresh_token": - loaders = reversed(loaders) - - token: None | AccessToken | RefreshToken = None - for loader in loaders: - token = await loader(revocation_request.token) - if token is not None: - break - - # if token is not found, just return HTTP 200 per the RFC - if token and token.client_id == client.client_id: - # Revoke token; provider is not meant to be able to do validation - # at this point that would result in an error - await self.provider.revoke_token(token) - - # Return successful empty response - return Response( - status_code=200, - headers={ - "Cache-Control": "no-store", - "Pragma": "no-cache", - }, - ) diff --git a/src/mcp/mcp/server/auth/handlers/token.py b/src/mcp/mcp/server/auth/handlers/token.py deleted file mode 100644 index 4e15e6265..000000000 --- a/src/mcp/mcp/server/auth/handlers/token.py +++ /dev/null @@ -1,238 +0,0 @@ -import base64 -import hashlib -import time -from dataclasses import dataclass -from typing import Annotated, Any, Literal - -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError -from starlette.requests import Request - -from mcp.server.auth.errors import stringify_pydantic_error -from mcp.server.auth.json_response import PydanticJSONResponse -from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode -from mcp.shared.auth import OAuthToken - - -class AuthorizationCodeRequest(BaseModel): - # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 - grant_type: Literal["authorization_code"] - code: str = Field(..., description="The authorization code") - redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize") - client_id: str - # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 - client_secret: str | None = None - # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 - code_verifier: str = Field(..., description="PKCE code verifier") - # RFC 8707 resource indicator - resource: str | None = Field(None, description="Resource indicator for the token") - - -class RefreshTokenRequest(BaseModel): - # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 - grant_type: Literal["refresh_token"] - refresh_token: str = Field(..., description="The refresh token") - scope: str | None = Field(None, description="Optional scope parameter") - client_id: str - # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 - client_secret: str | None = None - # RFC 8707 resource indicator - resource: str | None = Field(None, description="Resource indicator for the token") - - -class TokenRequest( - RootModel[ - Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, - Field(discriminator="grant_type"), - ] - ] -): - root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, - Field(discriminator="grant_type"), - ] - - -class TokenErrorResponse(BaseModel): - """ - See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 - """ - - error: TokenErrorCode - error_description: str | None = None - error_uri: AnyHttpUrl | None = None - - -class TokenSuccessResponse(RootModel[OAuthToken]): - # this is just a wrapper over OAuthToken; the only reason we do this - # is to have some separation between the HTTP response type, and the - # type returned by the provider - root: OAuthToken - - -@dataclass -class TokenHandler: - provider: OAuthAuthorizationServerProvider[Any, Any, Any] - client_authenticator: ClientAuthenticator - - def response(self, obj: TokenSuccessResponse | TokenErrorResponse): - status_code = 200 - if isinstance(obj, TokenErrorResponse): - status_code = 400 - - return PydanticJSONResponse( - content=obj, - status_code=status_code, - headers={ - "Cache-Control": "no-store", - "Pragma": "no-cache", - }, - ) - - async def handle(self, request: Request): - try: - form_data = await request.form() - token_request = TokenRequest.model_validate(dict(form_data)).root - except ValidationError as validation_error: - return self.response( - TokenErrorResponse( - error="invalid_request", - error_description=stringify_pydantic_error(validation_error), - ) - ) - - try: - client_info = await self.client_authenticator.authenticate( - client_id=token_request.client_id, - client_secret=token_request.client_secret, - ) - except AuthenticationError as e: - return self.response( - TokenErrorResponse( - error="unauthorized_client", - error_description=e.message, - ) - ) - - if token_request.grant_type not in client_info.grant_types: - return self.response( - TokenErrorResponse( - error="unsupported_grant_type", - error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"), - ) - ) - - tokens: OAuthToken - - match token_request: - case AuthorizationCodeRequest(): - auth_code = await self.provider.load_authorization_code(client_info, token_request.code) - if auth_code is None or auth_code.client_id != token_request.client_id: - # if code belongs to different client, pretend it doesn't exist - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="authorization code does not exist", - ) - ) - - # make auth codes expire after a deadline - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 - if auth_code.expires_at < time.time(): - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="authorization code has expired", - ) - ) - - # verify redirect_uri doesn't change between /authorize and /tokens - # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 - if auth_code.redirect_uri_provided_explicitly: - authorize_request_redirect_uri = auth_code.redirect_uri - else: - authorize_request_redirect_uri = None - - # Convert both sides to strings for comparison to handle AnyUrl vs string issues - token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None - auth_redirect_str = ( - str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None - ) - - if token_redirect_str != auth_redirect_str: - return self.response( - TokenErrorResponse( - error="invalid_request", - error_description=("redirect_uri did not match the one used when creating auth code"), - ) - ) - - # Verify PKCE code verifier - sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() - hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=") - - if hashed_code_verifier != auth_code.code_challenge: - # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="incorrect code_verifier", - ) - ) - - try: - # Exchange authorization code for tokens - tokens = await self.provider.exchange_authorization_code(client_info, auth_code) - except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) - - case RefreshTokenRequest(): - refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) - if refresh_token is None or refresh_token.client_id != token_request.client_id: - # if token belongs to different client, pretend it doesn't exist - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="refresh token does not exist", - ) - ) - - if refresh_token.expires_at and refresh_token.expires_at < time.time(): - # if the refresh token has expired, pretend it doesn't exist - return self.response( - TokenErrorResponse( - error="invalid_grant", - error_description="refresh token has expired", - ) - ) - - # Parse scopes if provided - scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes - - for scope in scopes: - if scope not in refresh_token.scopes: - return self.response( - TokenErrorResponse( - error="invalid_scope", - error_description=(f"cannot request scope `{scope}` not provided by refresh token"), - ) - ) - - try: - # Exchange refresh token for new tokens - tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes) - except TokenError as e: - return self.response( - TokenErrorResponse( - error=e.error, - error_description=e.error_description, - ) - ) - - return self.response(TokenSuccessResponse(root=tokens)) diff --git a/src/mcp/mcp/server/auth/json_response.py b/src/mcp/mcp/server/auth/json_response.py deleted file mode 100644 index bd95bd693..000000000 --- a/src/mcp/mcp/server/auth/json_response.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any - -from starlette.responses import JSONResponse - - -class PydanticJSONResponse(JSONResponse): - # use pydantic json serialization instead of the stock `json.dumps`, - # so that we can handle serializing pydantic models like AnyHttpUrl - def render(self, content: Any) -> bytes: - return content.model_dump_json(exclude_none=True).encode("utf-8") diff --git a/src/mcp/mcp/server/auth/middleware/__init__.py b/src/mcp/mcp/server/auth/middleware/__init__.py deleted file mode 100644 index ba3ff63c3..000000000 --- a/src/mcp/mcp/server/auth/middleware/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Middleware for MCP authorization. -""" diff --git a/src/mcp/mcp/server/auth/middleware/auth_context.py b/src/mcp/mcp/server/auth/middleware/auth_context.py deleted file mode 100644 index e2116c3bf..000000000 --- a/src/mcp/mcp/server/auth/middleware/auth_context.py +++ /dev/null @@ -1,48 +0,0 @@ -import contextvars - -from starlette.types import ASGIApp, Receive, Scope, Send - -from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser -from mcp.server.auth.provider import AccessToken - -# Create a contextvar to store the authenticated user -# The default is None, indicating no authenticated user is present -auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None) - - -def get_access_token() -> AccessToken | None: - """ - Get the access token from the current context. - - Returns: - The access token if an authenticated user is available, None otherwise. - """ - auth_user = auth_context_var.get() - return auth_user.access_token if auth_user else None - - -class AuthContextMiddleware: - """ - Middleware that extracts the authenticated user from the request - and sets it in a contextvar for easy access throughout the request lifecycle. - - This middleware should be added after the AuthenticationMiddleware in the - middleware stack to ensure that the user is properly authenticated before - being stored in the context. - """ - - def __init__(self, app: ASGIApp): - self.app = app - - async def __call__(self, scope: Scope, receive: Receive, send: Send): - user = scope.get("user") - if isinstance(user, AuthenticatedUser): - # Set the authenticated user in the contextvar - token = auth_context_var.set(user) - try: - await self.app(scope, receive, send) - finally: - auth_context_var.reset(token) - else: - # No authenticated user, just process the request - await self.app(scope, receive, send) diff --git a/src/mcp/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/mcp/server/auth/middleware/bearer_auth.py deleted file mode 100644 index 6251e5ad5..000000000 --- a/src/mcp/mcp/server/auth/middleware/bearer_auth.py +++ /dev/null @@ -1,128 +0,0 @@ -import json -import time -from typing import Any - -from pydantic import AnyHttpUrl -from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser -from starlette.requests import HTTPConnection -from starlette.types import Receive, Scope, Send - -from mcp.server.auth.provider import AccessToken, TokenVerifier - - -class AuthenticatedUser(SimpleUser): - """User with authentication info.""" - - def __init__(self, auth_info: AccessToken): - super().__init__(auth_info.client_id) - self.access_token = auth_info - self.scopes = auth_info.scopes - - -class BearerAuthBackend(AuthenticationBackend): - """ - Authentication backend that validates Bearer tokens using a TokenVerifier. - """ - - def __init__(self, token_verifier: TokenVerifier): - self.token_verifier = token_verifier - - async def authenticate(self, conn: HTTPConnection): - auth_header = next( - (conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"), - None, - ) - if not auth_header or not auth_header.lower().startswith("bearer "): - return None - - token = auth_header[7:] # Remove "Bearer " prefix - - # Validate the token with the verifier - auth_info = await self.token_verifier.verify_token(token) - - if not auth_info: - return None - - if auth_info.expires_at and auth_info.expires_at < int(time.time()): - return None - - return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) - - -class RequireAuthMiddleware: - """ - Middleware that requires a valid Bearer token in the Authorization header. - - This will validate the token with the auth provider and store the resulting - auth info in the request state. - """ - - def __init__( - self, - app: Any, - required_scopes: list[str], - resource_metadata_url: AnyHttpUrl | None = None, - ): - """ - Initialize the middleware. - - Args: - app: ASGI application - required_scopes: List of scopes that the token must have - resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header - """ - self.app = app - self.required_scopes = required_scopes - self.resource_metadata_url = resource_metadata_url - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - auth_user = scope.get("user") - if not isinstance(auth_user, AuthenticatedUser): - await self._send_auth_error( - send, status_code=401, error="invalid_token", description="Authentication required" - ) - return - - auth_credentials = scope.get("auth") - - for required_scope in self.required_scopes: - # auth_credentials should always be provided; this is just paranoia - if auth_credentials is None or required_scope not in auth_credentials.scopes: - await self._send_auth_error( - send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}" - ) - return - - await self.app(scope, receive, send) - - async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None: - """Send an authentication error response with WWW-Authenticate header.""" - # Build WWW-Authenticate header value - www_auth_parts = [f'error="{error}"', f'error_description="{description}"'] - if self.resource_metadata_url: - www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"') - - www_authenticate = f"Bearer {', '.join(www_auth_parts)}" - - # Send response - body = {"error": error, "error_description": description} - body_bytes = json.dumps(body).encode() - - await send( - { - "type": "http.response.start", - "status": status_code, - "headers": [ - (b"content-type", b"application/json"), - (b"content-length", str(len(body_bytes)).encode()), - (b"www-authenticate", www_authenticate.encode()), - ], - } - ) - - await send( - { - "type": "http.response.body", - "body": body_bytes, - } - ) diff --git a/src/mcp/mcp/server/auth/middleware/client_auth.py b/src/mcp/mcp/server/auth/middleware/client_auth.py deleted file mode 100644 index d5f473b48..000000000 --- a/src/mcp/mcp/server/auth/middleware/client_auth.py +++ /dev/null @@ -1,51 +0,0 @@ -import time -from typing import Any - -from mcp.server.auth.provider import OAuthAuthorizationServerProvider -from mcp.shared.auth import OAuthClientInformationFull - - -class AuthenticationError(Exception): - def __init__(self, message: str): - self.message = message - - -class ClientAuthenticator: - """ - ClientAuthenticator is a callable which validates requests from a client - application, used to verify /token calls. - If, during registration, the client requested to be issued a secret, the - authenticator asserts that /token calls must be authenticated with - that same token. - NOTE: clients can opt for no authentication during registration, in which case this - logic is skipped. - """ - - def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): - """ - Initialize the dependency. - - Args: - provider: Provider to look up client information - """ - self.provider = provider - - async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull: - # Look up client information - client = await self.provider.get_client(client_id) - if not client: - raise AuthenticationError("Invalid client_id") - - # If client from the store expects a secret, validate that the request provides - # that secret - if client.client_secret: - if not client_secret: - raise AuthenticationError("Client secret is required") - - if client.client_secret != client_secret: - raise AuthenticationError("Invalid client_secret") - - if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): - raise AuthenticationError("Client secret has expired") - - return client diff --git a/src/mcp/mcp/server/auth/provider.py b/src/mcp/mcp/server/auth/provider.py deleted file mode 100644 index a7b108602..000000000 --- a/src/mcp/mcp/server/auth/provider.py +++ /dev/null @@ -1,306 +0,0 @@ -from dataclasses import dataclass -from typing import Generic, Literal, Protocol, TypeVar -from urllib.parse import parse_qs, urlencode, urlparse, urlunparse - -from pydantic import AnyUrl, BaseModel - -from mcp.shared.auth import OAuthClientInformationFull, OAuthToken - - -class AuthorizationParams(BaseModel): - state: str | None - scopes: list[str] | None - code_challenge: str - redirect_uri: AnyUrl - redirect_uri_provided_explicitly: bool - resource: str | None = None # RFC 8707 resource indicator - - -class AuthorizationCode(BaseModel): - code: str - scopes: list[str] - expires_at: float - client_id: str - code_challenge: str - redirect_uri: AnyUrl - redirect_uri_provided_explicitly: bool - resource: str | None = None # RFC 8707 resource indicator - - -class RefreshToken(BaseModel): - token: str - client_id: str - scopes: list[str] - expires_at: int | None = None - - -class AccessToken(BaseModel): - token: str - client_id: str - scopes: list[str] - expires_at: int | None = None - resource: str | None = None # RFC 8707 resource indicator - - -RegistrationErrorCode = Literal[ - "invalid_redirect_uri", - "invalid_client_metadata", - "invalid_software_statement", - "unapproved_software_statement", -] - - -@dataclass(frozen=True) -class RegistrationError(Exception): - error: RegistrationErrorCode - error_description: str | None = None - - -AuthorizationErrorCode = Literal[ - "invalid_request", - "unauthorized_client", - "access_denied", - "unsupported_response_type", - "invalid_scope", - "server_error", - "temporarily_unavailable", -] - - -@dataclass(frozen=True) -class AuthorizeError(Exception): - error: AuthorizationErrorCode - error_description: str | None = None - - -TokenErrorCode = Literal[ - "invalid_request", - "invalid_client", - "invalid_grant", - "unauthorized_client", - "unsupported_grant_type", - "invalid_scope", -] - - -@dataclass(frozen=True) -class TokenError(Exception): - error: TokenErrorCode - error_description: str | None = None - - -class TokenVerifier(Protocol): - """Protocol for verifying bearer tokens.""" - - async def verify_token(self, token: str) -> AccessToken | None: - """Verify a bearer token and return access info if valid.""" - - -# NOTE: FastMCP doesn't render any of these types in the user response, so it's -# OK to add fields to subclasses which should not be exposed externally. -AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) -RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) -AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) - - -class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]): - async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: - """ - Retrieves client information by client ID. - - Implementors MAY raise NotImplementedError if dynamic client registration is - disabled in ClientRegistrationOptions. - - Args: - client_id: The ID of the client to retrieve. - - Returns: - The client information, or None if the client does not exist. - """ - ... - - async def register_client(self, client_info: OAuthClientInformationFull) -> None: - """ - Saves client information as part of registering it. - - Implementors MAY raise NotImplementedError if dynamic client registration is - disabled in ClientRegistrationOptions. - - Args: - client_info: The client metadata to register. - - Raises: - RegistrationError: If the client metadata is invalid. - """ - ... - - async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str: - """ - Called as part of the /authorize endpoint, and returns a URL that the client - will be redirected to. - Many MCP implementations will redirect to a third-party provider to perform - a second OAuth exchange with that provider. In this sort of setup, the client - has an OAuth connection with the MCP server, and the MCP server has an OAuth - connection with the 3rd-party provider. At the end of this flow, the client - should be redirected to the redirect_uri from params.redirect_uri. - - +--------+ +------------+ +-------------------+ - | | | | | | - | Client | --> | MCP Server | --> | 3rd Party OAuth | - | | | | | Server | - +--------+ +------------+ +-------------------+ - | ^ | - +------------+ | | | - | | | | Redirect | - |redirect_uri|<-----+ +------------------+ - | | - +------------+ - - Implementations will need to define another handler on the MCP server return - flow to perform the second redirect, and generate and store an authorization - code as part of completing the OAuth authorization step. - - Implementations SHOULD generate an authorization code with at least 160 bits of - entropy, - and MUST generate an authorization code with at least 128 bits of entropy. - See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. - - Args: - client: The client requesting authorization. - params: The parameters of the authorization request. - - Returns: - A URL to redirect the client to for authorization. - - Raises: - AuthorizeError: If the authorization request is invalid. - """ - ... - - async def load_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: str - ) -> AuthorizationCodeT | None: - """ - Loads an AuthorizationCode by its code. - - Args: - client: The client that requested the authorization code. - authorization_code: The authorization code to get the challenge for. - - Returns: - The AuthorizationCode, or None if not found - """ - ... - - async def exchange_authorization_code( - self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT - ) -> OAuthToken: - """ - Exchanges an authorization code for an access token and refresh token. - - Args: - client: The client exchanging the authorization code. - authorization_code: The authorization code to exchange. - - Returns: - The OAuth token, containing access and refresh tokens. - - Raises: - TokenError: If the request is invalid - """ - ... - - async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None: - """ - Loads a RefreshToken by its token string. - - Args: - client: The client that is requesting to load the refresh token. - refresh_token: The refresh token string to load. - - Returns: - The RefreshToken object if found, or None if not found. - """ - - ... - - async def exchange_refresh_token( - self, - client: OAuthClientInformationFull, - refresh_token: RefreshTokenT, - scopes: list[str], - ) -> OAuthToken: - """ - Exchanges a refresh token for an access token and refresh token. - - Implementations SHOULD rotate both the access token and refresh token. - - Args: - client: The client exchanging the refresh token. - refresh_token: The refresh token to exchange. - scopes: Optional scopes to request with the new access token. - - Returns: - The OAuth token, containing access and refresh tokens. - - Raises: - TokenError: If the request is invalid - """ - ... - - async def load_access_token(self, token: str) -> AccessTokenT | None: - """ - Loads an access token by its token. - - Args: - token: The access token to verify. - - Returns: - The AuthInfo, or None if the token is invalid. - """ - ... - - async def revoke_token( - self, - token: AccessTokenT | RefreshTokenT, - ) -> None: - """ - Revokes an access or refresh token. - - If the given token is invalid or already revoked, this method should do nothing. - - Implementations SHOULD revoke both the access token and its corresponding - refresh token, regardless of which of the access token or refresh token is - provided. - - Args: - token: the token to revoke - """ - ... - - -def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: - parsed_uri = urlparse(redirect_uri_base) - query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs] - for k, v in params.items(): - if v is not None: - query_params.append((k, v)) - - redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params))) - return redirect_uri - - -class ProviderTokenVerifier(TokenVerifier): - """Token verifier that uses an OAuthAuthorizationServerProvider. - - This is provided for backwards compatibility with existing auth_server_provider - configurations. For new implementations using AS/RS separation, consider using - the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier. - """ - - def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"): - self.provider = provider - - async def verify_token(self, token: str) -> AccessToken | None: - """Verify token using the provider's load_access_token method.""" - return await self.provider.load_access_token(token) diff --git a/src/mcp/mcp/server/auth/routes.py b/src/mcp/mcp/server/auth/routes.py deleted file mode 100644 index bce32df52..000000000 --- a/src/mcp/mcp/server/auth/routes.py +++ /dev/null @@ -1,227 +0,0 @@ -from collections.abc import Awaitable, Callable -from typing import Any - -from pydantic import AnyHttpUrl -from starlette.middleware.cors import CORSMiddleware -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Route, request_response # type: ignore -from starlette.types import ASGIApp - -from mcp.server.auth.handlers.authorize import AuthorizationHandler -from mcp.server.auth.handlers.metadata import MetadataHandler -from mcp.server.auth.handlers.register import RegistrationHandler -from mcp.server.auth.handlers.revoke import RevocationHandler -from mcp.server.auth.handlers.token import TokenHandler -from mcp.server.auth.middleware.client_auth import ClientAuthenticator -from mcp.server.auth.provider import OAuthAuthorizationServerProvider -from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions -from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER -from mcp.shared.auth import OAuthMetadata - - -def validate_issuer_url(url: AnyHttpUrl): - """ - Validate that the issuer URL meets OAuth 2.0 requirements. - - Args: - url: The issuer URL to validate - - Raises: - ValueError: If the issuer URL is invalid - """ - - # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing - if ( - url.scheme != "https" - and url.host != "localhost" - and (url.host is not None and not url.host.startswith("127.0.0.1")) - ): - raise ValueError("Issuer URL must be HTTPS") - - # No fragments or query parameters allowed - if url.fragment: - raise ValueError("Issuer URL must not have a fragment") - if url.query: - raise ValueError("Issuer URL must not have a query string") - - -AUTHORIZATION_PATH = "/authorize" -TOKEN_PATH = "/token" -REGISTRATION_PATH = "/register" -REVOCATION_PATH = "/revoke" - - -def cors_middleware( - handler: Callable[[Request], Response | Awaitable[Response]], - allow_methods: list[str], -) -> ASGIApp: - cors_app = CORSMiddleware( - app=request_response(handler), - allow_origins="*", - allow_methods=allow_methods, - allow_headers=[MCP_PROTOCOL_VERSION_HEADER], - ) - return cors_app - - -def create_auth_routes( - provider: OAuthAuthorizationServerProvider[Any, Any, Any], - issuer_url: AnyHttpUrl, - service_documentation_url: AnyHttpUrl | None = None, - client_registration_options: ClientRegistrationOptions | None = None, - revocation_options: RevocationOptions | None = None, -) -> list[Route]: - validate_issuer_url(issuer_url) - - client_registration_options = client_registration_options or ClientRegistrationOptions() - revocation_options = revocation_options or RevocationOptions() - metadata = build_metadata( - issuer_url, - service_documentation_url, - client_registration_options, - revocation_options, - ) - client_authenticator = ClientAuthenticator(provider) - - # Create routes - # Allow CORS requests for endpoints meant to be hit by the OAuth client - # (with the client secret). This is intended to support things like MCP Inspector, - # where the client runs in a web browser. - routes = [ - Route( - "/.well-known/oauth-authorization-server", - endpoint=cors_middleware( - MetadataHandler(metadata).handle, - ["GET", "OPTIONS"], - ), - methods=["GET", "OPTIONS"], - ), - Route( - AUTHORIZATION_PATH, - # do not allow CORS for authorization endpoint; - # clients should just redirect to this - endpoint=AuthorizationHandler(provider).handle, - methods=["GET", "POST"], - ), - Route( - TOKEN_PATH, - endpoint=cors_middleware( - TokenHandler(provider, client_authenticator).handle, - ["POST", "OPTIONS"], - ), - methods=["POST", "OPTIONS"], - ), - ] - - if client_registration_options.enabled: - registration_handler = RegistrationHandler( - provider, - options=client_registration_options, - ) - routes.append( - Route( - REGISTRATION_PATH, - endpoint=cors_middleware( - registration_handler.handle, - ["POST", "OPTIONS"], - ), - methods=["POST", "OPTIONS"], - ) - ) - - if revocation_options.enabled: - revocation_handler = RevocationHandler(provider, client_authenticator) - routes.append( - Route( - REVOCATION_PATH, - endpoint=cors_middleware( - revocation_handler.handle, - ["POST", "OPTIONS"], - ), - methods=["POST", "OPTIONS"], - ) - ) - - return routes - - -def build_metadata( - issuer_url: AnyHttpUrl, - service_documentation_url: AnyHttpUrl | None, - client_registration_options: ClientRegistrationOptions, - revocation_options: RevocationOptions, -) -> OAuthMetadata: - authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH) - token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH) - - # Create metadata - metadata = OAuthMetadata( - issuer=issuer_url, - authorization_endpoint=authorization_url, - token_endpoint=token_url, - scopes_supported=client_registration_options.valid_scopes, - response_types_supported=["code"], - response_modes_supported=None, - grant_types_supported=["authorization_code", "refresh_token"], - token_endpoint_auth_methods_supported=["client_secret_post"], - token_endpoint_auth_signing_alg_values_supported=None, - service_documentation=service_documentation_url, - ui_locales_supported=None, - op_policy_uri=None, - op_tos_uri=None, - introspection_endpoint=None, - code_challenge_methods_supported=["S256"], - ) - - # Add registration endpoint if supported - if client_registration_options.enabled: - metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH) - - # Add revocation endpoint if supported - if revocation_options.enabled: - metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH) - metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] - - return metadata - - -def create_protected_resource_routes( - resource_url: AnyHttpUrl, - authorization_servers: list[AnyHttpUrl], - scopes_supported: list[str] | None = None, - resource_name: str | None = None, - resource_documentation: AnyHttpUrl | None = None, -) -> list[Route]: - """ - Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728). - - Args: - resource_url: The URL of this resource server - authorization_servers: List of authorization servers that can issue tokens - scopes_supported: Optional list of scopes supported by this resource - - Returns: - List of Starlette routes for protected resource metadata - """ - from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler - from mcp.shared.auth import ProtectedResourceMetadata - - metadata = ProtectedResourceMetadata( - resource=resource_url, - authorization_servers=authorization_servers, - scopes_supported=scopes_supported, - resource_name=resource_name, - resource_documentation=resource_documentation, - # bearer_methods_supported defaults to ["header"] in the model - ) - - handler = ProtectedResourceMetadataHandler(metadata) - - return [ - Route( - "/.well-known/oauth-protected-resource", - endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]), - methods=["GET", "OPTIONS"], - ) - ] diff --git a/src/mcp/mcp/server/auth/settings.py b/src/mcp/mcp/server/auth/settings.py deleted file mode 100644 index 1649826db..000000000 --- a/src/mcp/mcp/server/auth/settings.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import AnyHttpUrl, BaseModel, Field - - -class ClientRegistrationOptions(BaseModel): - enabled: bool = False - client_secret_expiry_seconds: int | None = None - valid_scopes: list[str] | None = None - default_scopes: list[str] | None = None - - -class RevocationOptions(BaseModel): - enabled: bool = False - - -class AuthSettings(BaseModel): - issuer_url: AnyHttpUrl = Field( - ..., - description="OAuth authorization server URL that issues tokens for this resource server.", - ) - service_documentation_url: AnyHttpUrl | None = None - client_registration_options: ClientRegistrationOptions | None = None - revocation_options: RevocationOptions | None = None - required_scopes: list[str] | None = None - - # Resource Server settings (when operating as RS only) - resource_server_url: AnyHttpUrl | None = Field( - ..., - description="The URL of the MCP server to be used as the resource identifier " - "and base route to look up OAuth Protected Resource Metadata.", - ) diff --git a/src/mcp/mcp/server/elicitation.py b/src/mcp/mcp/server/elicitation.py deleted file mode 100644 index 1e48738c8..000000000 --- a/src/mcp/mcp/server/elicitation.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Elicitation utilities for MCP servers.""" - -from __future__ import annotations - -import types -from typing import Generic, Literal, TypeVar, Union, get_args, get_origin - -from pydantic import BaseModel -from pydantic.fields import FieldInfo - -from mcp.server.session import ServerSession -from mcp.types import RequestId - -ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) - - -class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]): - """Result when user accepts the elicitation.""" - - action: Literal["accept"] = "accept" - data: ElicitSchemaModelT - - -class DeclinedElicitation(BaseModel): - """Result when user declines the elicitation.""" - - action: Literal["decline"] = "decline" - - -class CancelledElicitation(BaseModel): - """Result when user cancels the elicitation.""" - - action: Literal["cancel"] = "cancel" - - -ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation - - -# Primitive types allowed in elicitation schemas -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) - - -def _validate_elicitation_schema(schema: type[BaseModel]) -> None: - """Validate that a Pydantic model only contains primitive field types.""" - for field_name, field_info in schema.model_fields.items(): - if not _is_primitive_field(field_info): - raise TypeError( - f"Elicitation schema field '{field_name}' must be a primitive type " - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " - f"Complex types like lists, dicts, or nested models are not allowed." - ) - - -def _is_primitive_field(field_info: FieldInfo) -> bool: - """Check if a field is a primitive type allowed in elicitation schemas.""" - annotation = field_info.annotation - - # Handle None type - if annotation is types.NoneType: - return True - - # Handle basic primitive types - if annotation in _ELICITATION_PRIMITIVE_TYPES: - return True - - # Handle Union types - origin = get_origin(annotation) - if origin is Union or origin is types.UnionType: - args = get_args(annotation) - # All args must be primitive types or None - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) - - return False - - -async def elicit_with_validation( - session: ServerSession, - message: str, - schema: type[ElicitSchemaModelT], - related_request_id: RequestId | None = None, -) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user with schema validation. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - """ - # Validate that schema only contains primitive types and fail loudly if not - _validate_elicitation_schema(schema) - - json_schema = schema.model_json_schema() - - result = await session.elicit( - message=message, - requestedSchema=json_schema, - related_request_id=related_request_id, - ) - - if result.action == "accept" and result.content: - # Validate and parse the content using the schema - validated_data = schema.model_validate(result.content) - return AcceptedElicitation(data=validated_data) - elif result.action == "decline": - return DeclinedElicitation() - elif result.action == "cancel": - return CancelledElicitation() - else: - # This should never happen, but handle it just in case - raise ValueError(f"Unexpected elicitation action: {result.action}") diff --git a/src/mcp/mcp/server/fastmcp/__init__.py b/src/mcp/mcp/server/fastmcp/__init__.py deleted file mode 100644 index f8f9c1c4c..000000000 --- a/src/mcp/mcp/server/fastmcp/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""FastMCP - A more ergonomic interface for MCP servers.""" - -from importlib.metadata import version - -from .server import Context, FastMCP -from .utilities.types import Audio, Image - -__version__ = version("mcp") -__all__ = ["FastMCP", "Context", "Image", "Audio"] diff --git a/src/mcp/mcp/server/fastmcp/exceptions.py b/src/mcp/mcp/server/fastmcp/exceptions.py deleted file mode 100644 index fb5bda106..000000000 --- a/src/mcp/mcp/server/fastmcp/exceptions.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Custom exceptions for FastMCP.""" - - -class FastMCPError(Exception): - """Base error for FastMCP.""" - - -class ValidationError(FastMCPError): - """Error in validating parameters or return values.""" - - -class ResourceError(FastMCPError): - """Error in resource operations.""" - - -class ToolError(FastMCPError): - """Error in tool operations.""" - - -class InvalidSignature(Exception): - """Invalid signature for use with FastMCP.""" diff --git a/src/mcp/mcp/server/fastmcp/prompts/__init__.py b/src/mcp/mcp/server/fastmcp/prompts/__init__.py deleted file mode 100644 index 763726964..000000000 --- a/src/mcp/mcp/server/fastmcp/prompts/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import Prompt -from .manager import PromptManager - -__all__ = ["Prompt", "PromptManager"] diff --git a/src/mcp/mcp/server/fastmcp/prompts/base.py b/src/mcp/mcp/server/fastmcp/prompts/base.py deleted file mode 100644 index b45cfc917..000000000 --- a/src/mcp/mcp/server/fastmcp/prompts/base.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Base classes for FastMCP prompts.""" - -import inspect -from collections.abc import Awaitable, Callable, Sequence -from typing import Any, Literal - -import pydantic_core -from pydantic import BaseModel, Field, TypeAdapter, validate_call - -from mcp.types import ContentBlock, TextContent - - -class Message(BaseModel): - """Base class for all prompt messages.""" - - role: Literal["user", "assistant"] - content: ContentBlock - - def __init__(self, content: str | ContentBlock, **kwargs: Any): - if isinstance(content, str): - content = TextContent(type="text", text=content) - super().__init__(content=content, **kwargs) - - -class UserMessage(Message): - """A message from the user.""" - - role: Literal["user", "assistant"] = "user" - - def __init__(self, content: str | ContentBlock, **kwargs: Any): - super().__init__(content=content, **kwargs) - - -class AssistantMessage(Message): - """A message from the assistant.""" - - role: Literal["user", "assistant"] = "assistant" - - def __init__(self, content: str | ContentBlock, **kwargs: Any): - super().__init__(content=content, **kwargs) - - -message_validator = TypeAdapter[UserMessage | AssistantMessage](UserMessage | AssistantMessage) - -SyncPromptResult = str | Message | dict[str, Any] | Sequence[str | Message | dict[str, Any]] -PromptResult = SyncPromptResult | Awaitable[SyncPromptResult] - - -class PromptArgument(BaseModel): - """An argument that can be passed to a prompt.""" - - name: str = Field(description="Name of the argument") - description: str | None = Field(None, description="Description of what the argument does") - required: bool = Field(default=False, description="Whether the argument is required") - - -class Prompt(BaseModel): - """A prompt template that can be rendered with parameters.""" - - name: str = Field(description="Name of the prompt") - title: str | None = Field(None, description="Human-readable title of the prompt") - description: str | None = Field(None, description="Description of what the prompt does") - arguments: list[PromptArgument] | None = Field(None, description="Arguments that can be passed to the prompt") - fn: Callable[..., PromptResult | Awaitable[PromptResult]] = Field(exclude=True) - - @classmethod - def from_function( - cls, - fn: Callable[..., PromptResult | Awaitable[PromptResult]], - name: str | None = None, - title: str | None = None, - description: str | None = None, - ) -> "Prompt": - """Create a Prompt from a function. - - The function can return: - - A string (converted to a message) - - A Message object - - A dict (converted to a message) - - A sequence of any of the above - """ - func_name = name or fn.__name__ - - if func_name == "": - raise ValueError("You must provide a name for lambda functions") - - # Get schema from TypeAdapter - will fail if function isn't properly typed - parameters = TypeAdapter(fn).json_schema() - - # Convert parameters to PromptArguments - arguments: list[PromptArgument] = [] - if "properties" in parameters: - for param_name, param in parameters["properties"].items(): - required = param_name in parameters.get("required", []) - arguments.append( - PromptArgument( - name=param_name, - description=param.get("description"), - required=required, - ) - ) - - # ensure the arguments are properly cast - fn = validate_call(fn) - - return cls( - name=func_name, - title=title, - description=description or fn.__doc__ or "", - arguments=arguments, - fn=fn, - ) - - async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]: - """Render the prompt with arguments.""" - # Validate required arguments - if self.arguments: - required = {arg.name for arg in self.arguments if arg.required} - provided = set(arguments or {}) - missing = required - provided - if missing: - raise ValueError(f"Missing required arguments: {missing}") - - try: - # Call function and check if result is a coroutine - result = self.fn(**(arguments or {})) - if inspect.iscoroutine(result): - result = await result - - # Validate messages - if not isinstance(result, list | tuple): - result = [result] - - # Convert result to messages - messages: list[Message] = [] - for msg in result: # type: ignore[reportUnknownVariableType] - try: - if isinstance(msg, Message): - messages.append(msg) - elif isinstance(msg, dict): - messages.append(message_validator.validate_python(msg)) - elif isinstance(msg, str): - content = TextContent(type="text", text=msg) - messages.append(UserMessage(content=content)) - else: - content = pydantic_core.to_json(msg, fallback=str, indent=2).decode() - messages.append(Message(role="user", content=content)) - except Exception: - raise ValueError(f"Could not convert prompt result to message: {msg}") - - return messages - except Exception as e: - raise ValueError(f"Error rendering prompt {self.name}: {e}") diff --git a/src/mcp/mcp/server/fastmcp/prompts/manager.py b/src/mcp/mcp/server/fastmcp/prompts/manager.py deleted file mode 100644 index 6b01d91cd..000000000 --- a/src/mcp/mcp/server/fastmcp/prompts/manager.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Prompt management functionality.""" - -from typing import Any - -from mcp.server.fastmcp.prompts.base import Message, Prompt -from mcp.server.fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - - -class PromptManager: - """Manages FastMCP prompts.""" - - def __init__(self, warn_on_duplicate_prompts: bool = True): - self._prompts: dict[str, Prompt] = {} - self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - - def get_prompt(self, name: str) -> Prompt | None: - """Get prompt by name.""" - return self._prompts.get(name) - - def list_prompts(self) -> list[Prompt]: - """List all registered prompts.""" - return list(self._prompts.values()) - - def add_prompt( - self, - prompt: Prompt, - ) -> Prompt: - """Add a prompt to the manager.""" - - # Check for duplicates - existing = self._prompts.get(prompt.name) - if existing: - if self.warn_on_duplicate_prompts: - logger.warning(f"Prompt already exists: {prompt.name}") - return existing - - self._prompts[prompt.name] = prompt - return prompt - - async def render_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> list[Message]: - """Render a prompt by name with arguments.""" - prompt = self.get_prompt(name) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - return await prompt.render(arguments) diff --git a/src/mcp/mcp/server/fastmcp/resources/__init__.py b/src/mcp/mcp/server/fastmcp/resources/__init__.py deleted file mode 100644 index b5805fb34..000000000 --- a/src/mcp/mcp/server/fastmcp/resources/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -from .base import Resource -from .resource_manager import ResourceManager -from .templates import ResourceTemplate -from .types import ( - BinaryResource, - DirectoryResource, - FileResource, - FunctionResource, - HttpResource, - TextResource, -) - -__all__ = [ - "Resource", - "TextResource", - "BinaryResource", - "FunctionResource", - "FileResource", - "HttpResource", - "DirectoryResource", - "ResourceTemplate", - "ResourceManager", -] diff --git a/src/mcp/mcp/server/fastmcp/resources/base.py b/src/mcp/mcp/server/fastmcp/resources/base.py deleted file mode 100644 index f57631cc1..000000000 --- a/src/mcp/mcp/server/fastmcp/resources/base.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Base classes and interfaces for FastMCP resources.""" - -import abc -from typing import Annotated - -from pydantic import ( - AnyUrl, - BaseModel, - ConfigDict, - Field, - UrlConstraints, - ValidationInfo, - field_validator, -) - - -class Resource(BaseModel, abc.ABC): - """Base class for all resources.""" - - model_config = ConfigDict(validate_default=True) - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] = Field(default=..., description="URI of the resource") - name: str | None = Field(description="Name of the resource", default=None) - title: str | None = Field(description="Human-readable title of the resource", default=None) - description: str | None = Field(description="Description of the resource", default=None) - mime_type: str = Field( - default="text/plain", - description="MIME type of the resource content", - pattern=r"^[a-zA-Z0-9]+/[a-zA-Z0-9\-+.]+$", - ) - - @field_validator("name", mode="before") - @classmethod - def set_default_name(cls, name: str | None, info: ValidationInfo) -> str: - """Set default name from URI if not provided.""" - if name: - return name - if uri := info.data.get("uri"): - return str(uri) - raise ValueError("Either name or uri must be provided") - - @abc.abstractmethod - async def read(self) -> str | bytes: - """Read the resource content.""" - pass diff --git a/src/mcp/mcp/server/fastmcp/resources/resource_manager.py b/src/mcp/mcp/server/fastmcp/resources/resource_manager.py deleted file mode 100644 index 35e4ec04d..000000000 --- a/src/mcp/mcp/server/fastmcp/resources/resource_manager.py +++ /dev/null @@ -1,97 +0,0 @@ -"""Resource manager functionality.""" - -from collections.abc import Callable -from typing import Any - -from pydantic import AnyUrl - -from mcp.server.fastmcp.resources.base import Resource -from mcp.server.fastmcp.resources.templates import ResourceTemplate -from mcp.server.fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - - -class ResourceManager: - """Manages FastMCP resources.""" - - def __init__(self, warn_on_duplicate_resources: bool = True): - self._resources: dict[str, Resource] = {} - self._templates: dict[str, ResourceTemplate] = {} - self.warn_on_duplicate_resources = warn_on_duplicate_resources - - def add_resource(self, resource: Resource) -> Resource: - """Add a resource to the manager. - - Args: - resource: A Resource instance to add - - Returns: - The added resource. If a resource with the same URI already exists, - returns the existing resource. - """ - logger.debug( - "Adding resource", - extra={ - "uri": resource.uri, - "type": type(resource).__name__, - "resource_name": resource.name, - }, - ) - existing = self._resources.get(str(resource.uri)) - if existing: - if self.warn_on_duplicate_resources: - logger.warning(f"Resource already exists: {resource.uri}") - return existing - self._resources[str(resource.uri)] = resource - return resource - - def add_template( - self, - fn: Callable[..., Any], - uri_template: str, - name: str | None = None, - title: str | None = None, - description: str | None = None, - mime_type: str | None = None, - ) -> ResourceTemplate: - """Add a template from a function.""" - template = ResourceTemplate.from_function( - fn, - uri_template=uri_template, - name=name, - title=title, - description=description, - mime_type=mime_type, - ) - self._templates[template.uri_template] = template - return template - - async def get_resource(self, uri: AnyUrl | str) -> Resource | None: - """Get resource by URI, checking concrete resources first, then templates.""" - uri_str = str(uri) - logger.debug("Getting resource", extra={"uri": uri_str}) - - # First check concrete resources - if resource := self._resources.get(uri_str): - return resource - - # Then check templates - for template in self._templates.values(): - if params := template.matches(uri_str): - try: - return await template.create_resource(uri_str, params) - except Exception as e: - raise ValueError(f"Error creating resource from template: {e}") - - raise ValueError(f"Unknown resource: {uri}") - - def list_resources(self) -> list[Resource]: - """List all registered resources.""" - logger.debug("Listing resources", extra={"count": len(self._resources)}) - return list(self._resources.values()) - - def list_templates(self) -> list[ResourceTemplate]: - """List all registered templates.""" - logger.debug("Listing templates", extra={"count": len(self._templates)}) - return list(self._templates.values()) diff --git a/src/mcp/mcp/server/fastmcp/resources/templates.py b/src/mcp/mcp/server/fastmcp/resources/templates.py deleted file mode 100644 index b1c7b2711..000000000 --- a/src/mcp/mcp/server/fastmcp/resources/templates.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Resource template functionality.""" - -from __future__ import annotations - -import inspect -import re -from collections.abc import Callable -from typing import Any - -from pydantic import BaseModel, Field, TypeAdapter, validate_call - -from mcp.server.fastmcp.resources.types import FunctionResource, Resource - - -class ResourceTemplate(BaseModel): - """A template for dynamically creating resources.""" - - uri_template: str = Field(description="URI template with parameters (e.g. weather://{city}/current)") - name: str = Field(description="Name of the resource") - title: str | None = Field(description="Human-readable title of the resource", default=None) - description: str | None = Field(description="Description of what the resource does") - mime_type: str = Field(default="text/plain", description="MIME type of the resource content") - fn: Callable[..., Any] = Field(exclude=True) - parameters: dict[str, Any] = Field(description="JSON schema for function parameters") - - @classmethod - def from_function( - cls, - fn: Callable[..., Any], - uri_template: str, - name: str | None = None, - title: str | None = None, - description: str | None = None, - mime_type: str | None = None, - ) -> ResourceTemplate: - """Create a template from a function.""" - func_name = name or fn.__name__ - if func_name == "": - raise ValueError("You must provide a name for lambda functions") - - # Get schema from TypeAdapter - will fail if function isn't properly typed - parameters = TypeAdapter(fn).json_schema() - - # ensure the arguments are properly cast - fn = validate_call(fn) - - return cls( - uri_template=uri_template, - name=func_name, - title=title, - description=description or fn.__doc__ or "", - mime_type=mime_type or "text/plain", - fn=fn, - parameters=parameters, - ) - - def matches(self, uri: str) -> dict[str, Any] | None: - """Check if URI matches template and extract parameters.""" - # Convert template to regex pattern - pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)") - match = re.match(f"^{pattern}$", uri) - if match: - return match.groupdict() - return None - - async def create_resource(self, uri: str, params: dict[str, Any]) -> Resource: - """Create a resource from the template with the given parameters.""" - try: - # Call function and check if result is a coroutine - result = self.fn(**params) - if inspect.iscoroutine(result): - result = await result - - return FunctionResource( - uri=uri, # type: ignore - name=self.name, - title=self.title, - description=self.description, - mime_type=self.mime_type, - fn=lambda: result, # Capture result in closure - ) - except Exception as e: - raise ValueError(f"Error creating resource from template: {e}") diff --git a/src/mcp/mcp/server/fastmcp/resources/types.py b/src/mcp/mcp/server/fastmcp/resources/types.py deleted file mode 100644 index f2a330706..000000000 --- a/src/mcp/mcp/server/fastmcp/resources/types.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Concrete resource implementations.""" - -import inspect -import json -from collections.abc import Callable -from pathlib import Path -from typing import Any - -import anyio -import anyio.to_thread -import httpx -import pydantic -import pydantic_core -from pydantic import AnyUrl, Field, ValidationInfo, validate_call - -from mcp.server.fastmcp.resources.base import Resource - - -class TextResource(Resource): - """A resource that reads from a string.""" - - text: str = Field(description="Text content of the resource") - - async def read(self) -> str: - """Read the text content.""" - return self.text - - -class BinaryResource(Resource): - """A resource that reads from bytes.""" - - data: bytes = Field(description="Binary content of the resource") - - async def read(self) -> bytes: - """Read the binary content.""" - return self.data - - -class FunctionResource(Resource): - """A resource that defers data loading by wrapping a function. - - The function is only called when the resource is read, allowing for lazy loading - of potentially expensive data. This is particularly useful when listing resources, - as the function won't be called until the resource is actually accessed. - - The function can return: - - str for text content (default) - - bytes for binary content - - other types will be converted to JSON - """ - - fn: Callable[[], Any] = Field(exclude=True) - - async def read(self) -> str | bytes: - """Read the resource by calling the wrapped function.""" - try: - # Call the function first to see if it returns a coroutine - result = self.fn() - # If it's a coroutine, await it - if inspect.iscoroutine(result): - result = await result - - if isinstance(result, Resource): - return await result.read() - elif isinstance(result, bytes): - return result - elif isinstance(result, str): - return result - else: - return pydantic_core.to_json(result, fallback=str, indent=2).decode() - except Exception as e: - raise ValueError(f"Error reading resource {self.uri}: {e}") - - @classmethod - def from_function( - cls, - fn: Callable[..., Any], - uri: str, - name: str | None = None, - title: str | None = None, - description: str | None = None, - mime_type: str | None = None, - ) -> "FunctionResource": - """Create a FunctionResource from a function.""" - func_name = name or fn.__name__ - if func_name == "": - raise ValueError("You must provide a name for lambda functions") - - # ensure the arguments are properly cast - fn = validate_call(fn) - - return cls( - uri=AnyUrl(uri), - name=func_name, - title=title, - description=description or fn.__doc__ or "", - mime_type=mime_type or "text/plain", - fn=fn, - ) - - -class FileResource(Resource): - """A resource that reads from a file. - - Set is_binary=True to read file as binary data instead of text. - """ - - path: Path = Field(description="Path to the file") - is_binary: bool = Field( - default=False, - description="Whether to read the file as binary data", - ) - mime_type: str = Field( - default="text/plain", - description="MIME type of the resource content", - ) - - @pydantic.field_validator("path") - @classmethod - def validate_absolute_path(cls, path: Path) -> Path: - """Ensure path is absolute.""" - if not path.is_absolute(): - raise ValueError("Path must be absolute") - return path - - @pydantic.field_validator("is_binary") - @classmethod - def set_binary_from_mime_type(cls, is_binary: bool, info: ValidationInfo) -> bool: - """Set is_binary based on mime_type if not explicitly set.""" - if is_binary: - return True - mime_type = info.data.get("mime_type", "text/plain") - return not mime_type.startswith("text/") - - async def read(self) -> str | bytes: - """Read the file content.""" - try: - if self.is_binary: - return await anyio.to_thread.run_sync(self.path.read_bytes) - return await anyio.to_thread.run_sync(self.path.read_text) - except Exception as e: - raise ValueError(f"Error reading file {self.path}: {e}") - - -class HttpResource(Resource): - """A resource that reads from an HTTP endpoint.""" - - url: str = Field(description="URL to fetch content from") - mime_type: str = Field(default="application/json", description="MIME type of the resource content") - - async def read(self) -> str | bytes: - """Read the HTTP content.""" - async with httpx.AsyncClient() as client: - response = await client.get(self.url) - response.raise_for_status() - return response.text - - -class DirectoryResource(Resource): - """A resource that lists files in a directory.""" - - path: Path = Field(description="Path to the directory") - recursive: bool = Field(default=False, description="Whether to list files recursively") - pattern: str | None = Field(default=None, description="Optional glob pattern to filter files") - mime_type: str = Field(default="application/json", description="MIME type of the resource content") - - @pydantic.field_validator("path") - @classmethod - def validate_absolute_path(cls, path: Path) -> Path: - """Ensure path is absolute.""" - if not path.is_absolute(): - raise ValueError("Path must be absolute") - return path - - def list_files(self) -> list[Path]: - """List files in the directory.""" - if not self.path.exists(): - raise FileNotFoundError(f"Directory not found: {self.path}") - if not self.path.is_dir(): - raise NotADirectoryError(f"Not a directory: {self.path}") - - try: - if self.pattern: - return list(self.path.glob(self.pattern)) if not self.recursive else list(self.path.rglob(self.pattern)) - return list(self.path.glob("*")) if not self.recursive else list(self.path.rglob("*")) - except Exception as e: - raise ValueError(f"Error listing directory {self.path}: {e}") - - async def read(self) -> str: # Always returns JSON string - """Read the directory listing.""" - try: - files = await anyio.to_thread.run_sync(self.list_files) - file_list = [str(f.relative_to(self.path)) for f in files if f.is_file()] - return json.dumps({"files": file_list}, indent=2) - except Exception as e: - raise ValueError(f"Error reading directory {self.path}: {e}") diff --git a/src/mcp/mcp/server/fastmcp/server.py b/src/mcp/mcp/server/fastmcp/server.py deleted file mode 100644 index 924baaa9b..000000000 --- a/src/mcp/mcp/server/fastmcp/server.py +++ /dev/null @@ -1,1188 +0,0 @@ -"""FastMCP - A more ergonomic interface for MCP servers.""" - -from __future__ import annotations as _annotations - -import inspect -import re -from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence -from contextlib import AbstractAsyncContextManager, asynccontextmanager -from typing import Any, Generic, Literal - -import anyio -import pydantic_core -from pydantic import BaseModel -from pydantic.networks import AnyUrl -from pydantic_settings import BaseSettings, SettingsConfigDict -from starlette.applications import Starlette -from starlette.middleware import Middleware -from starlette.middleware.authentication import AuthenticationMiddleware -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Mount, Route -from starlette.types import Receive, Scope, Send - -from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware -from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier -from mcp.server.auth.settings import AuthSettings -from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation -from mcp.server.fastmcp.exceptions import ResourceError -from mcp.server.fastmcp.prompts import Prompt, PromptManager -from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager -from mcp.server.fastmcp.tools import Tool, ToolManager -from mcp.server.fastmcp.utilities.logging import configure_logging, get_logger -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.lowlevel.server import LifespanResultT -from mcp.server.lowlevel.server import Server as MCPServer -from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.server.session import ServerSession, ServerSessionT -from mcp.server.sse import SseServerTransport -from mcp.server.stdio import stdio_server -from mcp.server.streamable_http import EventStore -from mcp.server.streamable_http_manager import StreamableHTTPSessionManager -from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.context import LifespanContextT, RequestContext, RequestT -from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations -from mcp.types import Prompt as MCPPrompt -from mcp.types import PromptArgument as MCPPromptArgument -from mcp.types import Resource as MCPResource -from mcp.types import ResourceTemplate as MCPResourceTemplate -from mcp.types import Tool as MCPTool - -logger = get_logger(__name__) - - -class Settings(BaseSettings, Generic[LifespanResultT]): - """FastMCP server settings. - - All settings can be configured via environment variables with the prefix FASTMCP_. - For example, FASTMCP_DEBUG=true will set debug=True. - """ - - model_config = SettingsConfigDict( - env_prefix="FASTMCP_", - env_file=".env", - env_nested_delimiter="__", - nested_model_default_partial_update=True, - extra="ignore", - ) - - # Server settings - debug: bool - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] - - # HTTP settings - host: str - port: int - mount_path: str - sse_path: str - message_path: str - streamable_http_path: str - - # StreamableHTTP settings - json_response: bool - stateless_http: bool - """Define if the server should create a new transport per request.""" - - # resource settings - warn_on_duplicate_resources: bool - - # tool settings - warn_on_duplicate_tools: bool - - # prompt settings - warn_on_duplicate_prompts: bool - - # TODO(Marcelo): Investigate if this is used. If it is, it's probably a good idea to remove it. - dependencies: list[str] - """A list of dependencies to install in the server environment.""" - - lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None - """A async context manager that will be called when the server is started.""" - - auth: AuthSettings | None - - # Transport security settings (DNS rebinding protection) - transport_security: TransportSecuritySettings | None - - -def lifespan_wrapper( - app: FastMCP[LifespanResultT], - lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: - @asynccontextmanager - async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: - async with lifespan(app) as context: - yield context - - return wrap - - -class FastMCP(Generic[LifespanResultT]): - def __init__( - self, - name: str | None = None, - instructions: str | None = None, - auth_server_provider: OAuthAuthorizationServerProvider[Any, Any, Any] | None = None, - token_verifier: TokenVerifier | None = None, - event_store: EventStore | None = None, - *, - tools: list[Tool] | None = None, - debug: bool = False, - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", - host: str = "127.0.0.1", - port: int = 8000, - mount_path: str = "/", - sse_path: str = "/sse", - message_path: str = "/messages/", - streamable_http_path: str = "/mcp", - json_response: bool = False, - stateless_http: bool = False, - warn_on_duplicate_resources: bool = True, - warn_on_duplicate_tools: bool = True, - warn_on_duplicate_prompts: bool = True, - dependencies: Collection[str] = (), - lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, - auth: AuthSettings | None = None, - transport_security: TransportSecuritySettings | None = None, - ): - self.settings = Settings( - debug=debug, - log_level=log_level, - host=host, - port=port, - mount_path=mount_path, - sse_path=sse_path, - message_path=message_path, - streamable_http_path=streamable_http_path, - json_response=json_response, - stateless_http=stateless_http, - warn_on_duplicate_resources=warn_on_duplicate_resources, - warn_on_duplicate_tools=warn_on_duplicate_tools, - warn_on_duplicate_prompts=warn_on_duplicate_prompts, - dependencies=list(dependencies), - lifespan=lifespan, - auth=auth, - transport_security=transport_security, - ) - - self._mcp_server = MCPServer( - name=name or "FastMCP", - instructions=instructions, - # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. - # We need to create a Lifespan type that is a generic on the server type, like Starlette does. - lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore - ) - self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) - self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) - self._prompt_manager = PromptManager(warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts) - # Validate auth configuration - if self.settings.auth is not None: - if auth_server_provider and token_verifier: - raise ValueError("Cannot specify both auth_server_provider and token_verifier") - if not auth_server_provider and not token_verifier: - raise ValueError("Must specify either auth_server_provider or token_verifier when auth is enabled") - else: - if auth_server_provider or token_verifier: - raise ValueError("Cannot specify auth_server_provider or token_verifier without auth settings") - - self._auth_server_provider = auth_server_provider - self._token_verifier = token_verifier - - # Create token verifier from provider if needed (backwards compatibility) - if auth_server_provider and not token_verifier: - self._token_verifier = ProviderTokenVerifier(auth_server_provider) - self._event_store = event_store - self._custom_starlette_routes: list[Route] = [] - self.dependencies = self.settings.dependencies - self._session_manager: StreamableHTTPSessionManager | None = None - - # Set up MCP protocol handlers - self._setup_handlers() - - # Configure logging - configure_logging(self.settings.log_level) - - @property - def name(self) -> str: - return self._mcp_server.name - - @property - def instructions(self) -> str | None: - return self._mcp_server.instructions - - @property - def session_manager(self) -> StreamableHTTPSessionManager: - """Get the StreamableHTTP session manager. - - This is exposed to enable advanced use cases like mounting multiple - FastMCP servers in a single FastAPI application. - - Raises: - RuntimeError: If called before streamable_http_app() has been called. - """ - if self._session_manager is None: - raise RuntimeError( - "Session manager can only be accessed after" - "calling streamable_http_app()." - "The session manager is created lazily" - "to avoid unnecessary initialization." - ) - return self._session_manager - - def run( - self, - transport: Literal["stdio", "sse", "streamable-http"] = "stdio", - mount_path: str | None = None, - ) -> None: - """Run the FastMCP server. Note this is a synchronous function. - - Args: - transport: Transport protocol to use ("stdio", "sse", or "streamable-http") - mount_path: Optional mount path for SSE transport - """ - TRANSPORTS = Literal["stdio", "sse", "streamable-http"] - if transport not in TRANSPORTS.__args__: # type: ignore - raise ValueError(f"Unknown transport: {transport}") - - match transport: - case "stdio": - anyio.run(self.run_stdio_async) - case "sse": - anyio.run(lambda: self.run_sse_async(mount_path)) - case "streamable-http": - anyio.run(self.run_streamable_http_async) - - def _setup_handlers(self) -> None: - """Set up core MCP protocol handlers.""" - self._mcp_server.list_tools()(self.list_tools) - # Note: we disable the lowlevel server's input validation. - # FastMCP does ad hoc conversion of incoming data before validating - - # for now we preserve this for backwards compatibility. - self._mcp_server.call_tool(validate_input=False)(self.call_tool) - self._mcp_server.list_resources()(self.list_resources) - self._mcp_server.read_resource()(self.read_resource) - self._mcp_server.list_prompts()(self.list_prompts) - self._mcp_server.get_prompt()(self.get_prompt) - self._mcp_server.list_resource_templates()(self.list_resource_templates) - - async def list_tools(self) -> list[MCPTool]: - """List all available tools.""" - tools = self._tool_manager.list_tools() - return [ - MCPTool( - name=info.name, - title=info.title, - description=info.description, - inputSchema=info.parameters, - outputSchema=info.output_schema, - annotations=info.annotations, - ) - for info in tools - ] - - def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: - """ - Returns a Context object. Note that the context will only be valid - during a request; outside a request, most methods will error. - """ - try: - request_context = self._mcp_server.request_context - except LookupError: - request_context = None - return Context(request_context=request_context, fastmcp=self) - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> Sequence[ContentBlock] | dict[str, Any]: - """Call a tool by name with arguments.""" - context = self.get_context() - return await self._tool_manager.call_tool(name, arguments, context=context, convert_result=True) - - async def list_resources(self) -> list[MCPResource]: - """List all available resources.""" - - resources = self._resource_manager.list_resources() - return [ - MCPResource( - uri=resource.uri, - name=resource.name or "", - title=resource.title, - description=resource.description, - mimeType=resource.mime_type, - ) - for resource in resources - ] - - async def list_resource_templates(self) -> list[MCPResourceTemplate]: - templates = self._resource_manager.list_templates() - return [ - MCPResourceTemplate( - uriTemplate=template.uri_template, - name=template.name, - title=template.title, - description=template.description, - ) - for template in templates - ] - - async def read_resource(self, uri: AnyUrl | str) -> Iterable[ReadResourceContents]: - """Read a resource by URI.""" - - resource = await self._resource_manager.get_resource(uri) - if not resource: - raise ResourceError(f"Unknown resource: {uri}") - - try: - content = await resource.read() - return [ReadResourceContents(content=content, mime_type=resource.mime_type)] - except Exception as e: - logger.exception(f"Error reading resource {uri}") - raise ResourceError(str(e)) - - def add_tool( - self, - fn: AnyFunction, - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - structured_output: bool | None = None, - ) -> None: - """Add a tool to the server. - - The tool function can optionally request a Context object by adding a parameter - with the Context type annotation. See the @tool decorator for examples. - - Args: - fn: The function to register as a tool - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, unconditionally creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool - """ - self._tool_manager.add_tool( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - structured_output=structured_output, - ) - - def tool( - self, - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - structured_output: bool | None = None, - ) -> Callable[[AnyFunction], AnyFunction]: - """Decorator to register a tool. - - Tools can optionally request a Context object by adding a parameter with the - Context type annotation. The context provides access to MCP capabilities like - logging, progress reporting, and resource access. - - Args: - name: Optional name for the tool (defaults to function name) - title: Optional human-readable title for the tool - description: Optional description of what the tool does - annotations: Optional ToolAnnotations providing additional tool information - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, unconditionally creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool - - Example: - @server.tool() - def my_tool(x: int) -> str: - return str(x) - - @server.tool() - def tool_with_context(x: int, ctx: Context) -> str: - ctx.info(f"Processing {x}") - return str(x) - - @server.tool() - async def async_tool(x: int, context: Context) -> str: - await context.report_progress(50, 100) - return str(x) - """ - # Check if user passed function directly instead of calling decorator - if callable(name): - raise TypeError( - "The @tool decorator was used incorrectly. Did you forget to call it? Use @tool() instead of @tool" - ) - - def decorator(fn: AnyFunction) -> AnyFunction: - self.add_tool( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - structured_output=structured_output, - ) - return fn - - return decorator - - def completion(self): - """Decorator to register a completion handler. - - The completion handler receives: - - ref: PromptReference or ResourceTemplateReference - - argument: CompletionArgument with name and partial value - - context: Optional CompletionContext with previously resolved arguments - - Example: - @mcp.completion() - async def handle_completion(ref, argument, context): - if isinstance(ref, ResourceTemplateReference): - # Return completions based on ref, argument, and context - return Completion(values=["option1", "option2"]) - return None - """ - return self._mcp_server.completion() - - def add_resource(self, resource: Resource) -> None: - """Add a resource to the server. - - Args: - resource: A Resource instance to add - """ - self._resource_manager.add_resource(resource) - - def resource( - self, - uri: str, - *, - name: str | None = None, - title: str | None = None, - description: str | None = None, - mime_type: str | None = None, - ) -> Callable[[AnyFunction], AnyFunction]: - """Decorator to register a function as a resource. - - The function will be called when the resource is read to generate its content. - The function can return: - - str for text content - - bytes for binary content - - other types will be converted to JSON - - If the URI contains parameters (e.g. "resource://{param}") or the function - has parameters, it will be registered as a template resource. - - Args: - uri: URI for the resource (e.g. "resource://my-resource" or "resource://{param}") - name: Optional name for the resource - title: Optional human-readable title for the resource - description: Optional description of the resource - mime_type: Optional MIME type for the resource - - Example: - @server.resource("resource://my-resource") - def get_data() -> str: - return "Hello, world!" - - @server.resource("resource://my-resource") - async get_data() -> str: - data = await fetch_data() - return f"Hello, world! {data}" - - @server.resource("resource://{city}/weather") - def get_weather(city: str) -> str: - return f"Weather for {city}" - - @server.resource("resource://{city}/weather") - async def get_weather(city: str) -> str: - data = await fetch_weather(city) - return f"Weather for {city}: {data}" - """ - # Check if user passed function directly instead of calling decorator - if callable(uri): - raise TypeError( - "The @resource decorator was used incorrectly. " - "Did you forget to call it? Use @resource('uri') instead of @resource" - ) - - def decorator(fn: AnyFunction) -> AnyFunction: - # Check if this should be a template - has_uri_params = "{" in uri and "}" in uri - has_func_params = bool(inspect.signature(fn).parameters) - - if has_uri_params or has_func_params: - # Validate that URI params match function params - uri_params = set(re.findall(r"{(\w+)}", uri)) - func_params = set(inspect.signature(fn).parameters.keys()) - - if uri_params != func_params: - raise ValueError( - f"Mismatch between URI parameters {uri_params} and function parameters {func_params}" - ) - - # Register as template - self._resource_manager.add_template( - fn=fn, - uri_template=uri, - name=name, - title=title, - description=description, - mime_type=mime_type, - ) - else: - # Register as regular resource - resource = FunctionResource.from_function( - fn=fn, - uri=uri, - name=name, - title=title, - description=description, - mime_type=mime_type, - ) - self.add_resource(resource) - return fn - - return decorator - - def add_prompt(self, prompt: Prompt) -> None: - """Add a prompt to the server. - - Args: - prompt: A Prompt instance to add - """ - self._prompt_manager.add_prompt(prompt) - - def prompt( - self, name: str | None = None, title: str | None = None, description: str | None = None - ) -> Callable[[AnyFunction], AnyFunction]: - """Decorator to register a prompt. - - Args: - name: Optional name for the prompt (defaults to function name) - title: Optional human-readable title for the prompt - description: Optional description of what the prompt does - - Example: - @server.prompt() - def analyze_table(table_name: str) -> list[Message]: - schema = read_table_schema(table_name) - return [ - { - "role": "user", - "content": f"Analyze this schema:\n{schema}" - } - ] - - @server.prompt() - async def analyze_file(path: str) -> list[Message]: - content = await read_file(path) - return [ - { - "role": "user", - "content": { - "type": "resource", - "resource": { - "uri": f"file://{path}", - "text": content - } - } - } - ] - """ - # Check if user passed function directly instead of calling decorator - if callable(name): - raise TypeError( - "The @prompt decorator was used incorrectly. " - "Did you forget to call it? Use @prompt() instead of @prompt" - ) - - def decorator(func: AnyFunction) -> AnyFunction: - prompt = Prompt.from_function(func, name=name, title=title, description=description) - self.add_prompt(prompt) - return func - - return decorator - - def custom_route( - self, - path: str, - methods: list[str], - name: str | None = None, - include_in_schema: bool = True, - ): - """ - Decorator to register a custom HTTP route on the FastMCP server. - - Allows adding arbitrary HTTP endpoints outside the standard MCP protocol, - which can be useful for OAuth callbacks, health checks, or admin APIs. - The handler function must be an async function that accepts a Starlette - Request and returns a Response. - - Args: - path: URL path for the route (e.g., "/oauth/callback") - methods: List of HTTP methods to support (e.g., ["GET", "POST"]) - name: Optional name for the route (to reference this route with - Starlette's reverse URL lookup feature) - include_in_schema: Whether to include in OpenAPI schema, defaults to True - - Example: - @server.custom_route("/health", methods=["GET"]) - async def health_check(request: Request) -> Response: - return JSONResponse({"status": "ok"}) - """ - - def decorator( - func: Callable[[Request], Awaitable[Response]], - ) -> Callable[[Request], Awaitable[Response]]: - self._custom_starlette_routes.append( - Route( - path, - endpoint=func, - methods=methods, - name=name, - include_in_schema=include_in_schema, - ) - ) - return func - - return decorator - - async def run_stdio_async(self) -> None: - """Run the server using stdio transport.""" - async with stdio_server() as (read_stream, write_stream): - await self._mcp_server.run( - read_stream, - write_stream, - self._mcp_server.create_initialization_options(), - ) - - async def run_sse_async(self, mount_path: str | None = None) -> None: - """Run the server using SSE transport.""" - import uvicorn - - starlette_app = self.sse_app(mount_path) - - config = uvicorn.Config( - starlette_app, - host=self.settings.host, - port=self.settings.port, - log_level=self.settings.log_level.lower(), - ) - server = uvicorn.Server(config) - await server.serve() - - async def run_streamable_http_async(self) -> None: - """Run the server using StreamableHTTP transport.""" - import uvicorn - - starlette_app = self.streamable_http_app() - - config = uvicorn.Config( - starlette_app, - host=self.settings.host, - port=self.settings.port, - log_level=self.settings.log_level.lower(), - ) - server = uvicorn.Server(config) - await server.serve() - - def _normalize_path(self, mount_path: str, endpoint: str) -> str: - """ - Combine mount path and endpoint to return a normalized path. - - Args: - mount_path: The mount path (e.g. "/github" or "/") - endpoint: The endpoint path (e.g. "/messages/") - - Returns: - Normalized path (e.g. "/github/messages/") - """ - # Special case: root path - if mount_path == "/": - return endpoint - - # Remove trailing slash from mount path - if mount_path.endswith("/"): - mount_path = mount_path[:-1] - - # Ensure endpoint starts with slash - if not endpoint.startswith("/"): - endpoint = "/" + endpoint - - # Combine paths - return mount_path + endpoint - - def sse_app(self, mount_path: str | None = None) -> Starlette: - """Return an instance of the SSE server app.""" - from starlette.middleware import Middleware - from starlette.routing import Mount, Route - - # Update mount_path in settings if provided - if mount_path is not None: - self.settings.mount_path = mount_path - - # Create normalized endpoint considering the mount path - normalized_message_endpoint = self._normalize_path(self.settings.mount_path, self.settings.message_path) - - # Set up auth context and dependencies - - sse = SseServerTransport( - normalized_message_endpoint, - security_settings=self.settings.transport_security, - ) - - async def handle_sse(scope: Scope, receive: Receive, send: Send): - # Add client ID from auth context into request context if available - - async with sse.connect_sse( - scope, - receive, - send, - ) as streams: - await self._mcp_server.run( - streams[0], - streams[1], - self._mcp_server.create_initialization_options(), - ) - return Response() - - # Create routes - routes: list[Route | Mount] = [] - middleware: list[Middleware] = [] - required_scopes = [] - - # Set up auth if configured - if self.settings.auth: - required_scopes = self.settings.auth.required_scopes or [] - - # Add auth middleware if token verifier is available - if self._token_verifier: - middleware = [ - # extract auth info from request (but do not require it) - Middleware( - AuthenticationMiddleware, - backend=BearerAuthBackend(self._token_verifier), - ), - # Add the auth context middleware to store - # authenticated user in a contextvar - Middleware(AuthContextMiddleware), - ] - - # Add auth endpoints if auth server provider is configured - if self._auth_server_provider: - from mcp.server.auth.routes import create_auth_routes - - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, - ) - ) - - # When auth is configured, require authentication - if self._token_verifier: - # Determine resource metadata URL - resource_metadata_url = None - if self.settings.auth and self.settings.auth.resource_server_url: - from pydantic import AnyHttpUrl - - resource_metadata_url = AnyHttpUrl( - str(self.settings.auth.resource_server_url).rstrip("/") + "/.well-known/oauth-protected-resource" - ) - - # Auth is enabled, wrap the endpoints with RequireAuthMiddleware - routes.append( - Route( - self.settings.sse_path, - endpoint=RequireAuthMiddleware(handle_sse, required_scopes, resource_metadata_url), - methods=["GET"], - ) - ) - routes.append( - Mount( - self.settings.message_path, - app=RequireAuthMiddleware(sse.handle_post_message, required_scopes, resource_metadata_url), - ) - ) - else: - # Auth is disabled, no need for RequireAuthMiddleware - # Since handle_sse is an ASGI app, we need to create a compatible endpoint - async def sse_endpoint(request: Request) -> Response: - # Convert the Starlette request to ASGI parameters - return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] - - routes.append( - Route( - self.settings.sse_path, - endpoint=sse_endpoint, - methods=["GET"], - ) - ) - routes.append( - Mount( - self.settings.message_path, - app=sse.handle_post_message, - ) - ) - # Add protected resource metadata endpoint if configured as RS - if self.settings.auth and self.settings.auth.resource_server_url: - from mcp.server.auth.routes import create_protected_resource_routes - - routes.extend( - create_protected_resource_routes( - resource_url=self.settings.auth.resource_server_url, - authorization_servers=[self.settings.auth.issuer_url], - scopes_supported=self.settings.auth.required_scopes, - ) - ) - - # mount these routes last, so they have the lowest route matching precedence - routes.extend(self._custom_starlette_routes) - - # Create Starlette app with routes and middleware - return Starlette(debug=self.settings.debug, routes=routes, middleware=middleware) - - def streamable_http_app(self) -> Starlette: - """Return an instance of the StreamableHTTP server app.""" - from starlette.middleware import Middleware - - # Create session manager on first call (lazy initialization) - if self._session_manager is None: - self._session_manager = StreamableHTTPSessionManager( - app=self._mcp_server, - event_store=self._event_store, - json_response=self.settings.json_response, - stateless=self.settings.stateless_http, # Use the stateless setting - security_settings=self.settings.transport_security, - ) - - # Create the ASGI handler - streamable_http_app = StreamableHTTPASGIApp(self._session_manager) - - # Create routes - routes: list[Route | Mount] = [] - middleware: list[Middleware] = [] - required_scopes = [] - - # Set up auth if configured - if self.settings.auth: - required_scopes = self.settings.auth.required_scopes or [] - - # Add auth middleware if token verifier is available - if self._token_verifier: - middleware = [ - Middleware( - AuthenticationMiddleware, - backend=BearerAuthBackend(self._token_verifier), - ), - Middleware(AuthContextMiddleware), - ] - - # Add auth endpoints if auth server provider is configured - if self._auth_server_provider: - from mcp.server.auth.routes import create_auth_routes - - routes.extend( - create_auth_routes( - provider=self._auth_server_provider, - issuer_url=self.settings.auth.issuer_url, - service_documentation_url=self.settings.auth.service_documentation_url, - client_registration_options=self.settings.auth.client_registration_options, - revocation_options=self.settings.auth.revocation_options, - ) - ) - - # Set up routes with or without auth - if self._token_verifier: - # Determine resource metadata URL - resource_metadata_url = None - if self.settings.auth and self.settings.auth.resource_server_url: - from pydantic import AnyHttpUrl - - resource_metadata_url = AnyHttpUrl( - str(self.settings.auth.resource_server_url).rstrip("/") + "/.well-known/oauth-protected-resource" - ) - - routes.append( - Route( - self.settings.streamable_http_path, - endpoint=RequireAuthMiddleware(streamable_http_app, required_scopes, resource_metadata_url), - ) - ) - else: - # Auth is disabled, no wrapper needed - routes.append( - Route( - self.settings.streamable_http_path, - endpoint=streamable_http_app, - ) - ) - - # Add protected resource metadata endpoint if configured as RS - if self.settings.auth and self.settings.auth.resource_server_url: - from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler - from mcp.server.auth.routes import cors_middleware - from mcp.shared.auth import ProtectedResourceMetadata - - protected_resource_metadata = ProtectedResourceMetadata( - resource=self.settings.auth.resource_server_url, - authorization_servers=[self.settings.auth.issuer_url], - scopes_supported=self.settings.auth.required_scopes, - ) - routes.append( - Route( - "/.well-known/oauth-protected-resource", - endpoint=cors_middleware( - ProtectedResourceMetadataHandler(protected_resource_metadata).handle, - ["GET", "OPTIONS"], - ), - methods=["GET", "OPTIONS"], - ) - ) - - routes.extend(self._custom_starlette_routes) - - return Starlette( - debug=self.settings.debug, - routes=routes, - middleware=middleware, - lifespan=lambda app: self.session_manager.run(), - ) - - async def list_prompts(self) -> list[MCPPrompt]: - """List all available prompts.""" - prompts = self._prompt_manager.list_prompts() - return [ - MCPPrompt( - name=prompt.name, - title=prompt.title, - description=prompt.description, - arguments=[ - MCPPromptArgument( - name=arg.name, - description=arg.description, - required=arg.required, - ) - for arg in (prompt.arguments or []) - ], - ) - for prompt in prompts - ] - - async def get_prompt(self, name: str, arguments: dict[str, Any] | None = None) -> GetPromptResult: - """Get a prompt by name with arguments.""" - try: - prompt = self._prompt_manager.get_prompt(name) - if not prompt: - raise ValueError(f"Unknown prompt: {name}") - - messages = await prompt.render(arguments) - - return GetPromptResult( - description=prompt.description, - messages=pydantic_core.to_jsonable_python(messages), - ) - except Exception as e: - logger.exception(f"Error getting prompt {name}") - raise ValueError(str(e)) - - -class StreamableHTTPASGIApp: - """ - ASGI application for Streamable HTTP server transport. - """ - - def __init__(self, session_manager: StreamableHTTPSessionManager): - self.session_manager = session_manager - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - await self.session_manager.handle_request(scope, receive, send) - - -class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]): - """Context object providing access to MCP capabilities. - - This provides a cleaner interface to MCP's RequestContext functionality. - It gets injected into tool and resource functions that request it via type hints. - - To use context in a tool function, add a parameter with the Context type annotation: - - ```python - @server.tool() - def my_tool(x: int, ctx: Context) -> str: - # Log messages to the client - ctx.info(f"Processing {x}") - ctx.debug("Debug info") - ctx.warning("Warning message") - ctx.error("Error message") - - # Report progress - ctx.report_progress(50, 100) - - # Access resources - data = ctx.read_resource("resource://data") - - # Get request info - request_id = ctx.request_id - client_id = ctx.client_id - - return str(x) - ``` - - The context parameter name can be anything as long as it's annotated with Context. - The context is optional - tools that don't need it can omit the parameter. - """ - - _request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None - _fastmcp: FastMCP | None - - def __init__( - self, - *, - request_context: (RequestContext[ServerSessionT, LifespanContextT, RequestT] | None) = None, - fastmcp: FastMCP | None = None, - **kwargs: Any, - ): - super().__init__(**kwargs) - self._request_context = request_context - self._fastmcp = fastmcp - - @property - def fastmcp(self) -> FastMCP: - """Access to the FastMCP server.""" - if self._fastmcp is None: - raise ValueError("Context is not available outside of a request") - return self._fastmcp - - @property - def request_context( - self, - ) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]: - """Access to the underlying request context.""" - if self._request_context is None: - raise ValueError("Context is not available outside of a request") - return self._request_context - - async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - """Report progress for the current operation. - - Args: - progress: Current progress value e.g. 24 - total: Optional total value e.g. 100 - message: Optional message e.g. Starting render... - """ - progress_token = self.request_context.meta.progressToken if self.request_context.meta else None - - if progress_token is None: - return - - await self.request_context.session.send_progress_notification( - progress_token=progress_token, - progress=progress, - total=total, - message=message, - ) - - async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContents]: - """Read a resource by URI. - - Args: - uri: Resource URI to read - - Returns: - The resource content as either text or bytes - """ - assert self._fastmcp is not None, "Context is not available outside of a request" - return await self._fastmcp.read_resource(uri) - - async def elicit( - self, - message: str, - schema: type[ElicitSchemaModelT], - ) -> ElicitationResult[ElicitSchemaModelT]: - """Elicit information from the client/user. - - This method can be used to interactively ask for additional information from the - client within a tool's execution. The client might display the message to the - user and collect a response according to the provided schema. Or in case a - client is an agent, it might decide how to handle the elicitation -- either by asking - the user or automatically generating a response. - - Args: - schema: A Pydantic model class defining the expected response structure, according to the specification, - only primive types are allowed. - message: Optional message to present to the user. If not provided, will use - a default message based on the schema - - Returns: - An ElicitationResult containing the action taken and the data if accepted - - Note: - Check the result.action to determine if the user accepted, declined, or cancelled. - The result.data will only be populated if action is "accept" and validation succeeded. - """ - - return await elicit_with_validation( - session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id - ) - - async def log( - self, - level: Literal["debug", "info", "warning", "error"], - message: str, - *, - logger_name: str | None = None, - ) -> None: - """Send a log message to the client. - - Args: - level: Log level (debug, info, warning, error) - message: Log message - logger_name: Optional logger name - **extra: Additional structured data to include - """ - await self.request_context.session.send_log_message( - level=level, - data=message, - logger=logger_name, - related_request_id=self.request_id, - ) - - @property - def client_id(self) -> str | None: - """Get the client ID if available.""" - return getattr(self.request_context.meta, "client_id", None) if self.request_context.meta else None - - @property - def request_id(self) -> str: - """Get the unique ID for this request.""" - return str(self.request_context.request_id) - - @property - def session(self): - """Access to the underlying session for advanced usage.""" - return self.request_context.session - - # Convenience methods for common log levels - async def debug(self, message: str, **extra: Any) -> None: - """Send a debug log message.""" - await self.log("debug", message, **extra) - - async def info(self, message: str, **extra: Any) -> None: - """Send an info log message.""" - await self.log("info", message, **extra) - - async def warning(self, message: str, **extra: Any) -> None: - """Send a warning log message.""" - await self.log("warning", message, **extra) - - async def error(self, message: str, **extra: Any) -> None: - """Send an error log message.""" - await self.log("error", message, **extra) diff --git a/src/mcp/mcp/server/fastmcp/tools/__init__.py b/src/mcp/mcp/server/fastmcp/tools/__init__.py deleted file mode 100644 index ae9c65619..000000000 --- a/src/mcp/mcp/server/fastmcp/tools/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .base import Tool -from .tool_manager import ToolManager - -__all__ = ["Tool", "ToolManager"] diff --git a/src/mcp/mcp/server/fastmcp/tools/base.py b/src/mcp/mcp/server/fastmcp/tools/base.py deleted file mode 100644 index f50126081..000000000 --- a/src/mcp/mcp/server/fastmcp/tools/base.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations as _annotations - -import functools -import inspect -from collections.abc import Callable -from functools import cached_property -from typing import TYPE_CHECKING, Any, get_origin - -from pydantic import BaseModel, Field - -from mcp.server.fastmcp.exceptions import ToolError -from mcp.server.fastmcp.utilities.func_metadata import FuncMetadata, func_metadata -from mcp.types import ToolAnnotations - -if TYPE_CHECKING: - from mcp.server.fastmcp.server import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT, RequestT - - -class Tool(BaseModel): - """Internal tool registration info.""" - - fn: Callable[..., Any] = Field(exclude=True) - name: str = Field(description="Name of the tool") - title: str | None = Field(None, description="Human-readable title of the tool") - description: str = Field(description="Description of what the tool does") - parameters: dict[str, Any] = Field(description="JSON schema for tool parameters") - fn_metadata: FuncMetadata = Field( - description="Metadata about the function including a pydantic model for tool arguments" - ) - is_async: bool = Field(description="Whether the tool is async") - context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context") - annotations: ToolAnnotations | None = Field(None, description="Optional annotations for the tool") - - @cached_property - def output_schema(self) -> dict[str, Any] | None: - return self.fn_metadata.output_schema - - @classmethod - def from_function( - cls, - fn: Callable[..., Any], - name: str | None = None, - title: str | None = None, - description: str | None = None, - context_kwarg: str | None = None, - annotations: ToolAnnotations | None = None, - structured_output: bool | None = None, - ) -> Tool: - """Create a Tool from a function.""" - from mcp.server.fastmcp.server import Context - - func_name = name or fn.__name__ - - if func_name == "": - raise ValueError("You must provide a name for lambda functions") - - func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) - - if context_kwarg is None: - sig = inspect.signature(fn) - for param_name, param in sig.parameters.items(): - if get_origin(param.annotation) is not None: - continue - if issubclass(param.annotation, Context): - context_kwarg = param_name - break - - func_arg_metadata = func_metadata( - fn, - skip_names=[context_kwarg] if context_kwarg is not None else [], - structured_output=structured_output, - ) - parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) - - return cls( - fn=fn, - name=func_name, - title=title, - description=func_doc, - parameters=parameters, - fn_metadata=func_arg_metadata, - is_async=is_async, - context_kwarg=context_kwarg, - annotations=annotations, - ) - - async def run( - self, - arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, - convert_result: bool = False, - ) -> Any: - """Run the tool with arguments.""" - try: - result = await self.fn_metadata.call_fn_with_arg_validation( - self.fn, - self.is_async, - arguments, - {self.context_kwarg: context} if self.context_kwarg is not None else None, - ) - - if convert_result: - result = self.fn_metadata.convert_result(result) - - return result - except Exception as e: - raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/mcp/server/fastmcp/tools/tool_manager.py b/src/mcp/mcp/server/fastmcp/tools/tool_manager.py deleted file mode 100644 index bfa8b2382..000000000 --- a/src/mcp/mcp/server/fastmcp/tools/tool_manager.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import annotations as _annotations - -from collections.abc import Callable -from typing import TYPE_CHECKING, Any - -from mcp.server.fastmcp.exceptions import ToolError -from mcp.server.fastmcp.tools.base import Tool -from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.shared.context import LifespanContextT, RequestT -from mcp.types import ToolAnnotations - -if TYPE_CHECKING: - from mcp.server.fastmcp.server import Context - from mcp.server.session import ServerSessionT - -logger = get_logger(__name__) - - -class ToolManager: - """Manages FastMCP tools.""" - - def __init__( - self, - warn_on_duplicate_tools: bool = True, - *, - tools: list[Tool] | None = None, - ): - self._tools: dict[str, Tool] = {} - if tools is not None: - for tool in tools: - if warn_on_duplicate_tools and tool.name in self._tools: - logger.warning(f"Tool already exists: {tool.name}") - self._tools[tool.name] = tool - - self.warn_on_duplicate_tools = warn_on_duplicate_tools - - def get_tool(self, name: str) -> Tool | None: - """Get tool by name.""" - return self._tools.get(name) - - def list_tools(self) -> list[Tool]: - """List all registered tools.""" - return list(self._tools.values()) - - def add_tool( - self, - fn: Callable[..., Any], - name: str | None = None, - title: str | None = None, - description: str | None = None, - annotations: ToolAnnotations | None = None, - structured_output: bool | None = None, - ) -> Tool: - """Add a tool to the server.""" - tool = Tool.from_function( - fn, - name=name, - title=title, - description=description, - annotations=annotations, - structured_output=structured_output, - ) - existing = self._tools.get(tool.name) - if existing: - if self.warn_on_duplicate_tools: - logger.warning(f"Tool already exists: {tool.name}") - return existing - self._tools[tool.name] = tool - return tool - - async def call_tool( - self, - name: str, - arguments: dict[str, Any], - context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None, - convert_result: bool = False, - ) -> Any: - """Call a tool by name with arguments.""" - tool = self.get_tool(name) - if not tool: - raise ToolError(f"Unknown tool: {name}") - - return await tool.run(arguments, context=context, convert_result=convert_result) diff --git a/src/mcp/mcp/server/fastmcp/utilities/__init__.py b/src/mcp/mcp/server/fastmcp/utilities/__init__.py deleted file mode 100644 index be448f97a..000000000 --- a/src/mcp/mcp/server/fastmcp/utilities/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""FastMCP utility modules.""" diff --git a/src/mcp/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/mcp/server/fastmcp/utilities/func_metadata.py deleted file mode 100644 index a4cb8ac5b..000000000 --- a/src/mcp/mcp/server/fastmcp/utilities/func_metadata.py +++ /dev/null @@ -1,523 +0,0 @@ -import inspect -import json -from collections.abc import Awaitable, Callable, Sequence -from itertools import chain -from types import GenericAlias -from typing import Annotated, Any, ForwardRef, cast, get_args, get_origin, get_type_hints - -import pydantic_core -from pydantic import ( - BaseModel, - ConfigDict, - Field, - RootModel, - WithJsonSchema, - create_model, -) -from pydantic._internal._typing_extra import eval_type_backport -from pydantic.fields import FieldInfo -from pydantic.json_schema import GenerateJsonSchema, JsonSchemaWarningKind -from pydantic_core import PydanticUndefined - -from mcp.server.fastmcp.exceptions import InvalidSignature -from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.server.fastmcp.utilities.types import Audio, Image -from mcp.types import ContentBlock, TextContent - -logger = get_logger(__name__) - - -class StrictJsonSchema(GenerateJsonSchema): - """A JSON schema generator that raises exceptions instead of emitting warnings. - - This is used to detect non-serializable types during schema generation. - """ - - def emit_warning(self, kind: JsonSchemaWarningKind, detail: str) -> None: - # Raise an exception instead of emitting a warning - raise ValueError(f"JSON schema warning: {kind} - {detail}") - - -class ArgModelBase(BaseModel): - """A model representing the arguments to a function.""" - - def model_dump_one_level(self) -> dict[str, Any]: - """Return a dict of the model's fields, one level deep. - - That is, sub-models etc are not dumped - they are kept as pydantic models. - """ - kwargs: dict[str, Any] = {} - for field_name, field_info in self.__class__.model_fields.items(): - value = getattr(self, field_name) - # Use the alias if it exists, otherwise use the field name - output_name = field_info.alias if field_info.alias else field_name - kwargs[output_name] = value - return kwargs - - model_config = ConfigDict( - arbitrary_types_allowed=True, - ) - - -class FuncMetadata(BaseModel): - arg_model: Annotated[type[ArgModelBase], WithJsonSchema(None)] - output_schema: dict[str, Any] | None = None - output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None - wrap_output: bool = False - - async def call_fn_with_arg_validation( - self, - fn: Callable[..., Any | Awaitable[Any]], - fn_is_async: bool, - arguments_to_validate: dict[str, Any], - arguments_to_pass_directly: dict[str, Any] | None, - ) -> Any: - """Call the given function with arguments validated and injected. - - Arguments are first attempted to be parsed from JSON, then validated against - the argument model, before being passed to the function. - """ - arguments_pre_parsed = self.pre_parse_json(arguments_to_validate) - arguments_parsed_model = self.arg_model.model_validate(arguments_pre_parsed) - arguments_parsed_dict = arguments_parsed_model.model_dump_one_level() - - arguments_parsed_dict |= arguments_to_pass_directly or {} - - if fn_is_async: - return await fn(**arguments_parsed_dict) - else: - return fn(**arguments_parsed_dict) - - def convert_result(self, result: Any) -> Any: - """ - Convert the result of a function call to the appropriate format for - the lowlevel server tool call handler: - - - If output_model is None, return the unstructured content directly. - - If output_model is not None, convert the result to structured output format - (dict[str, Any]) and return both unstructured and structured content. - - Note: we return unstructured content here **even though the lowlevel server - tool call handler provides generic backwards compatibility serialization of - structured content**. This is for FastMCP backwards compatibility: we need to - retain FastMCP's ad hoc conversion logic for constructing unstructured output - from function return values, whereas the lowlevel server simply serializes - the structured output. - """ - unstructured_content = _convert_to_content(result) - - if self.output_schema is None: - return unstructured_content - else: - if self.wrap_output: - result = {"result": result} - - assert self.output_model is not None, "Output model must be set if output schema is defined" - validated = self.output_model.model_validate(result) - structured_content = validated.model_dump(mode="json", by_alias=True) - - return (unstructured_content, structured_content) - - def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: - """Pre-parse data from JSON. - - Return a dict with same keys as input but with values parsed from JSON - if appropriate. - - This is to handle cases like `["a", "b", "c"]` being passed in as JSON inside - a string rather than an actual list. Claude desktop is prone to this - in fact - it seems incapable of NOT doing this. For sub-models, it tends to pass - dicts (JSON objects) as JSON strings, which can be pre-parsed here. - """ - new_data = data.copy() # Shallow copy - - # Build a mapping from input keys (including aliases) to field info - key_to_field_info: dict[str, FieldInfo] = {} - for field_name, field_info in self.arg_model.model_fields.items(): - # Map both the field name and its alias (if any) to the field info - key_to_field_info[field_name] = field_info - if field_info.alias: - key_to_field_info[field_info.alias] = field_info - - for data_key in data.keys(): - if data_key not in key_to_field_info: - continue - - field_info = key_to_field_info[data_key] - if isinstance(data[data_key], str) and field_info.annotation is not str: - try: - pre_parsed = json.loads(data[data_key]) - except json.JSONDecodeError: - continue # Not JSON - skip - if isinstance(pre_parsed, str | int | float): - # This is likely that the raw value is e.g. `"hello"` which we - # Should really be parsed as '"hello"' in Python - but if we parse - # it as JSON it'll turn into just 'hello'. So we skip it. - continue - new_data[data_key] = pre_parsed - assert new_data.keys() == data.keys() - return new_data - - model_config = ConfigDict( - arbitrary_types_allowed=True, - ) - - -def func_metadata( - func: Callable[..., Any], - skip_names: Sequence[str] = (), - structured_output: bool | None = None, -) -> FuncMetadata: - """Given a function, return metadata including a pydantic model representing its - signature. - - The use case for this is - ``` - meta = func_metadata(func) - validated_args = meta.arg_model.model_validate(some_raw_data_dict) - return func(**validated_args.model_dump_one_level()) - ``` - - **critically** it also provides pre-parse helper to attempt to parse things from - JSON. - - Args: - func: The function to convert to a pydantic model - skip_names: A list of parameter names to skip. These will not be included in - the model. - structured_output: Controls whether the tool's output is structured or unstructured - - If None, auto-detects based on the function's return type annotation - - If True, unconditionally creates a structured tool (return type annotation permitting) - - If False, unconditionally creates an unstructured tool - - If structured, creates a Pydantic model for the function's result based on its annotation. - Supports various return types: - - BaseModel subclasses (used directly) - - Primitive types (str, int, float, bool, bytes, None) - wrapped in a - model with a 'result' field - - TypedDict - converted to a Pydantic model with same fields - - Dataclasses and other annotated classes - converted to Pydantic models - - Generic types (list, dict, Union, etc.) - wrapped in a model with a 'result' field - - Returns: - A FuncMetadata object containing: - - arg_model: A pydantic model representing the function's arguments - - output_model: A pydantic model for the return type if output is structured - - output_conversion: Records how function output should be converted before returning. - """ - sig = _get_typed_signature(func) - params = sig.parameters - dynamic_pydantic_model_params: dict[str, Any] = {} - globalns = getattr(func, "__globals__", {}) - for param in params.values(): - if param.name.startswith("_"): - raise InvalidSignature(f"Parameter {param.name} of {func.__name__} cannot start with '_'") - if param.name in skip_names: - continue - annotation = param.annotation - - # `x: None` / `x: None = None` - if annotation is None: - annotation = Annotated[ - None, - Field(default=param.default if param.default is not inspect.Parameter.empty else PydanticUndefined), - ] - - # Untyped field - if annotation is inspect.Parameter.empty: - annotation = Annotated[ - Any, - Field(), - # 🤷 - WithJsonSchema({"title": param.name, "type": "string"}), - ] - - field_info = FieldInfo.from_annotated_attribute( - _get_typed_annotation(annotation, globalns), - param.default if param.default is not inspect.Parameter.empty else PydanticUndefined, - ) - - # Check if the parameter name conflicts with BaseModel attributes - # This is necessary because Pydantic warns about shadowing parent attributes - if hasattr(BaseModel, param.name) and callable(getattr(BaseModel, param.name)): - # Use an alias to avoid the shadowing warning - field_info.alias = param.name - field_info.validation_alias = param.name - field_info.serialization_alias = param.name - # Use a prefixed internal name - internal_name = f"field_{param.name}" - dynamic_pydantic_model_params[internal_name] = (field_info.annotation, field_info) - else: - dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) - continue - - arguments_model = create_model( - f"{func.__name__}Arguments", - **dynamic_pydantic_model_params, - __base__=ArgModelBase, - ) - - if structured_output is False: - return FuncMetadata(arg_model=arguments_model) - - # set up structured output support based on return type annotation - - if sig.return_annotation is inspect.Parameter.empty and structured_output is True: - raise InvalidSignature(f"Function {func.__name__}: return annotation required for structured output") - - output_info = FieldInfo.from_annotation(_get_typed_annotation(sig.return_annotation, globalns)) - annotation = output_info.annotation - - output_model, output_schema, wrap_output = _try_create_model_and_schema(annotation, func.__name__, output_info) - - if output_model is None and structured_output is True: - # Model creation failed or produced warnings - no structured output - raise InvalidSignature( - f"Function {func.__name__}: return type {annotation} is not serializable for structured output" - ) - - return FuncMetadata( - arg_model=arguments_model, - output_schema=output_schema, - output_model=output_model, - wrap_output=wrap_output, - ) - - -def _try_create_model_and_schema( - annotation: Any, func_name: str, field_info: FieldInfo -) -> tuple[type[BaseModel] | None, dict[str, Any] | None, bool]: - """Try to create a model and schema for the given annotation without warnings. - - Returns: - tuple of (model or None, schema or None, wrap_output) - Model and schema are None if warnings occur or creation fails. - wrap_output is True if the result needs to be wrapped in {"result": ...} - """ - model = None - wrap_output = False - - # First handle special case: None - if annotation is None: - model = _create_wrapped_model(func_name, annotation, field_info) - wrap_output = True - - # Handle GenericAlias types (list[str], dict[str, int], Union[str, int], etc.) - elif isinstance(annotation, GenericAlias): - origin = get_origin(annotation) - - # Special case: dict with string keys can use RootModel - if origin is dict: - args = get_args(annotation) - if len(args) == 2 and args[0] is str: - model = _create_dict_model(func_name, annotation) - else: - # dict with non-str keys needs wrapping - model = _create_wrapped_model(func_name, annotation, field_info) - wrap_output = True - else: - # All other generic types need wrapping (list, tuple, Union, Optional, etc.) - model = _create_wrapped_model(func_name, annotation, field_info) - wrap_output = True - - # Handle regular type objects - elif isinstance(annotation, type): - type_annotation: type[Any] = cast(type[Any], annotation) - - # Case 1: BaseModel subclasses (can be used directly) - if issubclass(annotation, BaseModel): - model = annotation - - # Case 2: TypedDict (special dict subclass with __annotations__) - elif hasattr(type_annotation, "__annotations__") and issubclass(annotation, dict): - model = _create_model_from_typeddict(type_annotation) - - # Case 3: Primitive types that need wrapping - elif annotation in (str, int, float, bool, bytes, type(None)): - model = _create_wrapped_model(func_name, annotation, field_info) - wrap_output = True - - # Case 4: Other class types (dataclasses, regular classes with annotations) - else: - type_hints = get_type_hints(type_annotation) - if type_hints: - # Classes with type hints can be converted to Pydantic models - model = _create_model_from_class(type_annotation) - # Classes without type hints are not serializable - model remains None - - # Handle any other types not covered above - else: - # This includes typing constructs that aren't GenericAlias in Python 3.10 - # (e.g., Union, Optional in some Python versions) - model = _create_wrapped_model(func_name, annotation, field_info) - wrap_output = True - - if model: - # If we successfully created a model, try to get its schema - # Use StrictJsonSchema to raise exceptions instead of warnings - try: - schema = model.model_json_schema(schema_generator=StrictJsonSchema) - except (TypeError, ValueError, pydantic_core.SchemaError, pydantic_core.ValidationError) as e: - # These are expected errors when a type can't be converted to a Pydantic schema - # TypeError: When Pydantic can't handle the type - # ValueError: When there are issues with the type definition (including our custom warnings) - # SchemaError: When Pydantic can't build a schema - # ValidationError: When validation fails - logger.info(f"Cannot create schema for type {annotation} in {func_name}: {type(e).__name__}: {e}") - return None, None, False - - return model, schema, wrap_output - - return None, None, False - - -def _create_model_from_class(cls: type[Any]) -> type[BaseModel]: - """Create a Pydantic model from an ordinary class. - - The created model will: - - Have the same name as the class - - Have fields with the same names and types as the class's fields - - Include all fields whose type does not include None in the set of required fields - - Precondition: cls must have type hints (i.e., get_type_hints(cls) is non-empty) - """ - type_hints = get_type_hints(cls) - - model_fields: dict[str, Any] = {} - for field_name, field_type in type_hints.items(): - if field_name.startswith("_"): - continue - - default = getattr(cls, field_name, PydanticUndefined) - field_info = FieldInfo.from_annotated_attribute(field_type, default) - model_fields[field_name] = (field_info.annotation, field_info) - - # Create a base class with the config - class BaseWithConfig(BaseModel): - model_config = ConfigDict(from_attributes=True) - - return create_model(cls.__name__, **model_fields, __base__=BaseWithConfig) - - -def _create_model_from_typeddict(td_type: type[Any]) -> type[BaseModel]: - """Create a Pydantic model from a TypedDict. - - The created model will have the same name and fields as the TypedDict. - """ - type_hints = get_type_hints(td_type) - required_keys = getattr(td_type, "__required_keys__", set(type_hints.keys())) - - model_fields: dict[str, Any] = {} - for field_name, field_type in type_hints.items(): - field_info = FieldInfo.from_annotation(field_type) - - if field_name not in required_keys: - # For optional TypedDict fields, set default=None - # This makes them not required in the Pydantic model - # The model should use exclude_unset=True when dumping to get TypedDict semantics - field_info.default = None - - model_fields[field_name] = (field_info.annotation, field_info) - - return create_model(td_type.__name__, **model_fields, __base__=BaseModel) - - -def _create_wrapped_model(func_name: str, annotation: Any, field_info: FieldInfo) -> type[BaseModel]: - """Create a model that wraps a type in a 'result' field. - - This is used for primitive types, generic types like list/dict, etc. - """ - model_name = f"{func_name}Output" - - # Pydantic needs type(None) instead of None for the type annotation - if annotation is None: - annotation = type(None) - - return create_model(model_name, result=(annotation, field_info), __base__=BaseModel) - - -def _create_dict_model(func_name: str, dict_annotation: Any) -> type[BaseModel]: - """Create a RootModel for dict[str, T] types.""" - - class DictModel(RootModel[dict_annotation]): - pass - - # Give it a meaningful name - DictModel.__name__ = f"{func_name}DictOutput" - DictModel.__qualname__ = f"{func_name}DictOutput" - - return DictModel - - -def _get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: - def try_eval_type(value: Any, globalns: dict[str, Any], localns: dict[str, Any]) -> tuple[Any, bool]: - try: - return eval_type_backport(value, globalns, localns), True - except NameError: - return value, False - - if isinstance(annotation, str): - annotation = ForwardRef(annotation) - annotation, status = try_eval_type(annotation, globalns, globalns) - - # This check and raise could perhaps be skipped, and we (FastMCP) just call - # model_rebuild right before using it 🤷 - if status is False: - raise InvalidSignature(f"Unable to evaluate type annotation {annotation}") - - return annotation - - -def _get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: - """Get function signature while evaluating forward references""" - signature = inspect.signature(call) - globalns = getattr(call, "__globals__", {}) - typed_params = [ - inspect.Parameter( - name=param.name, - kind=param.kind, - default=param.default, - annotation=_get_typed_annotation(param.annotation, globalns), - ) - for param in signature.parameters.values() - ] - typed_return = _get_typed_annotation(signature.return_annotation, globalns) - typed_signature = inspect.Signature(typed_params, return_annotation=typed_return) - return typed_signature - - -def _convert_to_content( - result: Any, -) -> Sequence[ContentBlock]: - """ - Convert a result to a sequence of content objects. - - Note: This conversion logic comes from previous versions of FastMCP and is being - retained for purposes of backwards compatibility. It produces different unstructured - output than the lowlevel server tool call handler, which just serializes structured - content verbatim. - """ - if result is None: - return [] - - if isinstance(result, ContentBlock): - return [result] - - if isinstance(result, Image): - return [result.to_image_content()] - - if isinstance(result, Audio): - return [result.to_audio_content()] - - if isinstance(result, list | tuple): - return list( - chain.from_iterable( - _convert_to_content(item) - for item in result # type: ignore - ) - ) - - if not isinstance(result, str): - result = pydantic_core.to_json(result, fallback=str, indent=2).decode() - - return [TextContent(type="text", text=result)] diff --git a/src/mcp/mcp/server/fastmcp/utilities/logging.py b/src/mcp/mcp/server/fastmcp/utilities/logging.py deleted file mode 100644 index 091d57e69..000000000 --- a/src/mcp/mcp/server/fastmcp/utilities/logging.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Logging utilities for FastMCP.""" - -import logging -from typing import Literal - - -def get_logger(name: str) -> logging.Logger: - """Get a logger nested under MCPnamespace. - - Args: - name: the name of the logger, which will be prefixed with 'FastMCP.' - - Returns: - a configured logger instance - """ - return logging.getLogger(name) - - -def configure_logging( - level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", -) -> None: - """Configure logging for MCP. - - Args: - level: the log level to use - """ - handlers: list[logging.Handler] = [] - try: - from rich.console import Console - from rich.logging import RichHandler - - handlers.append(RichHandler(console=Console(stderr=True), rich_tracebacks=True)) - except ImportError: - pass - - if not handlers: - handlers.append(logging.StreamHandler()) - - logging.basicConfig( - level=level, - format="%(message)s", - handlers=handlers, - ) diff --git a/src/mcp/mcp/server/fastmcp/utilities/types.py b/src/mcp/mcp/server/fastmcp/utilities/types.py deleted file mode 100644 index 1be6f8274..000000000 --- a/src/mcp/mcp/server/fastmcp/utilities/types.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Common types used across FastMCP.""" - -import base64 -from pathlib import Path - -from mcp.types import AudioContent, ImageContent - - -class Image: - """Helper class for returning images from tools.""" - - def __init__( - self, - path: str | Path | None = None, - data: bytes | None = None, - format: str | None = None, - ): - if path is None and data is None: - raise ValueError("Either path or data must be provided") - if path is not None and data is not None: - raise ValueError("Only one of path or data can be provided") - - self.path = Path(path) if path else None - self.data = data - self._format = format - self._mime_type = self._get_mime_type() - - def _get_mime_type(self) -> str: - """Get MIME type from format or guess from file extension.""" - if self._format: - return f"image/{self._format.lower()}" - - if self.path: - suffix = self.path.suffix.lower() - return { - ".png": "image/png", - ".jpg": "image/jpeg", - ".jpeg": "image/jpeg", - ".gif": "image/gif", - ".webp": "image/webp", - }.get(suffix, "application/octet-stream") - return "image/png" # default for raw binary data - - def to_image_content(self) -> ImageContent: - """Convert to MCP ImageContent.""" - if self.path: - with open(self.path, "rb") as f: - data = base64.b64encode(f.read()).decode() - elif self.data is not None: - data = base64.b64encode(self.data).decode() - else: - raise ValueError("No image data available") - - return ImageContent(type="image", data=data, mimeType=self._mime_type) - - -class Audio: - """Helper class for returning audio from tools.""" - - def __init__( - self, - path: str | Path | None = None, - data: bytes | None = None, - format: str | None = None, - ): - if not bool(path) ^ bool(data): - raise ValueError("Either path or data can be provided") - - self.path = Path(path) if path else None - self.data = data - self._format = format - self._mime_type = self._get_mime_type() - - def _get_mime_type(self) -> str: - """Get MIME type from format or guess from file extension.""" - if self._format: - return f"audio/{self._format.lower()}" - - if self.path: - suffix = self.path.suffix.lower() - return { - ".wav": "audio/wav", - ".mp3": "audio/mpeg", - ".ogg": "audio/ogg", - ".flac": "audio/flac", - ".aac": "audio/aac", - ".m4a": "audio/mp4", - }.get(suffix, "application/octet-stream") - return "audio/wav" # default for raw binary data - - def to_audio_content(self) -> AudioContent: - """Convert to MCP AudioContent.""" - if self.path: - with open(self.path, "rb") as f: - data = base64.b64encode(f.read()).decode() - elif self.data is not None: - data = base64.b64encode(self.data).decode() - else: - raise ValueError("No audio data available") - - return AudioContent(type="audio", data=data, mimeType=self._mime_type) diff --git a/src/mcp/mcp/server/lowlevel/__init__.py b/src/mcp/mcp/server/lowlevel/__init__.py deleted file mode 100644 index 66df38991..000000000 --- a/src/mcp/mcp/server/lowlevel/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .server import NotificationOptions, Server - -__all__ = ["Server", "NotificationOptions"] diff --git a/src/mcp/mcp/server/lowlevel/helper_types.py b/src/mcp/mcp/server/lowlevel/helper_types.py deleted file mode 100644 index 3d09b2505..000000000 --- a/src/mcp/mcp/server/lowlevel/helper_types.py +++ /dev/null @@ -1,9 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class ReadResourceContents: - """Contents returned from a read_resource call.""" - - content: str | bytes - mime_type: str | None = None diff --git a/src/mcp/mcp/server/lowlevel/server.py b/src/mcp/mcp/server/lowlevel/server.py deleted file mode 100644 index 8c459383c..000000000 --- a/src/mcp/mcp/server/lowlevel/server.py +++ /dev/null @@ -1,686 +0,0 @@ -""" -MCP Server Module - -This module provides a framework for creating an MCP (Model Context Protocol) server. -It allows you to easily define and handle various types of requests and notifications -in an asynchronous manner. - -Usage: -1. Create a Server instance: - server = Server("your_server_name") - -2. Define request handlers using decorators: - @server.list_prompts() - async def handle_list_prompts() -> list[types.Prompt]: - # Implementation - - @server.get_prompt() - async def handle_get_prompt( - name: str, arguments: dict[str, str] | None - ) -> types.GetPromptResult: - # Implementation - - @server.list_tools() - async def handle_list_tools() -> list[types.Tool]: - # Implementation - - @server.call_tool() - async def handle_call_tool( - name: str, arguments: dict | None - ) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]: - # Implementation - - @server.list_resource_templates() - async def handle_list_resource_templates() -> list[types.ResourceTemplate]: - # Implementation - -3. Define notification handlers if needed: - @server.progress_notification() - async def handle_progress( - progress_token: str | int, progress: float, total: float | None, - message: str | None - ) -> None: - # Implementation - -4. Run the server: - async def main(): - async with mcp.server.stdio.stdio_server() as (read_stream, write_stream): - await server.run( - read_stream, - write_stream, - InitializationOptions( - server_name="your_server_name", - server_version="your_version", - capabilities=server.get_capabilities( - notification_options=NotificationOptions(), - experimental_capabilities={}, - ), - ), - ) - - asyncio.run(main()) - -The Server class provides methods to register handlers for various MCP requests and -notifications. It automatically manages the request context and handles incoming -messages from the client. -""" - -from __future__ import annotations as _annotations - -import contextvars -import json -import logging -import warnings -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from typing import Any, Generic, TypeAlias, cast - -import anyio -import jsonschema -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl -from typing_extensions import TypeVar - -import mcp.types as types -from mcp.server.lowlevel.helper_types import ReadResourceContents -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.server.stdio import stdio_server as stdio_server -from mcp.shared.context import RequestContext -from mcp.shared.exceptions import McpError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import RequestResponder - -logger = logging.getLogger(__name__) - -LifespanResultT = TypeVar("LifespanResultT", default=Any) -RequestT = TypeVar("RequestT", default=Any) - -# type aliases for tool call results -StructuredContent: TypeAlias = dict[str, Any] -UnstructuredContent: TypeAlias = Iterable[types.ContentBlock] -CombinationContent: TypeAlias = tuple[UnstructuredContent, StructuredContent] - -# This will be properly typed in each Server instance's context -request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = contextvars.ContextVar("request_ctx") - - -class NotificationOptions: - def __init__( - self, - prompts_changed: bool = False, - resources_changed: bool = False, - tools_changed: bool = False, - ): - self.prompts_changed = prompts_changed - self.resources_changed = resources_changed - self.tools_changed = tools_changed - - -@asynccontextmanager -async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: - """Default lifespan context manager that does nothing. - - Args: - server: The server instance this lifespan is managing - - Returns: - An empty context object - """ - yield {} - - -class Server(Generic[LifespanResultT, RequestT]): - def __init__( - self, - name: str, - version: str | None = None, - instructions: str | None = None, - lifespan: Callable[ - [Server[LifespanResultT, RequestT]], - AbstractAsyncContextManager[LifespanResultT], - ] = lifespan, - ): - self.name = name - self.version = version - self.instructions = instructions - self.lifespan = lifespan - self.request_handlers: dict[type, Callable[..., Awaitable[types.ServerResult]]] = { - types.PingRequest: _ping_handler, - } - self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self._tool_cache: dict[str, types.Tool] = {} - logger.debug("Initializing server %r", name) - - def create_initialization_options( - self, - notification_options: NotificationOptions | None = None, - experimental_capabilities: dict[str, dict[str, Any]] | None = None, - ) -> InitializationOptions: - """Create initialization options from this server instance.""" - - def pkg_version(package: str) -> str: - try: - from importlib.metadata import version - - return version(package) - except Exception: - pass - - return "unknown" - - return InitializationOptions( - server_name=self.name, - server_version=self.version if self.version else pkg_version("mcp"), - capabilities=self.get_capabilities( - notification_options or NotificationOptions(), - experimental_capabilities or {}, - ), - instructions=self.instructions, - ) - - def get_capabilities( - self, - notification_options: NotificationOptions, - experimental_capabilities: dict[str, dict[str, Any]], - ) -> types.ServerCapabilities: - """Convert existing handlers to a ServerCapabilities object.""" - prompts_capability = None - resources_capability = None - tools_capability = None - logging_capability = None - completions_capability = None - - # Set prompt capabilities if handler exists - if types.ListPromptsRequest in self.request_handlers: - prompts_capability = types.PromptsCapability(listChanged=notification_options.prompts_changed) - - # Set resource capabilities if handler exists - if types.ListResourcesRequest in self.request_handlers: - resources_capability = types.ResourcesCapability( - subscribe=False, listChanged=notification_options.resources_changed - ) - - # Set tool capabilities if handler exists - if types.ListToolsRequest in self.request_handlers: - tools_capability = types.ToolsCapability(listChanged=notification_options.tools_changed) - - # Set logging capabilities if handler exists - if types.SetLevelRequest in self.request_handlers: - logging_capability = types.LoggingCapability() - - # Set completions capabilities if handler exists - if types.CompleteRequest in self.request_handlers: - completions_capability = types.CompletionsCapability() - - return types.ServerCapabilities( - prompts=prompts_capability, - resources=resources_capability, - tools=tools_capability, - logging=logging_capability, - experimental=experimental_capabilities, - completions=completions_capability, - ) - - @property - def request_context( - self, - ) -> RequestContext[ServerSession, LifespanResultT, RequestT]: - """If called outside of a request context, this will raise a LookupError.""" - return request_ctx.get() - - def list_prompts(self): - def decorator(func: Callable[[], Awaitable[list[types.Prompt]]]): - logger.debug("Registering handler for PromptListRequest") - - async def handler(_: Any): - prompts = await func() - return types.ServerResult(types.ListPromptsResult(prompts=prompts)) - - self.request_handlers[types.ListPromptsRequest] = handler - return func - - return decorator - - def get_prompt(self): - def decorator( - func: Callable[[str, dict[str, str] | None], Awaitable[types.GetPromptResult]], - ): - logger.debug("Registering handler for GetPromptRequest") - - async def handler(req: types.GetPromptRequest): - prompt_get = await func(req.params.name, req.params.arguments) - return types.ServerResult(prompt_get) - - self.request_handlers[types.GetPromptRequest] = handler - return func - - return decorator - - def list_resources(self): - def decorator(func: Callable[[], Awaitable[list[types.Resource]]]): - logger.debug("Registering handler for ListResourcesRequest") - - async def handler(_: Any): - resources = await func() - return types.ServerResult(types.ListResourcesResult(resources=resources)) - - self.request_handlers[types.ListResourcesRequest] = handler - return func - - return decorator - - def list_resource_templates(self): - def decorator(func: Callable[[], Awaitable[list[types.ResourceTemplate]]]): - logger.debug("Registering handler for ListResourceTemplatesRequest") - - async def handler(_: Any): - templates = await func() - return types.ServerResult(types.ListResourceTemplatesResult(resourceTemplates=templates)) - - self.request_handlers[types.ListResourceTemplatesRequest] = handler - return func - - return decorator - - def read_resource(self): - def decorator( - func: Callable[[AnyUrl], Awaitable[str | bytes | Iterable[ReadResourceContents]]], - ): - logger.debug("Registering handler for ReadResourceRequest") - - async def handler(req: types.ReadResourceRequest): - result = await func(req.params.uri) - - def create_content(data: str | bytes, mime_type: str | None): - match data: - case str() as data: - return types.TextResourceContents( - uri=req.params.uri, - text=data, - mimeType=mime_type or "text/plain", - ) - case bytes() as data: - import base64 - - return types.BlobResourceContents( - uri=req.params.uri, - blob=base64.b64encode(data).decode(), - mimeType=mime_type or "application/octet-stream", - ) - - match result: - case str() | bytes() as data: - warnings.warn( - "Returning str or bytes from read_resource is deprecated. " - "Use Iterable[ReadResourceContents] instead.", - DeprecationWarning, - stacklevel=2, - ) - content = create_content(data, None) - case Iterable() as contents: - contents_list = [ - create_content(content_item.content, content_item.mime_type) for content_item in contents - ] - return types.ServerResult( - types.ReadResourceResult( - contents=contents_list, - ) - ) - case _: - raise ValueError(f"Unexpected return type from read_resource: {type(result)}") - - return types.ServerResult( - types.ReadResourceResult( - contents=[content], - ) - ) - - self.request_handlers[types.ReadResourceRequest] = handler - return func - - return decorator - - def set_logging_level(self): - def decorator(func: Callable[[types.LoggingLevel], Awaitable[None]]): - logger.debug("Registering handler for SetLevelRequest") - - async def handler(req: types.SetLevelRequest): - await func(req.params.level) - return types.ServerResult(types.EmptyResult()) - - self.request_handlers[types.SetLevelRequest] = handler - return func - - return decorator - - def subscribe_resource(self): - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): - logger.debug("Registering handler for SubscribeRequest") - - async def handler(req: types.SubscribeRequest): - await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) - - self.request_handlers[types.SubscribeRequest] = handler - return func - - return decorator - - def unsubscribe_resource(self): - def decorator(func: Callable[[AnyUrl], Awaitable[None]]): - logger.debug("Registering handler for UnsubscribeRequest") - - async def handler(req: types.UnsubscribeRequest): - await func(req.params.uri) - return types.ServerResult(types.EmptyResult()) - - self.request_handlers[types.UnsubscribeRequest] = handler - return func - - return decorator - - def list_tools(self): - def decorator(func: Callable[[], Awaitable[list[types.Tool]]]): - logger.debug("Registering handler for ListToolsRequest") - - async def handler(_: Any): - tools = await func() - # Refresh the tool cache - self._tool_cache.clear() - for tool in tools: - self._tool_cache[tool.name] = tool - return types.ServerResult(types.ListToolsResult(tools=tools)) - - self.request_handlers[types.ListToolsRequest] = handler - return func - - return decorator - - def _make_error_result(self, error_message: str) -> types.ServerResult: - """Create a ServerResult with an error CallToolResult.""" - return types.ServerResult( - types.CallToolResult( - content=[types.TextContent(type="text", text=error_message)], - isError=True, - ) - ) - - async def _get_cached_tool_definition(self, tool_name: str) -> types.Tool | None: - """Get tool definition from cache, refreshing if necessary. - - Returns the Tool object if found, None otherwise. - """ - if tool_name not in self._tool_cache: - if types.ListToolsRequest in self.request_handlers: - logger.debug("Tool cache miss for %s, refreshing cache", tool_name) - await self.request_handlers[types.ListToolsRequest](None) - - tool = self._tool_cache.get(tool_name) - if tool is None: - logger.warning("Tool '%s' not listed, no validation will be performed", tool_name) - - return tool - - def call_tool(self, *, validate_input: bool = True): - """Register a tool call handler. - - Args: - validate_input: If True, validates input against inputSchema. Default is True. - - The handler validates input against inputSchema (if validate_input=True), calls the tool function, - and builds a CallToolResult with the results: - - Unstructured content (iterable of ContentBlock): returned in content - - Structured content (dict): returned in structuredContent, serialized JSON text returned in content - - Both: returned in content and structuredContent - - If outputSchema is defined, validates structuredContent or errors if missing. - """ - - def decorator( - func: Callable[ - ..., - Awaitable[UnstructuredContent | StructuredContent | CombinationContent], - ], - ): - logger.debug("Registering handler for CallToolRequest") - - async def handler(req: types.CallToolRequest): - try: - tool_name = req.params.name - arguments = req.params.arguments or {} - tool = await self._get_cached_tool_definition(tool_name) - - # input validation - if validate_input and tool: - try: - jsonschema.validate(instance=arguments, schema=tool.inputSchema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Input validation error: {e.message}") - - # tool call - results = await func(tool_name, arguments) - - # output normalization - unstructured_content: UnstructuredContent - maybe_structured_content: StructuredContent | None - if isinstance(results, tuple) and len(results) == 2: - # tool returned both structured and unstructured content - unstructured_content, maybe_structured_content = cast(CombinationContent, results) - elif isinstance(results, dict): - # tool returned structured content only - maybe_structured_content = cast(StructuredContent, results) - unstructured_content = [types.TextContent(type="text", text=json.dumps(results, indent=2))] - elif hasattr(results, "__iter__"): - # tool returned unstructured content only - unstructured_content = cast(UnstructuredContent, results) - maybe_structured_content = None - else: - return self._make_error_result(f"Unexpected return type from tool: {type(results).__name__}") - - # output validation - if tool and tool.outputSchema is not None: - if maybe_structured_content is None: - return self._make_error_result( - "Output validation error: outputSchema defined but no structured output returned" - ) - else: - try: - jsonschema.validate(instance=maybe_structured_content, schema=tool.outputSchema) - except jsonschema.ValidationError as e: - return self._make_error_result(f"Output validation error: {e.message}") - - # result - return types.ServerResult( - types.CallToolResult( - content=list(unstructured_content), - structuredContent=maybe_structured_content, - isError=False, - ) - ) - except Exception as e: - return self._make_error_result(str(e)) - - self.request_handlers[types.CallToolRequest] = handler - return func - - return decorator - - def progress_notification(self): - def decorator( - func: Callable[[str | int, float, float | None, str | None], Awaitable[None]], - ): - logger.debug("Registering handler for ProgressNotification") - - async def handler(req: types.ProgressNotification): - await func( - req.params.progressToken, - req.params.progress, - req.params.total, - req.params.message, - ) - - self.notification_handlers[types.ProgressNotification] = handler - return func - - return decorator - - def completion(self): - """Provides completions for prompts and resource templates""" - - def decorator( - func: Callable[ - [ - types.PromptReference | types.ResourceTemplateReference, - types.CompletionArgument, - types.CompletionContext | None, - ], - Awaitable[types.Completion | None], - ], - ): - logger.debug("Registering handler for CompleteRequest") - - async def handler(req: types.CompleteRequest): - completion = await func(req.params.ref, req.params.argument, req.params.context) - return types.ServerResult( - types.CompleteResult( - completion=completion - if completion is not None - else types.Completion(values=[], total=None, hasMore=None), - ) - ) - - self.request_handlers[types.CompleteRequest] = handler - return func - - return decorator - - async def run( - self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], - initialization_options: InitializationOptions, - # When False, exceptions are returned as messages to the client. - # When True, exceptions are raised, which will cause the server to shut down - # but also make tracing exceptions much easier during testing and when using - # in-process servers. - raise_exceptions: bool = False, - # When True, the server is stateless and - # clients can perform initialization with any node. The client must still follow - # the initialization lifecycle, but can do so with any available node - # rather than requiring initialization for each connection. - stateless: bool = False, - ): - async with AsyncExitStack() as stack: - lifespan_context = await stack.enter_async_context(self.lifespan(self)) - session = await stack.enter_async_context( - ServerSession( - read_stream, - write_stream, - initialization_options, - stateless=stateless, - ) - ) - - async with anyio.create_task_group() as tg: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - tg.start_soon( - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) - - async def _handle_message( - self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, - session: ServerSession, - lifespan_context: LifespanResultT, - raise_exceptions: bool = False, - ): - with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] - case RequestResponder(request=types.ClientRequest(root=req)) as responder: - with responder: - await self._handle_request(message, req, session, lifespan_context, raise_exceptions) - case types.ClientNotification(root=notify): - await self._handle_notification(notify) - - for warning in w: - logger.info("Warning: %s: %s", warning.category.__name__, warning.message) - - async def _handle_request( - self, - message: RequestResponder[types.ClientRequest, types.ServerResult], - req: Any, - session: ServerSession, - lifespan_context: LifespanResultT, - raise_exceptions: bool, - ): - logger.info("Processing request of type %s", type(req).__name__) - if handler := self.request_handlers.get(type(req)): # type: ignore - logger.debug("Dispatching request of type %s", type(req).__name__) - - token = None - try: - # Extract request context from message metadata - request_data = None - if message.message_metadata is not None and isinstance(message.message_metadata, ServerMessageMetadata): - request_data = message.message_metadata.request_context - - # Set our global state that can be retrieved via - # app.get_request_context() - token = request_ctx.set( - RequestContext( - message.request_id, - message.request_meta, - session, - lifespan_context, - request=request_data, - ) - ) - response = await handler(req) - except McpError as err: - response = err.error - except anyio.get_cancelled_exc_class(): - logger.info( - "Request %s cancelled - duplicate response suppressed", - message.request_id, - ) - return - except Exception as err: - if raise_exceptions: - raise err - response = types.ErrorData(code=0, message=str(err), data=None) - finally: - # Reset the global state after we are done - if token is not None: - request_ctx.reset(token) - - await message.respond(response) - else: - await message.respond( - types.ErrorData( - code=types.METHOD_NOT_FOUND, - message="Method not found", - ) - ) - - logger.debug("Response sent") - - async def _handle_notification(self, notify: Any): - if handler := self.notification_handlers.get(type(notify)): # type: ignore - logger.debug("Dispatching notification of type %s", type(notify).__name__) - - try: - await handler(notify) - except Exception: - logger.exception("Uncaught exception in notification handler") - - -async def _ping_handler(request: types.PingRequest) -> types.ServerResult: - return types.ServerResult(types.EmptyResult()) diff --git a/src/mcp/mcp/server/models.py b/src/mcp/mcp/server/models.py deleted file mode 100644 index 3b5abba78..000000000 --- a/src/mcp/mcp/server/models.py +++ /dev/null @@ -1,17 +0,0 @@ -""" -This module provides simpler types to use with the server for managing prompts -and tools. -""" - -from pydantic import BaseModel - -from mcp.types import ( - ServerCapabilities, -) - - -class InitializationOptions(BaseModel): - server_name: str - server_version: str - capabilities: ServerCapabilities - instructions: str | None = None diff --git a/src/mcp/mcp/server/session.py b/src/mcp/mcp/server/session.py deleted file mode 100644 index 5c696b136..000000000 --- a/src/mcp/mcp/server/session.py +++ /dev/null @@ -1,359 +0,0 @@ -""" -ServerSession Module - -This module provides the ServerSession class, which manages communication between the -server and client in the MCP (Model Context Protocol) framework. It is most commonly -used in MCP servers to interact with the client. - -Common usage pattern: -``` - server = Server(name) - - @server.call_tool() - async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any: - # Check client capabilities before proceeding - if ctx.session.check_client_capability( - types.ClientCapabilities(experimental={"advanced_tools": dict()}) - ): - # Perform advanced tool operations - result = await perform_advanced_tool_operation(arguments) - else: - # Fall back to basic tool operations - result = await perform_basic_tool_operation(arguments) - - return result - - @server.list_prompts() - async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: - # Access session for any necessary checks or operations - if ctx.session.client_params: - # Customize prompts based on client initialization parameters - return generate_custom_prompts(ctx.session.client_params) - else: - return default_prompts -``` - -The ServerSession class is typically used internally by the Server class and should not -be instantiated directly by users of the MCP framework. -""" - -from enum import Enum -from typing import Any, TypeVar - -import anyio -import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import AnyUrl - -import mcp.types as types -from mcp.server.models import InitializationOptions -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import ( - BaseSession, - RequestResponder, -) -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS - - -class InitializationState(Enum): - NotInitialized = 1 - Initializing = 2 - Initialized = 3 - - -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") - -ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception -) - - -class ServerSession( - BaseSession[ - types.ServerRequest, - types.ServerNotification, - types.ServerResult, - types.ClientRequest, - types.ClientNotification, - ] -): - _initialized: InitializationState = InitializationState.NotInitialized - _client_params: types.InitializeRequestParams | None = None - - def __init__( - self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], - init_options: InitializationOptions, - stateless: bool = False, - ) -> None: - super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) - self._initialization_state = ( - InitializationState.Initialized if stateless else InitializationState.NotInitialized - ) - - self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ - ServerRequestResponder - ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) - - @property - def client_params(self) -> types.InitializeRequestParams | None: - return self._client_params - - def check_client_capability(self, capability: types.ClientCapabilities) -> bool: - """Check if the client supports a specific capability.""" - if self._client_params is None: - return False - - # Get client capabilities from initialization params - client_caps = self._client_params.capabilities - - # Check each specified capability in the passed in capability object - if capability.roots is not None: - if client_caps.roots is None: - return False - if capability.roots.listChanged and not client_caps.roots.listChanged: - return False - - if capability.sampling is not None: - if client_caps.sampling is None: - return False - - if capability.elicitation is not None: - if client_caps.elicitation is None: - return False - - if capability.experimental is not None: - if client_caps.experimental is None: - return False - # Check each experimental capability - for exp_key, exp_value in capability.experimental.items(): - if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: - return False - - return True - - async def _receive_loop(self) -> None: - async with self._incoming_message_stream_writer: - await super()._receive_loop() - - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): - match responder.request.root: - case types.InitializeRequest(params=params): - requested_version = params.protocolVersion - self._initialization_state = InitializationState.Initializing - self._client_params = params - with responder: - await responder.respond( - types.ServerResult( - types.InitializeResult( - protocolVersion=requested_version - if requested_version in SUPPORTED_PROTOCOL_VERSIONS - else types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - serverInfo=types.Implementation( - name=self._init_options.server_name, - version=self._init_options.server_version, - ), - instructions=self._init_options.instructions, - ) - ) - ) - case _: - if self._initialization_state != InitializationState.Initialized: - raise RuntimeError("Received request before initialization was complete") - - async def _received_notification(self, notification: types.ClientNotification) -> None: - # Need this to avoid ASYNC910 - await anyio.lowlevel.checkpoint() - match notification.root: - case types.InitializedNotification(): - self._initialization_state = InitializationState.Initialized - case _: - if self._initialization_state != InitializationState.Initialized: - raise RuntimeError("Received notification before initialization was complete") - - async def send_log_message( - self, - level: types.LoggingLevel, - data: Any, - logger: str | None = None, - related_request_id: types.RequestId | None = None, - ) -> None: - """Send a log message notification.""" - await self.send_notification( - types.ServerNotification( - types.LoggingMessageNotification( - method="notifications/message", - params=types.LoggingMessageNotificationParams( - level=level, - data=data, - logger=logger, - ), - ) - ), - related_request_id, - ) - - async def send_resource_updated(self, uri: AnyUrl) -> None: - """Send a resource updated notification.""" - await self.send_notification( - types.ServerNotification( - types.ResourceUpdatedNotification( - method="notifications/resources/updated", - params=types.ResourceUpdatedNotificationParams(uri=uri), - ) - ) - ) - - async def create_message( - self, - messages: list[types.SamplingMessage], - *, - max_tokens: int, - system_prompt: str | None = None, - include_context: types.IncludeContext | None = None, - temperature: float | None = None, - stop_sequences: list[str] | None = None, - metadata: dict[str, Any] | None = None, - model_preferences: types.ModelPreferences | None = None, - related_request_id: types.RequestId | None = None, - ) -> types.CreateMessageResult: - """Send a sampling/create_message request.""" - return await self.send_request( - request=types.ServerRequest( - types.CreateMessageRequest( - method="sampling/createMessage", - params=types.CreateMessageRequestParams( - messages=messages, - systemPrompt=system_prompt, - includeContext=include_context, - temperature=temperature, - maxTokens=max_tokens, - stopSequences=stop_sequences, - metadata=metadata, - modelPreferences=model_preferences, - ), - ) - ), - result_type=types.CreateMessageResult, - metadata=ServerMessageMetadata( - related_request_id=related_request_id, - ), - ) - - async def list_roots(self) -> types.ListRootsResult: - """Send a roots/list request.""" - return await self.send_request( - types.ServerRequest( - types.ListRootsRequest( - method="roots/list", - ) - ), - types.ListRootsResult, - ) - - async def elicit( - self, - message: str, - requestedSchema: types.ElicitRequestedSchema, - related_request_id: types.RequestId | None = None, - ) -> types.ElicitResult: - """Send an elicitation/create request. - - Args: - message: The message to present to the user - requestedSchema: Schema defining the expected response structure - - Returns: - The client's response - """ - return await self.send_request( - types.ServerRequest( - types.ElicitRequest( - method="elicitation/create", - params=types.ElicitRequestParams( - message=message, - requestedSchema=requestedSchema, - ), - ) - ), - types.ElicitResult, - metadata=ServerMessageMetadata(related_request_id=related_request_id), - ) - - async def send_ping(self) -> types.EmptyResult: - """Send a ping request.""" - return await self.send_request( - types.ServerRequest( - types.PingRequest( - method="ping", - ) - ), - types.EmptyResult, - ) - - async def send_progress_notification( - self, - progress_token: str | int, - progress: float, - total: float | None = None, - message: str | None = None, - related_request_id: str | None = None, - ) -> None: - """Send a progress notification.""" - await self.send_notification( - types.ServerNotification( - types.ProgressNotification( - method="notifications/progress", - params=types.ProgressNotificationParams( - progressToken=progress_token, - progress=progress, - total=total, - message=message, - ), - ) - ), - related_request_id, - ) - - async def send_resource_list_changed(self) -> None: - """Send a resource list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.ResourceListChangedNotification( - method="notifications/resources/list_changed", - ) - ) - ) - - async def send_tool_list_changed(self) -> None: - """Send a tool list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.ToolListChangedNotification( - method="notifications/tools/list_changed", - ) - ) - ) - - async def send_prompt_list_changed(self) -> None: - """Send a prompt list changed notification.""" - await self.send_notification( - types.ServerNotification( - types.PromptListChangedNotification( - method="notifications/prompts/list_changed", - ) - ) - ) - - async def _handle_incoming(self, req: ServerRequestResponder) -> None: - await self._incoming_message_stream_writer.send(req) - - @property - def incoming_messages( - self, - ) -> MemoryObjectReceiveStream[ServerRequestResponder]: - return self._incoming_message_stream_reader diff --git a/src/mcp/mcp/server/sse.py b/src/mcp/mcp/server/sse.py deleted file mode 100644 index b7ff33280..000000000 --- a/src/mcp/mcp/server/sse.py +++ /dev/null @@ -1,249 +0,0 @@ -""" -SSE Server Transport Module - -This module implements a Server-Sent Events (SSE) transport layer for MCP servers. - -Example usage: -``` - # Create an SSE transport at an endpoint - sse = SseServerTransport("/messages/") - - # Create Starlette routes for SSE and message handling - routes = [ - Route("/sse", endpoint=handle_sse, methods=["GET"]), - Mount("/messages/", app=sse.handle_post_message), - ] - - # Define handler functions - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: - await app.run( - streams[0], streams[1], app.create_initialization_options() - ) - # Return empty response to avoid NoneType error - return Response() - - # Create and run Starlette app - starlette_app = Starlette(routes=routes) - uvicorn.run(starlette_app, host="127.0.0.1", port=port) -``` - -Note: The handle_sse function must return a Response to avoid a "TypeError: 'NoneType' -object is not callable" error when client disconnects. The example above returns -an empty Response() after the SSE connection ends to fix this. - -See SseServerTransport class documentation for more details. -""" - -import logging -from contextlib import asynccontextmanager -from typing import Any -from urllib.parse import quote -from uuid import UUID, uuid4 - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import ValidationError -from sse_starlette import EventSourceResponse -from starlette.requests import Request -from starlette.responses import Response -from starlette.types import Receive, Scope, Send - -import mcp.types as types -from mcp.server.transport_security import ( - TransportSecurityMiddleware, - TransportSecuritySettings, -) -from mcp.shared.message import ServerMessageMetadata, SessionMessage - -logger = logging.getLogger(__name__) - - -class SseServerTransport: - """ - SSE server transport for MCP. This class provides _two_ ASGI applications, - suitable to be used with a framework like Starlette and a server like Hypercorn: - - 1. connect_sse() is an ASGI application which receives incoming GET requests, - and sets up a new SSE stream to send server messages to the client. - 2. handle_post_message() is an ASGI application which receives incoming POST - requests, which should contain client messages that link to a - previously-established SSE session. - """ - - _endpoint: str - _read_stream_writers: dict[UUID, MemoryObjectSendStream[SessionMessage | Exception]] - _security: TransportSecurityMiddleware - - def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None: - """ - Creates a new SSE server transport, which will direct the client to POST - messages to the relative path given. - - Args: - endpoint: A relative path where messages should be posted - (e.g., "/messages/"). - security_settings: Optional security settings for DNS rebinding protection. - - Note: - We use relative paths instead of full URLs for several reasons: - 1. Security: Prevents cross-origin requests by ensuring clients only connect - to the same origin they established the SSE connection with - 2. Flexibility: The server can be mounted at any path without needing to - know its full URL - 3. Portability: The same endpoint configuration works across different - environments (development, staging, production) - - Raises: - ValueError: If the endpoint is a full URL instead of a relative path - """ - - super().__init__() - - # Validate that endpoint is a relative path and not a full URL - if "://" in endpoint or endpoint.startswith("//") or "?" in endpoint or "#" in endpoint: - raise ValueError( - f"Given endpoint: {endpoint} is not a relative path (e.g., '/messages/'), " - "expecting a relative path (e.g., '/messages/')." - ) - - # Ensure endpoint starts with a forward slash - if not endpoint.startswith("/"): - endpoint = "/" + endpoint - - self._endpoint = endpoint - self._read_stream_writers = {} - self._security = TransportSecurityMiddleware(security_settings) - logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}") - - @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): - if scope["type"] != "http": - logger.error("connect_sse received non-HTTP request") - raise ValueError("connect_sse can only handle HTTP requests") - - # Validate request headers for DNS rebinding protection - request = Request(scope, receive) - error_response = await self._security.validate_request(request, is_post=False) - if error_response: - await error_response(scope, receive, send) - raise ValueError("Request validation failed") - - logger.debug("Setting up SSE connection") - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - session_id = uuid4() - self._read_stream_writers[session_id] = read_stream_writer - logger.debug(f"Created new session with ID: {session_id}") - - # Determine the full path for the message endpoint to be sent to the client. - # scope['root_path'] is the prefix where the current Starlette app - # instance is mounted. - # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". - root_path = scope.get("root_path", "") - - # self._endpoint is the path *within* this app, e.g., "/messages". - # Concatenating them gives the full absolute path from the server root. - # e.g., "" + "/messages" -> "/messages" - # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" - full_message_path_for_client = root_path.rstrip("/") + self._endpoint - - # This is the URI (path + query) the client will use to POST messages. - client_post_uri_data = f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" - - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, Any]](0) - - async def sse_writer(): - logger.debug("Starting SSE writer") - async with sse_stream_writer, write_stream_reader: - await sse_stream_writer.send({"event": "endpoint", "data": client_post_uri_data}) - logger.debug(f"Sent endpoint event: {client_post_uri_data}") - - async for session_message in write_stream_reader: - logger.debug(f"Sending message via SSE: {session_message}") - await sse_stream_writer.send( - { - "event": "message", - "data": session_message.message.model_dump_json(by_alias=True, exclude_none=True), - } - ) - - async with anyio.create_task_group() as tg: - - async def response_wrapper(scope: Scope, receive: Receive, send: Send): - """ - The EventSourceResponse returning signals a client close / disconnect. - In this case we close our side of the streams to signal the client that - the connection has been closed. - """ - await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)( - scope, receive, send - ) - await read_stream_writer.aclose() - await write_stream_reader.aclose() - logging.debug(f"Client session disconnected {session_id}") - - logger.debug("Starting SSE response task") - tg.start_soon(response_wrapper, scope, receive, send) - - logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) - - async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: - logger.debug("Handling POST message") - request = Request(scope, receive) - - # Validate request headers for DNS rebinding protection - error_response = await self._security.validate_request(request, is_post=True) - if error_response: - return await error_response(scope, receive, send) - - session_id_param = request.query_params.get("session_id") - if session_id_param is None: - logger.warning("Received request without session_id") - response = Response("session_id is required", status_code=400) - return await response(scope, receive, send) - - try: - session_id = UUID(hex=session_id_param) - logger.debug(f"Parsed session ID: {session_id}") - except ValueError: - logger.warning(f"Received invalid session ID: {session_id_param}") - response = Response("Invalid session ID", status_code=400) - return await response(scope, receive, send) - - writer = self._read_stream_writers.get(session_id) - if not writer: - logger.warning(f"Could not find session for ID: {session_id}") - response = Response("Could not find session", status_code=404) - return await response(scope, receive, send) - - body = await request.body() - logger.debug(f"Received JSON: {body}") - - try: - message = types.JSONRPCMessage.model_validate_json(body) - logger.debug(f"Validated client message: {message}") - except ValidationError as err: - logger.exception("Failed to parse message") - response = Response("Could not parse message", status_code=400) - await response(scope, receive, send) - await writer.send(err) - return - - # Pass the ASGI scope for framework-agnostic access to request data - metadata = ServerMessageMetadata(request_context=request) - session_message = SessionMessage(message, metadata=metadata) - logger.debug(f"Sending session message to writer: {session_message}") - response = Response("Accepted", status_code=202) - await response(scope, receive, send) - await writer.send(session_message) diff --git a/src/mcp/mcp/server/stdio.py b/src/mcp/mcp/server/stdio.py deleted file mode 100644 index d1618a371..000000000 --- a/src/mcp/mcp/server/stdio.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Stdio Server Transport Module - -This module provides functionality for creating an stdio-based transport layer -that can be used to communicate with an MCP client through standard input/output -streams. - -Example usage: -``` - async def run_server(): - async with stdio_server() as (read_stream, write_stream): - # read_stream contains incoming JSONRPCMessages from stdin - # write_stream allows sending JSONRPCMessages to stdout - server = await create_my_server() - await server.run(read_stream, write_stream, init_options) - - anyio.run(run_server) -``` -""" - -import sys -from contextlib import asynccontextmanager -from io import TextIOWrapper - -import anyio -import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -import mcp.types as types -from mcp.shared.message import SessionMessage - - -@asynccontextmanager -async def stdio_server( - stdin: anyio.AsyncFile[str] | None = None, - stdout: anyio.AsyncFile[str] | None = None, -): - """ - Server transport for stdio: this communicates with an MCP client by reading - from the current process' stdin and writing to stdout. - """ - # Purposely not using context managers for these, as we don't want to close - # standard process handles. Encoding of stdin/stdout as text streams on - # python is platform-dependent (Windows is particularly problematic), so we - # re-wrap the underlying binary stream to ensure UTF-8. - if not stdin: - stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8")) - if not stdout: - stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8")) - - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async def stdin_reader(): - try: - async with read_stream_writer: - async for line in stdin: - try: - message = types.JSONRPCMessage.model_validate_json(line) - except Exception as exc: - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: - await anyio.lowlevel.checkpoint() - - async def stdout_writer(): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - json = session_message.message.model_dump_json(by_alias=True, exclude_none=True) - await stdout.write(json + "\n") - await stdout.flush() - except anyio.ClosedResourceError: - await anyio.lowlevel.checkpoint() - - async with anyio.create_task_group() as tg: - tg.start_soon(stdin_reader) - tg.start_soon(stdout_writer) - yield read_stream, write_stream diff --git a/src/mcp/mcp/server/streamable_http.py b/src/mcp/mcp/server/streamable_http.py deleted file mode 100644 index 802cb8680..000000000 --- a/src/mcp/mcp/server/streamable_http.py +++ /dev/null @@ -1,901 +0,0 @@ -""" -StreamableHTTP Server Transport Module - -This module implements an HTTP transport layer with Streamable HTTP. - -The transport handles bidirectional communication using HTTP requests and -responses, with streaming support for long-running operations. -""" - -import json -import logging -import re -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager -from dataclasses import dataclass -from http import HTTPStatus - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import ValidationError -from sse_starlette import EventSourceResponse -from starlette.requests import Request -from starlette.responses import Response -from starlette.types import Receive, Scope, Send - -from mcp.server.transport_security import ( - TransportSecurityMiddleware, - TransportSecuritySettings, -) -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.types import ( - DEFAULT_NEGOTIATED_VERSION, - INTERNAL_ERROR, - INVALID_PARAMS, - INVALID_REQUEST, - PARSE_ERROR, - ErrorData, - JSONRPCError, - JSONRPCMessage, - JSONRPCRequest, - JSONRPCResponse, - RequestId, -) - -logger = logging.getLogger(__name__) - - -# Header names -MCP_SESSION_ID_HEADER = "mcp-session-id" -MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version" -LAST_EVENT_ID_HEADER = "last-event-id" - -# Content types -CONTENT_TYPE_JSON = "application/json" -CONTENT_TYPE_SSE = "text/event-stream" - -# Special key for the standalone GET stream -GET_STREAM_KEY = "_GET_stream" - -# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) -# Pattern ensures entire string contains only valid characters by using ^ and $ anchors -SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") - -# Type aliases -StreamId = str -EventId = str - - -@dataclass -class EventMessage: - """ - A JSONRPCMessage with an optional event ID for stream resumability. - """ - - message: JSONRPCMessage - event_id: str | None = None - - -EventCallback = Callable[[EventMessage], Awaitable[None]] - - -class EventStore(ABC): - """ - Interface for resumability support via event storage. - """ - - @abstractmethod - async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: - """ - Stores an event for later retrieval. - - Args: - stream_id: ID of the stream the event belongs to - message: The JSON-RPC message to store - - Returns: - The generated event ID for the stored event - """ - pass - - @abstractmethod - async def replay_events_after( - self, - last_event_id: EventId, - send_callback: EventCallback, - ) -> StreamId | None: - """ - Replays events that occurred after the specified event ID. - - Args: - last_event_id: The ID of the last event the client received - send_callback: A callback function to send events to the client - - Returns: - The stream ID of the replayed events - """ - pass - - -class StreamableHTTPServerTransport: - """ - HTTP server transport with event streaming support for MCP. - - Handles JSON-RPC messages in HTTP POST requests with SSE streaming. - Supports optional JSON responses and session management. - """ - - # Server notification streams for POST requests as well as standalone SSE stream - _read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] | None = None - _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] | None = None - _write_stream: MemoryObjectSendStream[SessionMessage] | None = None - _write_stream_reader: MemoryObjectReceiveStream[SessionMessage] | None = None - _security: TransportSecurityMiddleware - - def __init__( - self, - mcp_session_id: str | None, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - security_settings: TransportSecuritySettings | None = None, - ) -> None: - """ - Initialize a new StreamableHTTP server transport. - - Args: - mcp_session_id: Optional session identifier for this connection. - Must contain only visible ASCII characters (0x21-0x7E). - is_json_response_enabled: If True, return JSON responses for requests - instead of SSE streams. Default is False. - event_store: Event store for resumability support. If provided, - resumability will be enabled, allowing clients to - reconnect and resume messages. - security_settings: Optional security settings for DNS rebinding protection. - - Raises: - ValueError: If the session ID contains invalid characters. - """ - if mcp_session_id is not None and not SESSION_ID_PATTERN.fullmatch(mcp_session_id): - raise ValueError("Session ID must only contain visible ASCII characters (0x21-0x7E)") - - self.mcp_session_id = mcp_session_id - self.is_json_response_enabled = is_json_response_enabled - self._event_store = event_store - self._security = TransportSecurityMiddleware(security_settings) - self._request_streams: dict[ - RequestId, - tuple[ - MemoryObjectSendStream[EventMessage], - MemoryObjectReceiveStream[EventMessage], - ], - ] = {} - self._terminated = False - - @property - def is_terminated(self) -> bool: - """Check if this transport has been explicitly terminated.""" - return self._terminated - - def _create_error_response( - self, - error_message: str, - status_code: HTTPStatus, - error_code: int = INVALID_REQUEST, - headers: dict[str, str] | None = None, - ) -> Response: - """Create an error response with a simple string message.""" - response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: - response_headers.update(headers) - - if self.mcp_session_id: - response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - - # Return a properly formatted JSON error response - error_response = JSONRPCError( - jsonrpc="2.0", - id="server-error", # We don't have a request ID for general errors - error=ErrorData( - code=error_code, - message=error_message, - ), - ) - - return Response( - error_response.model_dump_json(by_alias=True, exclude_none=True), - status_code=status_code, - headers=response_headers, - ) - - def _create_json_response( - self, - response_message: JSONRPCMessage | None, - status_code: HTTPStatus = HTTPStatus.OK, - headers: dict[str, str] | None = None, - ) -> Response: - """Create a JSON response from a JSONRPCMessage""" - response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: - response_headers.update(headers) - - if self.mcp_session_id: - response_headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - - return Response( - response_message.model_dump_json(by_alias=True, exclude_none=True) if response_message else None, - status_code=status_code, - headers=response_headers, - ) - - def _get_session_id(self, request: Request) -> str | None: - """Extract the session ID from request headers.""" - return request.headers.get(MCP_SESSION_ID_HEADER) - - def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: - """Create event data dictionary from an EventMessage.""" - event_data = { - "event": "message", - "data": event_message.message.model_dump_json(by_alias=True, exclude_none=True), - } - - # If an event ID was provided, include it - if event_message.event_id: - event_data["id"] = event_message.event_id - - return event_data - - async def _clean_up_memory_streams(self, request_id: RequestId) -> None: - """Clean up memory streams for a given request ID.""" - if request_id in self._request_streams: - try: - # Close the request stream - await self._request_streams[request_id][0].aclose() - await self._request_streams[request_id][1].aclose() - except Exception: - # During cleanup, we catch all exceptions since streams might be in various states - logger.debug("Error closing memory streams - may already be closed") - finally: - # Remove the request stream from the mapping - self._request_streams.pop(request_id, None) - - async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: - """Application entry point that handles all HTTP requests""" - request = Request(scope, receive) - - # Validate request headers for DNS rebinding protection - is_post = request.method == "POST" - error_response = await self._security.validate_request(request, is_post=is_post) - if error_response: - await error_response(scope, receive, send) - return - - if self._terminated: - # If the session has been terminated, return 404 Not Found - response = self._create_error_response( - "Not Found: Session has been terminated", - HTTPStatus.NOT_FOUND, - ) - await response(scope, receive, send) - return - - if request.method == "POST": - await self._handle_post_request(scope, request, receive, send) - elif request.method == "GET": - await self._handle_get_request(request, send) - elif request.method == "DELETE": - await self._handle_delete_request(request, send) - else: - await self._handle_unsupported_request(request, send) - - def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: - """Check if the request accepts the required media types.""" - accept_header = request.headers.get("accept", "") - accept_types = [media_type.strip() for media_type in accept_header.split(",")] - - has_json = any(media_type.startswith(CONTENT_TYPE_JSON) for media_type in accept_types) - has_sse = any(media_type.startswith(CONTENT_TYPE_SSE) for media_type in accept_types) - - return has_json, has_sse - - def _check_content_type(self, request: Request) -> bool: - """Check if the request has the correct Content-Type.""" - content_type = request.headers.get("content-type", "") - content_type_parts = [part.strip() for part in content_type.split(";")[0].split(",")] - - return any(part == CONTENT_TYPE_JSON for part in content_type_parts) - - async def _handle_post_request(self, scope: Scope, request: Request, receive: Receive, send: Send) -> None: - """Handle POST requests containing JSON-RPC messages.""" - writer = self._read_stream_writer - if writer is None: - raise ValueError("No read stream writer available. Ensure connect() is called first.") - try: - # Check Accept headers - has_json, has_sse = self._check_accept_headers(request) - if not (has_json and has_sse): - response = self._create_error_response( - ("Not Acceptable: Client must accept both application/json and text/event-stream"), - HTTPStatus.NOT_ACCEPTABLE, - ) - await response(scope, receive, send) - return - - # Validate Content-Type - if not self._check_content_type(request): - response = self._create_error_response( - "Unsupported Media Type: Content-Type must be application/json", - HTTPStatus.UNSUPPORTED_MEDIA_TYPE, - ) - await response(scope, receive, send) - return - - # Parse the body - only read it once - body = await request.body() - - try: - raw_message = json.loads(body) - except json.JSONDecodeError as e: - response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) - await response(scope, receive, send) - return - - try: - message = JSONRPCMessage.model_validate(raw_message) - except ValidationError as e: - response = self._create_error_response( - f"Validation error: {str(e)}", - HTTPStatus.BAD_REQUEST, - INVALID_PARAMS, - ) - await response(scope, receive, send) - return - - # Check if this is an initialization request - is_initialization_request = isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" - - if is_initialization_request: - # Check if the server already has an established session - if self.mcp_session_id: - # Check if request has a session ID - request_session_id = self._get_session_id(request) - - # If request has a session ID but doesn't match, return 404 - if request_session_id and request_session_id != self.mcp_session_id: - response = self._create_error_response( - "Not Found: Invalid or expired session ID", - HTTPStatus.NOT_FOUND, - ) - await response(scope, receive, send) - return - elif not await self._validate_request_headers(request, send): - return - - # For notifications and responses only, return 202 Accepted - if not isinstance(message.root, JSONRPCRequest): - # Create response object and send it - response = self._create_json_response( - None, - HTTPStatus.ACCEPTED, - ) - await response(scope, receive, send) - - # Process the message after sending the response - metadata = ServerMessageMetadata(request_context=request) - session_message = SessionMessage(message, metadata=metadata) - await writer.send(session_message) - - return - - # Extract the request ID outside the try block for proper scope - request_id = str(message.root.id) - # Register this stream for the request ID - self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) - request_stream_reader = self._request_streams[request_id][1] - - if self.is_json_response_enabled: - # Process the message - metadata = ServerMessageMetadata(request_context=request) - session_message = SessionMessage(message, metadata=metadata) - await writer.send(session_message) - try: - # Process messages from the request-specific stream - # We need to collect all messages until we get a response - response_message = None - - # Use similar approach to SSE writer for consistency - async for event_message in request_stream_reader: - # If it's a response, this is what we're waiting for - if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError): - response_message = event_message.message - break - # For notifications and request, keep waiting - else: - logger.debug(f"received: {event_message.message.root.method}") - - # At this point we should have a response - if response_message: - # Create JSON response - response = self._create_json_response(response_message) - await response(scope, receive, send) - else: - # This shouldn't happen in normal operation - logger.error("No response message received before stream closed") - response = self._create_error_response( - "Error processing request: No response received", - HTTPStatus.INTERNAL_SERVER_ERROR, - ) - await response(scope, receive, send) - except Exception: - logger.exception("Error processing JSON response") - response = self._create_error_response( - "Error processing request", - HTTPStatus.INTERNAL_SERVER_ERROR, - INTERNAL_ERROR, - ) - await response(scope, receive, send) - finally: - await self._clean_up_memory_streams(request_id) - else: - # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - - async def sse_writer(): - # Get the request ID from the incoming request message - try: - async with sse_stream_writer, request_stream_reader: - # Process messages from the request-specific stream - async for event_message in request_stream_reader: - # Build the event data - event_data = self._create_event_data(event_message) - await sse_stream_writer.send(event_data) - - # If response, remove from pending streams and close - if isinstance( - event_message.message.root, - JSONRPCResponse | JSONRPCError, - ): - break - except Exception: - logger.exception("Error in SSE writer") - finally: - logger.debug("Closing SSE writer") - await self._clean_up_memory_streams(request_id) - - # Create and start EventSourceResponse - # SSE stream mode (original behavior) - # Set up headers - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - **({MCP_SESSION_ID_HEADER: self.mcp_session_id} if self.mcp_session_id else {}), - } - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=sse_writer, - headers=headers, - ) - - # Start the SSE response (this will send headers immediately) - try: - # First send the response to establish the SSE connection - async with anyio.create_task_group() as tg: - tg.start_soon(response, scope, receive, send) - # Then send the message to be processed by the server - metadata = ServerMessageMetadata(request_context=request) - session_message = SessionMessage(message, metadata=metadata) - await writer.send(session_message) - except Exception: - logger.exception("SSE response error") - await sse_stream_writer.aclose() - await sse_stream_reader.aclose() - await self._clean_up_memory_streams(request_id) - - except Exception as err: - logger.exception("Error handling POST request") - response = self._create_error_response( - f"Error handling POST request: {err}", - HTTPStatus.INTERNAL_SERVER_ERROR, - INTERNAL_ERROR, - ) - await response(scope, receive, send) - if writer: - await writer.send(Exception(err)) - return - - async def _handle_get_request(self, request: Request, send: Send) -> None: - """ - Handle GET request to establish SSE. - - This allows the server to communicate to the client without the client - first sending data via HTTP POST. The server can send JSON-RPC requests - and notifications on this stream. - """ - writer = self._read_stream_writer - if writer is None: - raise ValueError("No read stream writer available. Ensure connect() is called first.") - - # Validate Accept header - must include text/event-stream - _, has_sse = self._check_accept_headers(request) - - if not has_sse: - response = self._create_error_response( - "Not Acceptable: Client must accept text/event-stream", - HTTPStatus.NOT_ACCEPTABLE, - ) - await response(request.scope, request.receive, send) - return - - if not await self._validate_request_headers(request, send): - return - - # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): - await self._replay_events(last_event_id, request, send) - return - - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - } - - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - - # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: - response = self._create_error_response( - "Conflict: Only one SSE stream is allowed per session", - HTTPStatus.CONFLICT, - ) - await response(request.scope, request.receive, send) - return - - # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - - async def standalone_sse_writer(): - try: - # Create a standalone message stream for server-initiated messages - - self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0) - standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] - - async with sse_stream_writer, standalone_stream_reader: - # Process messages from the standalone stream - async for event_message in standalone_stream_reader: - # For the standalone stream, we handle: - # - JSONRPCNotification (server sends notifications to client) - # - JSONRPCRequest (server sends requests to client) - # We should NOT receive JSONRPCResponse - - # Send the message via SSE - event_data = self._create_event_data(event_message) - await sse_stream_writer.send(event_data) - except Exception: - logger.exception("Error in standalone SSE writer") - finally: - logger.debug("Closing standalone SSE writer") - await self._clean_up_memory_streams(GET_STREAM_KEY) - - # Create and start EventSourceResponse - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=standalone_sse_writer, - headers=headers, - ) - - try: - # This will send headers immediately and establish the SSE connection - await response(request.scope, request.receive, send) - except Exception: - logger.exception("Error in standalone SSE response") - await sse_stream_writer.aclose() - await sse_stream_reader.aclose() - await self._clean_up_memory_streams(GET_STREAM_KEY) - - async def _handle_delete_request(self, request: Request, send: Send) -> None: - """Handle DELETE requests for explicit session termination.""" - # Validate session ID - if not self.mcp_session_id: - # If no session ID set, return Method Not Allowed - response = self._create_error_response( - "Method Not Allowed: Session termination not supported", - HTTPStatus.METHOD_NOT_ALLOWED, - ) - await response(request.scope, request.receive, send) - return - - if not await self._validate_request_headers(request, send): - return - - await self.terminate() - - response = self._create_json_response( - None, - HTTPStatus.OK, - ) - await response(request.scope, request.receive, send) - - async def terminate(self) -> None: - """Terminate the current session, closing all streams. - - Once terminated, all requests with this session ID will receive 404 Not Found. - """ - - self._terminated = True - logger.info(f"Terminating session: {self.mcp_session_id}") - - # We need a copy of the keys to avoid modification during iteration - request_stream_keys = list(self._request_streams.keys()) - - # Close all request streams asynchronously - for key in request_stream_keys: - await self._clean_up_memory_streams(key) - - # Clear the request streams dictionary immediately - self._request_streams.clear() - try: - if self._read_stream_writer is not None: - await self._read_stream_writer.aclose() - if self._read_stream is not None: - await self._read_stream.aclose() - if self._write_stream_reader is not None: - await self._write_stream_reader.aclose() - if self._write_stream is not None: - await self._write_stream.aclose() - except Exception as e: - # During cleanup, we catch all exceptions since streams might be in various states - logger.debug(f"Error closing streams: {e}") - - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: - """Handle unsupported HTTP methods.""" - headers = { - "Content-Type": CONTENT_TYPE_JSON, - "Allow": "GET, POST, DELETE", - } - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - - response = self._create_error_response( - "Method Not Allowed", - HTTPStatus.METHOD_NOT_ALLOWED, - headers=headers, - ) - await response(request.scope, request.receive, send) - - async def _validate_request_headers(self, request: Request, send: Send) -> bool: - if not await self._validate_session(request, send): - return False - if not await self._validate_protocol_version(request, send): - return False - return True - - async def _validate_session(self, request: Request, send: Send) -> bool: - """Validate the session ID in the request.""" - if not self.mcp_session_id: - # If we're not using session IDs, return True - return True - - # Get the session ID from the request headers - request_session_id = self._get_session_id(request) - - # If no session ID provided but required, return error - if not request_session_id: - response = self._create_error_response( - "Bad Request: Missing session ID", - HTTPStatus.BAD_REQUEST, - ) - await response(request.scope, request.receive, send) - return False - - # If session ID doesn't match, return error - if request_session_id != self.mcp_session_id: - response = self._create_error_response( - "Not Found: Invalid or expired session ID", - HTTPStatus.NOT_FOUND, - ) - await response(request.scope, request.receive, send) - return False - - return True - - async def _validate_protocol_version(self, request: Request, send: Send) -> bool: - """Validate the protocol version header in the request.""" - # Get the protocol version from the request headers - protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) - - # If no protocol version provided, assume default version - if protocol_version is None: - protocol_version = DEFAULT_NEGOTIATED_VERSION - - # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: - supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) - response = self._create_error_response( - f"Bad Request: Unsupported protocol version: {protocol_version}. " - + f"Supported versions: {supported_versions}", - HTTPStatus.BAD_REQUEST, - ) - await response(request.scope, request.receive, send) - return False - - return True - - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: - """ - Replays events that would have been sent after the specified event ID. - Only used when resumability is enabled. - """ - event_store = self._event_store - if not event_store: - return - - try: - headers = { - "Cache-Control": "no-cache, no-transform", - "Connection": "keep-alive", - "Content-Type": CONTENT_TYPE_SSE, - } - - if self.mcp_session_id: - headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - - # Create SSE stream for replay - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) - - async def replay_sender(): - try: - async with sse_stream_writer: - # Define an async callback for sending events - async def send_event(event_message: EventMessage) -> None: - event_data = self._create_event_data(event_message) - await sse_stream_writer.send(event_data) - - # Replay past events and get the stream ID - stream_id = await event_store.replay_events_after(last_event_id, send_event) - - # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: - self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) - msg_reader = self._request_streams[stream_id][1] - - # Forward messages to SSE - async with msg_reader: - async for event_message in msg_reader: - event_data = self._create_event_data(event_message) - - await sse_stream_writer.send(event_data) - except Exception: - logger.exception("Error in replay sender") - - # Create and start EventSourceResponse - response = EventSourceResponse( - content=sse_stream_reader, - data_sender_callable=replay_sender, - headers=headers, - ) - - try: - await response(request.scope, request.receive, send) - except Exception: - logger.exception("Error in replay response") - finally: - await sse_stream_writer.aclose() - await sse_stream_reader.aclose() - - except Exception: - logger.exception("Error replaying events") - response = self._create_error_response( - "Error replaying events", - HTTPStatus.INTERNAL_SERVER_ERROR, - INTERNAL_ERROR, - ) - await response(request.scope, request.receive, send) - - @asynccontextmanager - async def connect( - self, - ) -> AsyncGenerator[ - tuple[ - MemoryObjectReceiveStream[SessionMessage | Exception], - MemoryObjectSendStream[SessionMessage], - ], - None, - ]: - """Context manager that provides read and write streams for a connection. - - Yields: - Tuple of (read_stream, write_stream) for bidirectional communication - """ - - # Create the memory streams for this connection - - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - - # Store the streams - self._read_stream_writer = read_stream_writer - self._read_stream = read_stream - self._write_stream_reader = write_stream_reader - self._write_stream = write_stream - - # Start a task group for message routing - async with anyio.create_task_group() as tg: - # Create a message router that distributes messages to request streams - async def message_router(): - try: - async for session_message in write_stream_reader: - # Determine which request stream(s) should receive this message - message = session_message.message - target_request_id = None - # Check if this is a response - if isinstance(message.root, JSONRPCResponse | JSONRPCError): - response_id = str(message.root.id) - # If this response is for an existing request stream, - # send it there - target_request_id = response_id - else: - # Extract related_request_id from meta if it exists - if ( - session_message.metadata is not None - and isinstance( - session_message.metadata, - ServerMessageMetadata, - ) - and session_message.metadata.related_request_id is not None - ): - target_request_id = str(session_message.metadata.related_request_id) - - request_stream_id = target_request_id if target_request_id is not None else GET_STREAM_KEY - - # Store the event if we have an event store, - # regardless of whether a client is connected - # messages will be replayed on the re-connect - event_id = None - if self._event_store: - event_id = await self._event_store.store_event(request_stream_id, message) - logger.debug(f"Stored {event_id} from {request_stream_id}") - - if request_stream_id in self._request_streams: - try: - # Send both the message and the event ID - await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except ( - anyio.BrokenResourceError, - anyio.ClosedResourceError, - ): - # Stream might be closed, remove from registry - self._request_streams.pop(request_stream_id, None) - else: - logging.debug( - f"""Request stream {request_stream_id} not found - for message. Still processing message as the client - might reconnect and replay.""" - ) - except Exception: - logger.exception("Error in message router") - - # Start the message router - tg.start_soon(message_router) - - try: - # Yield the streams for the caller to use - yield read_stream, write_stream - finally: - for stream_id in list(self._request_streams.keys()): - await self._clean_up_memory_streams(stream_id) - self._request_streams.clear() - - # Clean up the read and write streams - try: - await read_stream_writer.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() - await write_stream.aclose() - except Exception as e: - # During cleanup, we catch all exceptions since streams might be in various states - logger.debug(f"Error closing streams: {e}") diff --git a/src/mcp/mcp/server/streamable_http_manager.py b/src/mcp/mcp/server/streamable_http_manager.py deleted file mode 100644 index 53d542d21..000000000 --- a/src/mcp/mcp/server/streamable_http_manager.py +++ /dev/null @@ -1,279 +0,0 @@ -"""StreamableHTTP Session Manager for MCP servers.""" - -from __future__ import annotations - -import contextlib -import logging -from collections.abc import AsyncIterator -from http import HTTPStatus -from typing import Any -from uuid import uuid4 - -import anyio -from anyio.abc import TaskStatus -from starlette.requests import Request -from starlette.responses import Response -from starlette.types import Receive, Scope, Send - -from mcp.server.lowlevel.server import Server as MCPServer -from mcp.server.streamable_http import ( - MCP_SESSION_ID_HEADER, - EventStore, - StreamableHTTPServerTransport, -) -from mcp.server.transport_security import TransportSecuritySettings - -logger = logging.getLogger(__name__) - - -class StreamableHTTPSessionManager: - """ - Manages StreamableHTTP sessions with optional resumability via event store. - - This class abstracts away the complexity of session management, event storage, - and request handling for StreamableHTTP transports. It handles: - - 1. Session tracking for clients - 2. Resumability via an optional event store - 3. Connection management and lifecycle - 4. Request handling and transport setup - - Important: Only one StreamableHTTPSessionManager instance should be created - per application. The instance cannot be reused after its run() context has - completed. If you need to restart the manager, create a new instance. - - Args: - app: The MCP server instance - event_store: Optional event store for resumability support. - If provided, enables resumable connections where clients - can reconnect and receive missed events. - If None, sessions are still tracked but not resumable. - json_response: Whether to use JSON responses instead of SSE streams - stateless: If True, creates a completely fresh transport for each request - with no session tracking or state persistence between requests. - """ - - def __init__( - self, - app: MCPServer[Any, Any], - event_store: EventStore | None = None, - json_response: bool = False, - stateless: bool = False, - security_settings: TransportSecuritySettings | None = None, - ): - self.app = app - self.event_store = event_store - self.json_response = json_response - self.stateless = stateless - self.security_settings = security_settings - - # Session tracking (only used if not stateless) - self._session_creation_lock = anyio.Lock() - self._server_instances: dict[str, StreamableHTTPServerTransport] = {} - - # The task group will be set during lifespan - self._task_group = None - # Thread-safe tracking of run() calls - self._run_lock = anyio.Lock() - self._has_started = False - - @contextlib.asynccontextmanager - async def run(self) -> AsyncIterator[None]: - """ - Run the session manager with proper lifecycle management. - - This creates and manages the task group for all session operations. - - Important: This method can only be called once per instance. The same - StreamableHTTPSessionManager instance cannot be reused after this - context manager exits. Create a new instance if you need to restart. - - Use this in the lifespan context manager of your Starlette app: - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette) -> AsyncIterator[None]: - async with session_manager.run(): - yield - """ - # Thread-safe check to ensure run() is only called once - async with self._run_lock: - if self._has_started: - raise RuntimeError( - "StreamableHTTPSessionManager .run() can only be called " - "once per instance. Create a new instance if you need to run again." - ) - self._has_started = True - - async with anyio.create_task_group() as tg: - # Store the task group for later use - self._task_group = tg - logger.info("StreamableHTTP session manager started") - try: - yield # Let the application run - finally: - logger.info("StreamableHTTP session manager shutting down") - # Cancel task group to stop all spawned tasks - tg.cancel_scope.cancel() - self._task_group = None - # Clear any remaining server instances - self._server_instances.clear() - - async def handle_request( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - """ - Process ASGI request with proper session handling and transport setup. - - Dispatches to the appropriate handler based on stateless mode. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ - if self._task_group is None: - raise RuntimeError("Task group is not initialized. Make sure to use run().") - - # Dispatch to the appropriate handler - if self.stateless: - await self._handle_stateless_request(scope, receive, send) - else: - await self._handle_stateful_request(scope, receive, send) - - async def _handle_stateless_request( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - """ - Process request in stateless mode - creating a new transport for each request. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ - logger.debug("Stateless mode: Creating new transport for this request") - # No session ID needed in stateless mode - http_transport = StreamableHTTPServerTransport( - mcp_session_id=None, # No session tracking in stateless mode - is_json_response_enabled=self.json_response, - event_store=None, # No event store in stateless mode - security_settings=self.security_settings, - ) - - # Start server in a new task - async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED): - async with http_transport.connect() as streams: - read_stream, write_stream = streams - task_status.started() - try: - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=True, - ) - except Exception: - logger.exception("Stateless session crashed") - - # Assert task group is not None for type checking - assert self._task_group is not None - # Start the server task - await self._task_group.start(run_stateless_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - - # Terminate the transport after the request is handled - await http_transport.terminate() - - async def _handle_stateful_request( - self, - scope: Scope, - receive: Receive, - send: Send, - ) -> None: - """ - Process request in stateful mode - maintaining session state between requests. - - Args: - scope: ASGI scope - receive: ASGI receive function - send: ASGI send function - """ - request = Request(scope, receive) - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) - - # Existing session case - if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances: - transport = self._server_instances[request_mcp_session_id] - logger.debug("Session already exists, handling request directly") - await transport.handle_request(scope, receive, send) - return - - if request_mcp_session_id is None: - # New session case - logger.debug("Creating new transport") - async with self._session_creation_lock: - new_session_id = uuid4().hex - http_transport = StreamableHTTPServerTransport( - mcp_session_id=new_session_id, - is_json_response_enabled=self.json_response, - event_store=self.event_store, # May be None (no resumability) - security_settings=self.security_settings, - ) - - assert http_transport.mcp_session_id is not None - self._server_instances[http_transport.mcp_session_id] = http_transport - logger.info(f"Created new transport with session ID: {new_session_id}") - - # Define the server runner - async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: - async with http_transport.connect() as streams: - read_stream, write_stream = streams - task_status.started() - try: - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=False, # Stateful mode - ) - except Exception as e: - logger.error( - f"Session {http_transport.mcp_session_id} crashed: {e}", - exc_info=True, - ) - finally: - # Only remove from instances if not terminated - if ( - http_transport.mcp_session_id - and http_transport.mcp_session_id in self._server_instances - and not http_transport.is_terminated - ): - logger.info( - "Cleaning up crashed session " - f"{http_transport.mcp_session_id} from " - "active instances." - ) - del self._server_instances[http_transport.mcp_session_id] - - # Assert task group is not None for type checking - assert self._task_group is not None - # Start the server task - await self._task_group.start(run_server) - - # Handle the HTTP request and return the response - await http_transport.handle_request(scope, receive, send) - else: - # Invalid session ID - response = Response( - "Bad Request: No valid session ID provided", - status_code=HTTPStatus.BAD_REQUEST, - ) - await response(scope, receive, send) diff --git a/src/mcp/mcp/server/streaming_asgi_transport.py b/src/mcp/mcp/server/streaming_asgi_transport.py deleted file mode 100644 index a74751312..000000000 --- a/src/mcp/mcp/server/streaming_asgi_transport.py +++ /dev/null @@ -1,203 +0,0 @@ -""" -A modified version of httpx.ASGITransport that supports streaming responses. - -This transport runs the ASGI app as a separate anyio task, allowing it to -handle streaming responses like SSE where the app doesn't terminate until -the connection is closed. - -This is only intended for writing tests for the SSE transport. -""" - -import typing -from typing import Any, cast - -import anyio -import anyio.abc -import anyio.streams.memory -from httpx._models import Request, Response -from httpx._transports.base import AsyncBaseTransport -from httpx._types import AsyncByteStream -from starlette.types import ASGIApp, Receive, Scope, Send - - -class StreamingASGITransport(AsyncBaseTransport): - """ - A custom AsyncTransport that handles sending requests directly to an ASGI app - and supports streaming responses like SSE. - - Unlike the standard ASGITransport, this transport runs the ASGI app in a - separate anyio task, allowing it to handle responses from apps that don't - terminate immediately (like SSE endpoints). - - Arguments: - - * `app` - The ASGI application. - * `raise_app_exceptions` - Boolean indicating if exceptions in the application - should be raised. Default to `True`. Can be set to `False` for use cases - such as testing the content of a client 500 response. - * `root_path` - The root path on which the ASGI application should be mounted. - * `client` - A two-tuple indicating the client IP and port of incoming requests. - * `response_timeout` - Timeout in seconds to wait for the initial response. - Default is 10 seconds. - - TODO: https://github.com/encode/httpx/pull/3059 is adding something similar to - upstream httpx. When that merges, we should delete this & switch back to the - upstream implementation. - """ - - def __init__( - self, - app: ASGIApp, - task_group: anyio.abc.TaskGroup, - raise_app_exceptions: bool = True, - root_path: str = "", - client: tuple[str, int] = ("127.0.0.1", 123), - ) -> None: - self.app = app - self.raise_app_exceptions = raise_app_exceptions - self.root_path = root_path - self.client = client - self.task_group = task_group - - async def handle_async_request( - self, - request: Request, - ) -> Response: - assert isinstance(request.stream, AsyncByteStream) - - # ASGI scope. - scope = { - "type": "http", - "asgi": {"version": "3.0"}, - "http_version": "1.1", - "method": request.method, - "headers": [(k.lower(), v) for (k, v) in request.headers.raw], - "scheme": request.url.scheme, - "path": request.url.path, - "raw_path": request.url.raw_path.split(b"?")[0], - "query_string": request.url.query, - "server": (request.url.host, request.url.port), - "client": self.client, - "root_path": self.root_path, - } - - # Request body - request_body_chunks = request.stream.__aiter__() - request_complete = False - - # Response state - status_code = 499 - response_headers = None - response_started = False - response_complete = anyio.Event() - initial_response_ready = anyio.Event() - - # Synchronization for streaming response - asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[dict[str, Any]](100) - content_send_channel, content_receive_channel = anyio.create_memory_object_stream[bytes](100) - - # ASGI callables. - async def receive() -> dict[str, Any]: - nonlocal request_complete - - if request_complete: - await response_complete.wait() - return {"type": "http.disconnect"} - - try: - body = await request_body_chunks.__anext__() - except StopAsyncIteration: - request_complete = True - return {"type": "http.request", "body": b"", "more_body": False} - return {"type": "http.request", "body": body, "more_body": True} - - async def send(message: dict[str, Any]) -> None: - nonlocal status_code, response_headers, response_started - - await asgi_send_channel.send(message) - - # Start the ASGI application in a separate task - async def run_app() -> None: - try: - # Cast the receive and send functions to the ASGI types - await self.app(cast(Scope, scope), cast(Receive, receive), cast(Send, send)) - except Exception: - if self.raise_app_exceptions: - raise - - if not response_started: - await asgi_send_channel.send({"type": "http.response.start", "status": 500, "headers": []}) - - await asgi_send_channel.send({"type": "http.response.body", "body": b"", "more_body": False}) - finally: - await asgi_send_channel.aclose() - - # Process messages from the ASGI app - async def process_messages() -> None: - nonlocal status_code, response_headers, response_started - - try: - async with asgi_receive_channel: - async for message in asgi_receive_channel: - if message["type"] == "http.response.start": - assert not response_started - status_code = message["status"] - response_headers = message.get("headers", []) - response_started = True - - # As soon as we have headers, we can return a response - initial_response_ready.set() - - elif message["type"] == "http.response.body": - body = message.get("body", b"") - more_body = message.get("more_body", False) - - if body and request.method != "HEAD": - await content_send_channel.send(body) - - if not more_body: - response_complete.set() - await content_send_channel.aclose() - break - finally: - # Ensure events are set even if there's an error - initial_response_ready.set() - response_complete.set() - await content_send_channel.aclose() - - # Create tasks for running the app and processing messages - self.task_group.start_soon(run_app) - self.task_group.start_soon(process_messages) - - # Wait for the initial response or timeout - await initial_response_ready.wait() - - # Create a streaming response - return Response( - status_code, - headers=response_headers, - stream=StreamingASGIResponseStream(content_receive_channel), - ) - - -class StreamingASGIResponseStream(AsyncByteStream): - """ - A modified ASGIResponseStream that supports streaming responses. - - This class extends the standard ASGIResponseStream to handle cases where - the response body continues to be generated after the initial response - is returned. - """ - - def __init__( - self, - receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], - ) -> None: - self.receive_channel = receive_channel - - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - try: - async for chunk in self.receive_channel: - yield chunk - finally: - await self.receive_channel.aclose() diff --git a/src/mcp/mcp/server/transport_security.py b/src/mcp/mcp/server/transport_security.py deleted file mode 100644 index 3a884ee2b..000000000 --- a/src/mcp/mcp/server/transport_security.py +++ /dev/null @@ -1,127 +0,0 @@ -"""DNS rebinding protection for MCP server transports.""" - -import logging - -from pydantic import BaseModel, Field -from starlette.requests import Request -from starlette.responses import Response - -logger = logging.getLogger(__name__) - - -class TransportSecuritySettings(BaseModel): - """Settings for MCP transport security features. - - These settings help protect against DNS rebinding attacks by validating - incoming request headers. - """ - - enable_dns_rebinding_protection: bool = Field( - default=True, - description="Enable DNS rebinding protection (recommended for production)", - ) - - allowed_hosts: list[str] = Field( - default=[], - description="List of allowed Host header values. Only applies when " - + "enable_dns_rebinding_protection is True.", - ) - - allowed_origins: list[str] = Field( - default=[], - description="List of allowed Origin header values. Only applies when " - + "enable_dns_rebinding_protection is True.", - ) - - -class TransportSecurityMiddleware: - """Middleware to enforce DNS rebinding protection for MCP transport endpoints.""" - - def __init__(self, settings: TransportSecuritySettings | None = None): - # If not specified, disable DNS rebinding protection by default - # for backwards compatibility - self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - - def _validate_host(self, host: str | None) -> bool: - """Validate the Host header against allowed values.""" - if not host: - logger.warning("Missing Host header in request") - return False - - # Check exact match first - if host in self.settings.allowed_hosts: - return True - - # Check wildcard port patterns - for allowed in self.settings.allowed_hosts: - if allowed.endswith(":*"): - # Extract base host from pattern - base_host = allowed[:-2] - # Check if the actual host starts with base host and has a port - if host.startswith(base_host + ":"): - return True - - logger.warning(f"Invalid Host header: {host}") - return False - - def _validate_origin(self, origin: str | None) -> bool: - """Validate the Origin header against allowed values.""" - # Origin can be absent for same-origin requests - if not origin: - return True - - # Check exact match first - if origin in self.settings.allowed_origins: - return True - - # Check wildcard port patterns - for allowed in self.settings.allowed_origins: - if allowed.endswith(":*"): - # Extract base origin from pattern - base_origin = allowed[:-2] - # Check if the actual origin starts with base origin and has a port - if origin.startswith(base_origin + ":"): - return True - - logger.warning(f"Invalid Origin header: {origin}") - return False - - def _validate_content_type(self, content_type: str | None) -> bool: - """Validate the Content-Type header for POST requests.""" - if not content_type: - logger.warning("Missing Content-Type header in POST request") - return False - - # Content-Type must start with application/json - if not content_type.lower().startswith("application/json"): - logger.warning(f"Invalid Content-Type header: {content_type}") - return False - - return True - - async def validate_request(self, request: Request, is_post: bool = False) -> Response | None: - """Validate request headers for DNS rebinding protection. - - Returns None if validation passes, or an error Response if validation fails. - """ - # Always validate Content-Type for POST requests - if is_post: - content_type = request.headers.get("content-type") - if not self._validate_content_type(content_type): - return Response("Invalid Content-Type header", status_code=400) - - # Skip remaining validation if DNS rebinding protection is disabled - if not self.settings.enable_dns_rebinding_protection: - return None - - # Validate Host header - host = request.headers.get("host") - if not self._validate_host(host): - return Response("Invalid Host header", status_code=421) - - # Validate Origin header - origin = request.headers.get("origin") - if not self._validate_origin(origin): - return Response("Invalid Origin header", status_code=400) - - return None diff --git a/src/mcp/mcp/server/websocket.py b/src/mcp/mcp/server/websocket.py deleted file mode 100644 index 7c0d8789c..000000000 --- a/src/mcp/mcp/server/websocket.py +++ /dev/null @@ -1,62 +0,0 @@ -import logging -from contextlib import asynccontextmanager - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic_core import ValidationError -from starlette.types import Receive, Scope, Send -from starlette.websockets import WebSocket - -import mcp.types as types -from mcp.shared.message import SessionMessage - -logger = logging.getLogger(__name__) - - -@asynccontextmanager -async def websocket_server(scope: Scope, receive: Receive, send: Send): - """ - WebSocket server transport for MCP. This is an ASGI application, suitable to be - used with a framework like Starlette and a server like Hypercorn. - """ - - websocket = WebSocket(scope, receive, send) - await websocket.accept(subprotocol="mcp") - - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] - read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception] - - write_stream: MemoryObjectSendStream[SessionMessage] - write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async def ws_reader(): - try: - async with read_stream_writer: - async for msg in websocket.iter_text(): - try: - client_message = types.JSONRPCMessage.model_validate_json(msg) - except ValidationError as exc: - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage(client_message) - await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: - await websocket.close() - - async def ws_writer(): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - obj = session_message.message.model_dump_json(by_alias=True, exclude_none=True) - await websocket.send_text(obj) - except anyio.ClosedResourceError: - await websocket.close() - - async with anyio.create_task_group() as tg: - tg.start_soon(ws_reader) - tg.start_soon(ws_writer) - yield (read_stream, write_stream) diff --git a/src/mcp/mcp/shared/__init__.py b/src/mcp/mcp/shared/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/mcp/mcp/shared/_httpx_utils.py b/src/mcp/mcp/shared/_httpx_utils.py deleted file mode 100644 index e0611ce73..000000000 --- a/src/mcp/mcp/shared/_httpx_utils.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Utilities for creating standardized httpx AsyncClient instances.""" - -from typing import Any, Protocol - -import httpx - -__all__ = ["create_mcp_http_client"] - - -class McpHttpClientFactory(Protocol): - def __call__( - self, - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, - auth: httpx.Auth | None = None, - ) -> httpx.AsyncClient: ... - - -def create_mcp_http_client( - headers: dict[str, str] | None = None, - timeout: httpx.Timeout | None = None, - auth: httpx.Auth | None = None, -) -> httpx.AsyncClient: - """Create a standardized httpx AsyncClient with MCP defaults. - - This function provides common defaults used throughout the MCP codebase: - - follow_redirects=True (always enabled) - - Default timeout of 30 seconds if not specified - - Args: - headers: Optional headers to include with all requests. - timeout: Request timeout as httpx.Timeout object. - Defaults to 30 seconds if not specified. - auth: Optional authentication handler. - - Returns: - Configured httpx.AsyncClient instance with MCP defaults. - - Note: - The returned AsyncClient must be used as a context manager to ensure - proper cleanup of connections. - - Examples: - # Basic usage with MCP defaults - async with create_mcp_http_client() as client: - response = await client.get("https://api.example.com") - - # With custom headers - headers = {"Authorization": "Bearer token"} - async with create_mcp_http_client(headers) as client: - response = await client.get("/endpoint") - - # With both custom headers and timeout - timeout = httpx.Timeout(60.0, read=300.0) - async with create_mcp_http_client(headers, timeout) as client: - response = await client.get("/long-request") - - # With authentication - from httpx import BasicAuth - auth = BasicAuth(username="user", password="pass") - async with create_mcp_http_client(headers, timeout, auth) as client: - response = await client.get("/protected-endpoint") - """ - # Set MCP defaults - kwargs: dict[str, Any] = { - "follow_redirects": True, - } - - # Handle timeout - if timeout is None: - kwargs["timeout"] = httpx.Timeout(30.0) - else: - kwargs["timeout"] = timeout - - # Handle headers - if headers is not None: - kwargs["headers"] = headers - - # Handle authentication - if auth is not None: - kwargs["auth"] = auth - - return httpx.AsyncClient(**kwargs) diff --git a/src/mcp/mcp/shared/auth.py b/src/mcp/mcp/shared/auth.py deleted file mode 100644 index 6bf15b531..000000000 --- a/src/mcp/mcp/shared/auth.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Any, Literal - -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, field_validator - - -class OAuthToken(BaseModel): - """ - See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 - """ - - access_token: str - token_type: Literal["Bearer"] = "Bearer" - expires_in: int | None = None - scope: str | None = None - refresh_token: str | None = None - - @field_validator("token_type", mode="before") - @classmethod - def normalize_token_type(cls, v: str | None) -> str | None: - if isinstance(v, str): - # Bearer is title-cased in the spec, so we normalize it - # https://datatracker.ietf.org/doc/html/rfc6750#section-4 - return v.title() - return v - - -class InvalidScopeError(Exception): - def __init__(self, message: str): - self.message = message - - -class InvalidRedirectUriError(Exception): - def __init__(self, message: str): - self.message = message - - -class OAuthClientMetadata(BaseModel): - """ - RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. - See https://datatracker.ietf.org/doc/html/rfc7591#section-2 - for the full specification. - """ - - redirect_uris: list[AnyUrl] = Field(..., min_length=1) - # token_endpoint_auth_method: this implementation only supports none & - # client_secret_post; - # ie: we do not support client_secret_basic - token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"]] = [ - "authorization_code", - "refresh_token", - ] - # this implementation only supports code; ie: it does not support implicit grants - response_types: list[Literal["code"]] = ["code"] - scope: str | None = None - - # these fields are currently unused, but we support & store them for potential - # future use - client_name: str | None = None - client_uri: AnyHttpUrl | None = None - logo_uri: AnyHttpUrl | None = None - contacts: list[str] | None = None - tos_uri: AnyHttpUrl | None = None - policy_uri: AnyHttpUrl | None = None - jwks_uri: AnyHttpUrl | None = None - jwks: Any | None = None - software_id: str | None = None - software_version: str | None = None - - def validate_scope(self, requested_scope: str | None) -> list[str] | None: - if requested_scope is None: - return None - requested_scopes = requested_scope.split(" ") - allowed_scopes = [] if self.scope is None else self.scope.split(" ") - for scope in requested_scopes: - if scope not in allowed_scopes: - raise InvalidScopeError(f"Client was not registered with scope {scope}") - return requested_scopes - - def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: - if redirect_uri is not None: - # Validate redirect_uri against client's registered redirect URIs - if redirect_uri not in self.redirect_uris: - raise InvalidRedirectUriError(f"Redirect URI '{redirect_uri}' not registered for client") - return redirect_uri - elif len(self.redirect_uris) == 1: - return self.redirect_uris[0] - else: - raise InvalidRedirectUriError("redirect_uri must be specified when client has multiple registered URIs") - - -class OAuthClientInformationFull(OAuthClientMetadata): - """ - RFC 7591 OAuth 2.0 Dynamic Client Registration full response - (client information plus metadata). - """ - - client_id: str - client_secret: str | None = None - client_id_issued_at: int | None = None - client_secret_expires_at: int | None = None - - -class OAuthMetadata(BaseModel): - """ - RFC 8414 OAuth 2.0 Authorization Server Metadata. - See https://datatracker.ietf.org/doc/html/rfc8414#section-2 - """ - - issuer: AnyHttpUrl - authorization_endpoint: AnyHttpUrl - token_endpoint: AnyHttpUrl - registration_endpoint: AnyHttpUrl | None = None - scopes_supported: list[str] | None = None - response_types_supported: list[str] = ["code"] - response_modes_supported: list[str] | None = None - grant_types_supported: list[str] | None = None - token_endpoint_auth_methods_supported: list[str] | None = None - token_endpoint_auth_signing_alg_values_supported: list[str] | None = None - service_documentation: AnyHttpUrl | None = None - ui_locales_supported: list[str] | None = None - op_policy_uri: AnyHttpUrl | None = None - op_tos_uri: AnyHttpUrl | None = None - revocation_endpoint: AnyHttpUrl | None = None - revocation_endpoint_auth_methods_supported: list[str] | None = None - revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None - introspection_endpoint: AnyHttpUrl | None = None - introspection_endpoint_auth_methods_supported: list[str] | None = None - introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None - code_challenge_methods_supported: list[str] | None = None - - -class ProtectedResourceMetadata(BaseModel): - """ - RFC 9728 OAuth 2.0 Protected Resource Metadata. - See https://datatracker.ietf.org/doc/html/rfc9728#section-2 - """ - - resource: AnyHttpUrl - authorization_servers: list[AnyHttpUrl] = Field(..., min_length=1) - jwks_uri: AnyHttpUrl | None = None - scopes_supported: list[str] | None = None - bearer_methods_supported: list[str] | None = Field(default=["header"]) # MCP only supports header method - resource_signing_alg_values_supported: list[str] | None = None - resource_name: str | None = None - resource_documentation: AnyHttpUrl | None = None - resource_policy_uri: AnyHttpUrl | None = None - resource_tos_uri: AnyHttpUrl | None = None - # tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity - tls_client_certificate_bound_access_tokens: bool | None = None - authorization_details_types_supported: list[str] | None = None - dpop_signing_alg_values_supported: list[str] | None = None - # dpop_bound_access_tokens_required default is False, but ommited here for clarity - dpop_bound_access_tokens_required: bool | None = None diff --git a/src/mcp/mcp/shared/auth_utils.py b/src/mcp/mcp/shared/auth_utils.py deleted file mode 100644 index 6d6300c9c..000000000 --- a/src/mcp/mcp/shared/auth_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707).""" - -from urllib.parse import urlparse, urlsplit, urlunsplit - -from pydantic import AnyUrl, HttpUrl - - -def resource_url_from_server_url(url: str | HttpUrl | AnyUrl) -> str: - """Convert server URL to canonical resource URL per RFC 8707. - - RFC 8707 section 2 states that resource URIs "MUST NOT include a fragment component". - Returns absolute URI with lowercase scheme/host for canonical form. - - Args: - url: Server URL to convert - - Returns: - Canonical resource URL string - """ - # Convert to string if needed - url_str = str(url) - - # Parse the URL and remove fragment, create canonical form - parsed = urlsplit(url_str) - canonical = urlunsplit(parsed._replace(scheme=parsed.scheme.lower(), netloc=parsed.netloc.lower(), fragment="")) - - return canonical - - -def check_resource_allowed(requested_resource: str, configured_resource: str) -> bool: - """Check if a requested resource URL matches a configured resource URL. - - A requested resource matches if it has the same scheme, domain, port, - and its path starts with the configured resource's path. This allows - hierarchical matching where a token for a parent resource can be used - for child resources. - - Args: - requested_resource: The resource URL being requested - configured_resource: The resource URL that has been configured - - Returns: - True if the requested resource matches the configured resource - """ - # Parse both URLs - requested = urlparse(requested_resource) - configured = urlparse(configured_resource) - - # Compare scheme, host, and port (origin) - if requested.scheme.lower() != configured.scheme.lower() or requested.netloc.lower() != configured.netloc.lower(): - return False - - # Handle cases like requested=/foo and configured=/foo/ - requested_path = requested.path - configured_path = configured.path - - # If requested path is shorter, it cannot be a child - if len(requested_path) < len(configured_path): - return False - - # Check if the requested path starts with the configured path - # Ensure both paths end with / for proper comparison - # This ensures that paths like "/api123" don't incorrectly match "/api" - if not requested_path.endswith("/"): - requested_path += "/" - if not configured_path.endswith("/"): - configured_path += "/" - - return requested_path.startswith(configured_path) diff --git a/src/mcp/mcp/shared/context.py b/src/mcp/mcp/shared/context.py deleted file mode 100644 index f3006e7d5..000000000 --- a/src/mcp/mcp/shared/context.py +++ /dev/null @@ -1,20 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Generic - -from typing_extensions import TypeVar - -from mcp.shared.session import BaseSession -from mcp.types import RequestId, RequestParams - -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) -LifespanContextT = TypeVar("LifespanContextT") -RequestT = TypeVar("RequestT", default=Any) - - -@dataclass -class RequestContext(Generic[SessionT, LifespanContextT, RequestT]): - request_id: RequestId - meta: RequestParams.Meta | None - session: SessionT - lifespan_context: LifespanContextT - request: RequestT | None = None diff --git a/src/mcp/mcp/shared/exceptions.py b/src/mcp/mcp/shared/exceptions.py deleted file mode 100644 index 97a1c09a9..000000000 --- a/src/mcp/mcp/shared/exceptions.py +++ /dev/null @@ -1,14 +0,0 @@ -from mcp.types import ErrorData - - -class McpError(Exception): - """ - Exception type raised when an error arrives over an MCP connection. - """ - - error: ErrorData - - def __init__(self, error: ErrorData): - """Initialize McpError.""" - super().__init__(error.message) - self.error = error diff --git a/src/mcp/mcp/shared/memory.py b/src/mcp/mcp/shared/memory.py deleted file mode 100644 index c94e5e6ac..000000000 --- a/src/mcp/mcp/shared/memory.py +++ /dev/null @@ -1,99 +0,0 @@ -""" -In-memory transports -""" - -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from datetime import timedelta -from typing import Any - -import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream - -import mcp.types as types -from mcp.client.session import ( - ClientSession, - ElicitationFnT, - ListRootsFnT, - LoggingFnT, - MessageHandlerFnT, - SamplingFnT, -) -from mcp.server import Server -from mcp.shared.message import SessionMessage - -MessageStream = tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]] - - -@asynccontextmanager -async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageStream, MessageStream], None]: - """ - Creates a pair of bidirectional memory streams for client-server communication. - - Returns: - A tuple of (client_streams, server_streams) where each is a tuple of - (read_stream, write_stream) - """ - # Create streams for both directions - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - client_streams = (server_to_client_receive, client_to_server_send) - server_streams = (client_to_server_receive, server_to_client_send) - - async with ( - server_to_client_receive, - client_to_server_send, - client_to_server_receive, - server_to_client_send, - ): - yield client_streams, server_streams - - -@asynccontextmanager -async def create_connected_server_and_client_session( - server: Server[Any], - read_timeout_seconds: timedelta | None = None, - sampling_callback: SamplingFnT | None = None, - list_roots_callback: ListRootsFnT | None = None, - logging_callback: LoggingFnT | None = None, - message_handler: MessageHandlerFnT | None = None, - client_info: types.Implementation | None = None, - raise_exceptions: bool = False, - elicitation_callback: ElicitationFnT | None = None, -) -> AsyncGenerator[ClientSession, None]: - """Creates a ClientSession that is connected to a running MCP server.""" - async with create_client_server_memory_streams() as ( - client_streams, - server_streams, - ): - client_read, client_write = client_streams - server_read, server_write = server_streams - - # Create a cancel scope for the server task - async with anyio.create_task_group() as tg: - tg.start_soon( - lambda: server.run( - server_read, - server_write, - server.create_initialization_options(), - raise_exceptions=raise_exceptions, - ) - ) - - try: - async with ClientSession( - read_stream=client_read, - write_stream=client_write, - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - elicitation_callback=elicitation_callback, - ) as client_session: - await client_session.initialize() - yield client_session - finally: - tg.cancel_scope.cancel() diff --git a/src/mcp/mcp/shared/message.py b/src/mcp/mcp/shared/message.py deleted file mode 100644 index 4b6df23eb..000000000 --- a/src/mcp/mcp/shared/message.py +++ /dev/null @@ -1,43 +0,0 @@ -""" -Message wrapper with metadata support. - -This module defines a wrapper type that combines JSONRPCMessage with metadata -to support transport-specific features like resumability. -""" - -from collections.abc import Awaitable, Callable -from dataclasses import dataclass - -from mcp.types import JSONRPCMessage, RequestId - -ResumptionToken = str - -ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]] - - -@dataclass -class ClientMessageMetadata: - """Metadata specific to client messages.""" - - resumption_token: ResumptionToken | None = None - on_resumption_token_update: Callable[[ResumptionToken], Awaitable[None]] | None = None - - -@dataclass -class ServerMessageMetadata: - """Metadata specific to server messages.""" - - related_request_id: RequestId | None = None - # Request-specific context (e.g., headers, auth info) - request_context: object | None = None - - -MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None - - -@dataclass -class SessionMessage: - """A message with specific metadata for transport-specific features.""" - - message: JSONRPCMessage - metadata: MessageMetadata = None diff --git a/src/mcp/mcp/shared/metadata_utils.py b/src/mcp/mcp/shared/metadata_utils.py deleted file mode 100644 index e3f49daf4..000000000 --- a/src/mcp/mcp/shared/metadata_utils.py +++ /dev/null @@ -1,45 +0,0 @@ -"""Utility functions for working with metadata in MCP types. - -These utilities are primarily intended for client-side usage to properly display -human-readable names in user interfaces in a spec compliant way. -""" - -from mcp.types import Implementation, Prompt, Resource, ResourceTemplate, Tool - - -def get_display_name(obj: Tool | Resource | Prompt | ResourceTemplate | Implementation) -> str: - """ - Get the display name for an MCP object with proper precedence. - - This is a client-side utility function designed to help MCP clients display - human-readable names in their user interfaces. When servers provide a 'title' - field, it should be preferred over the programmatic 'name' field for display. - - For tools: title > annotations.title > name - For other objects: title > name - - Example: - # In a client displaying available tools - tools = await session.list_tools() - for tool in tools.tools: - display_name = get_display_name(tool) - print(f"Available tool: {display_name}") - - Args: - obj: An MCP object with name and optional title fields - - Returns: - The display name to use for UI presentation - """ - if isinstance(obj, Tool): - # Tools have special precedence: title > annotations.title > name - if hasattr(obj, "title") and obj.title is not None: - return obj.title - if obj.annotations and hasattr(obj.annotations, "title") and obj.annotations.title is not None: - return obj.annotations.title - return obj.name - else: - # All other objects: title > name - if hasattr(obj, "title") and obj.title is not None: - return obj.title - return obj.name diff --git a/src/mcp/mcp/shared/progress.py b/src/mcp/mcp/shared/progress.py deleted file mode 100644 index 1ad81a779..000000000 --- a/src/mcp/mcp/shared/progress.py +++ /dev/null @@ -1,58 +0,0 @@ -from collections.abc import Generator -from contextlib import contextmanager -from dataclasses import dataclass, field -from typing import Generic - -from pydantic import BaseModel - -from mcp.shared.context import LifespanContextT, RequestContext -from mcp.shared.session import ( - BaseSession, - ReceiveNotificationT, - ReceiveRequestT, - SendNotificationT, - SendRequestT, - SendResultT, -) -from mcp.types import ProgressToken - - -class Progress(BaseModel): - progress: float - total: float | None - - -@dataclass -class ProgressContext(Generic[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT]): - session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT] - progress_token: ProgressToken - total: float | None - current: float = field(default=0.0, init=False) - - async def progress(self, amount: float, message: str | None = None) -> None: - self.current += amount - - await self.session.send_progress_notification( - self.progress_token, self.current, total=self.total, message=message - ) - - -@contextmanager -def progress( - ctx: RequestContext[ - BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - LifespanContextT, - ], - total: float | None = None, -) -> Generator[ - ProgressContext[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - None, -]: - if ctx.meta is None or ctx.meta.progressToken is None: - raise ValueError("No progress token provided") - - progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total) - try: - yield progress_ctx - finally: - pass diff --git a/src/mcp/mcp/shared/session.py b/src/mcp/mcp/shared/session.py deleted file mode 100644 index b2f49fc8b..000000000 --- a/src/mcp/mcp/shared/session.py +++ /dev/null @@ -1,470 +0,0 @@ -import logging -from collections.abc import Callable -from contextlib import AsyncExitStack -from datetime import timedelta -from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar - -import anyio -import httpx -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from pydantic import BaseModel -from typing_extensions import Self - -from mcp.shared.exceptions import McpError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage -from mcp.types import ( - CONNECTION_CLOSED, - INVALID_PARAMS, - CancelledNotification, - ClientNotification, - ClientRequest, - ClientResult, - ErrorData, - JSONRPCError, - JSONRPCMessage, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - ProgressNotification, - RequestParams, - ServerNotification, - ServerRequest, - ServerResult, -) - -SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) -SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) -SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) -ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) -ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) -ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) - -RequestId = str | int - - -class ProgressFnT(Protocol): - """Protocol for progress notification callbacks.""" - - async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... - - -class RequestResponder(Generic[ReceiveRequestT, SendResultT]): - """Handles responding to MCP requests and manages request lifecycle. - - This class MUST be used as a context manager to ensure proper cleanup and - cancellation handling: - - Example: - with request_responder as resp: - await resp.respond(result) - - The context manager ensures: - 1. Proper cancellation scope setup and cleanup - 2. Request completion tracking - 3. Cleanup of in-flight requests - """ - - def __init__( - self, - request_id: RequestId, - request_meta: RequestParams.Meta | None, - request: ReceiveRequestT, - session: """BaseSession[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT - ]""", - on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], - message_metadata: MessageMetadata = None, - ) -> None: - self.request_id = request_id - self.request_meta = request_meta - self.request = request - self.message_metadata = message_metadata - self._session = session - self._completed = False - self._cancel_scope = anyio.CancelScope() - self._on_complete = on_complete - self._entered = False # Track if we're in a context manager - - def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": - """Enter the context manager, enabling request cancellation tracking.""" - self._entered = True - self._cancel_scope = anyio.CancelScope() - self._cancel_scope.__enter__() - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - """Exit the context manager, performing cleanup and notifying completion.""" - try: - if self._completed: - self._on_complete(self) - finally: - self._entered = False - if not self._cancel_scope: - raise RuntimeError("No active cancel scope") - self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) - - async def respond(self, response: SendResultT | ErrorData) -> None: - """Send a response for this request. - - Must be called within a context manager block. - Raises: - RuntimeError: If not used within a context manager - AssertionError: If request was already responded to - """ - if not self._entered: - raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" - - if not self.cancelled: - self._completed = True - - await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, response=response - ) - - async def cancel(self) -> None: - """Cancel this request and mark it as completed.""" - if not self._entered: - raise RuntimeError("RequestResponder must be used as a context manager") - if not self._cancel_scope: - raise RuntimeError("No active cancel scope") - - self._cancel_scope.cancel() - self._completed = True # Mark as completed so it's removed from in_flight - # Send an error response to indicate cancellation - await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, - response=ErrorData(code=0, message="Request cancelled", data=None), - ) - - @property - def in_flight(self) -> bool: - return not self._completed and not self.cancelled - - @property - def cancelled(self) -> bool: - return self._cancel_scope.cancel_called - - -class BaseSession( - Generic[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], -): - """ - Implements an MCP "session" on top of read/write streams, including features - like request/response linking, notifications, and progress. - - This class is an async context manager that automatically starts processing - messages when entered. - """ - - _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] - _request_id: int - _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] - _progress_callbacks: dict[RequestId, ProgressFnT] - - def __init__( - self, - read_stream: MemoryObjectReceiveStream[SessionMessage | Exception], - write_stream: MemoryObjectSendStream[SessionMessage], - receive_request_type: type[ReceiveRequestT], - receive_notification_type: type[ReceiveNotificationT], - # If none, reading will never time out - read_timeout_seconds: timedelta | None = None, - ) -> None: - self._read_stream = read_stream - self._write_stream = write_stream - self._response_streams = {} - self._request_id = 0 - self._receive_request_type = receive_request_type - self._receive_notification_type = receive_notification_type - self._session_read_timeout_seconds = read_timeout_seconds - self._in_flight = {} - self._progress_callbacks = {} - self._exit_stack = AsyncExitStack() - - async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. - self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - - async def send_request( - self, - request: SendRequestT, - result_type: type[ReceiveResultT], - request_read_timeout_seconds: timedelta | None = None, - metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, - ) -> ReceiveResultT: - """ - Sends a request and wait for a response. Raises an McpError if the - response contains an error. If a request read timeout is provided, it - will take precedence over the session read timeout. - - Do not use this method to emit notifications! Use send_notification() - instead. - """ - request_id = self._request_id - self._request_id = request_id + 1 - - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = response_stream - - # Set up progress token if progress callback is provided - request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token - if "params" not in request_data: - request_data["params"] = {} - if "_meta" not in request_data["params"]: - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback - - try: - jsonrpc_request = JSONRPCRequest( - jsonrpc="2.0", - id=request_id, - **request_data, - ) - - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = None - if request_read_timeout_seconds is not None: - timeout = request_read_timeout_seconds.total_seconds() - elif self._session_read_timeout_seconds is not None: - timeout = self._session_read_timeout_seconds.total_seconds() - - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - raise McpError( - ErrorData( - code=httpx.codes.REQUEST_TIMEOUT, - message=( - f"Timed out while waiting for response to " - f"{request.__class__.__name__}. Waited " - f"{timeout} seconds." - ), - ) - ) - - if isinstance(response_or_error, JSONRPCError): - raise McpError(response_or_error.error) - else: - return result_type.model_validate(response_or_error.result) - - finally: - self._response_streams.pop(request_id, None) - self._progress_callbacks.pop(request_id, None) - await response_stream.aclose() - await response_stream_reader.aclose() - - async def send_notification( - self, - notification: SendNotificationT, - related_request_id: RequestId | None = None, - ) -> None: - """ - Emits a notification, which is a one-way message that does not expect - a response. - """ - # Some transport implementations may need to set the related_request_id - # to attribute to the notifications to the request that triggered them. - jsonrpc_notification = JSONRPCNotification( - jsonrpc="2.0", - **notification.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage( - message=JSONRPCMessage(jsonrpc_notification), - metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, - ) - await self._write_stream.send(session_message) - - async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: - if isinstance(response, ErrorData): - jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) - await self._write_stream.send(session_message) - else: - jsonrpc_response = JSONRPCResponse( - jsonrpc="2.0", - id=request_id, - result=response.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) - await self._write_stream.send(session_message) - - async def _receive_loop(self) -> None: - async with ( - self._read_stream, - self._write_stream, - ): - try: - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): - try: - validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - responder = RequestResponder( - request_id=message.message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, - request=validated_request, - session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), - message_metadata=message.metadata, - ) - self._in_flight[responder.request_id] = responder - await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) - except Exception as e: - # For request validation errors, send a proper JSON-RPC error - # response instead of crashing the server - logging.warning(f"Failed to validate request: {e}") - logging.debug(f"Message that failed validation: {message.message.root}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.root.id, - error=ErrorData( - code=INVALID_PARAMS, - message="Invalid request parameters", - data="", - ), - ) - session_message = SessionMessage(message=JSONRPCMessage(error_response)) - await self._write_stream.send(session_message) - - elif isinstance(message.message.root, JSONRPCNotification): - try: - notification = self._receive_notification_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() - else: - # Handle progress notifications callback - if isinstance(notification.root, ProgressNotification): - progress_token = notification.root.params.progressToken - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - await callback( - notification.root.params.progress, - notification.root.params.total, - notification.root.params.message, - ) - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception as e: - # For other validation errors, log and continue - logging.warning( - f"Failed to validate notification: {e}. Message was: {message.message.root}" - ) - else: # Response or error - stream = self._response_streams.pop(message.message.root.id, None) - if stream: - await stream.send(message.message.root) - else: - await self._handle_incoming( - RuntimeError(f"Received response with an unknown request ID: {message}") - ) - - except anyio.ClosedResourceError: - # This is expected when the client disconnects abruptly. - # Without this handler, the exception would propagate up and - # crash the server's task group. - logging.debug("Read stream closed by client") - except Exception as e: - # Other exceptions are not expected and should be logged. We purposefully - # catch all exceptions here to avoid crashing the server. - logging.exception(f"Unhandled exception in receive loop: {e}") - finally: - # after the read stream is closed, we need to send errors - # to any pending requests - for id, stream in self._response_streams.items(): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: - # Stream might already be closed - pass - self._response_streams.clear() - - async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: - """ - Can be overridden by subclasses to handle a request without needing to - listen on the message stream. - - If the request is responded to within this method, it will not be - forwarded on to the message stream. - """ - - async def _received_notification(self, notification: ReceiveNotificationT) -> None: - """ - Can be overridden by subclasses to handle a notification without needing - to listen on the message stream. - """ - - async def send_progress_notification( - self, - progress_token: str | int, - progress: float, - total: float | None = None, - message: str | None = None, - ) -> None: - """ - Sends a progress notification for a request that is currently being - processed. - """ - - async def _handle_incoming( - self, - req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception, - ) -> None: - """A generic handler for incoming messages. Overwritten by subclasses.""" - pass diff --git a/src/mcp/mcp/shared/version.py b/src/mcp/mcp/shared/version.py deleted file mode 100644 index 23c46d04b..000000000 --- a/src/mcp/mcp/shared/version.py +++ /dev/null @@ -1,3 +0,0 @@ -from mcp.types import LATEST_PROTOCOL_VERSION - -SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", "2025-03-26", LATEST_PROTOCOL_VERSION] diff --git a/src/mcp/mcp/types.py b/src/mcp/mcp/types.py deleted file mode 100644 index 98fefa080..000000000 --- a/src/mcp/mcp/types.py +++ /dev/null @@ -1,1320 +0,0 @@ -from collections.abc import Callable -from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar - -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel -from pydantic.networks import AnyUrl, UrlConstraints -from typing_extensions import deprecated - -""" -Model Context Protocol bindings for Python - -These bindings were generated from https://github.com/modelcontextprotocol/specification, -using Claude, with a prompt something like the following: - -Generate idiomatic Python bindings for this schema for MCP, or the "Model Context -Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version -for reference. - -* For the bindings, let's use Pydantic V2 models. -* Each model should allow extra fields everywhere, by specifying `model_config = - ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class. -* Union types should be represented with a Pydantic `RootModel`. -* Define additional model classes instead of using dictionaries. Do this even if they're - not separate types in the schema. -""" - -LATEST_PROTOCOL_VERSION = "2025-06-18" - -""" -The default negotiated version of the Model Context Protocol when no version is specified. -We need this to satisfy the MCP specification, which requires the server to assume a -specific version if none is provided by the client. See section "Protocol Version Header" at -https://modelcontextprotocol.io/specification -""" -DEFAULT_NEGOTIATED_VERSION = "2025-03-26" - -ProgressToken = str | int -Cursor = str -Role = Literal["user", "assistant"] -RequestId = Annotated[int, Field(strict=True)] | str -AnyFunction: TypeAlias = Callable[..., Any] - - -class RequestParams(BaseModel): - class Meta(BaseModel): - progressToken: ProgressToken | None = None - """ - If specified, the caller requests out-of-band progress notifications for - this request (as represented by notifications/progress). The value of this - parameter is an opaque token that will be attached to any subsequent - notifications. The receiver is not obligated to provide these notifications. - """ - - model_config = ConfigDict(extra="allow") - - meta: Meta | None = Field(alias="_meta", default=None) - - -class PaginatedRequestParams(RequestParams): - cursor: Cursor | None = None - """ - An opaque token representing the current pagination position. - If provided, the server should return results starting after this cursor. - """ - - -class NotificationParams(BaseModel): - class Meta(BaseModel): - model_config = ConfigDict(extra="allow") - - meta: Meta | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - - -RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None) -NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None) -MethodT = TypeVar("MethodT", bound=str) - - -class Request(BaseModel, Generic[RequestParamsT, MethodT]): - """Base class for JSON-RPC requests.""" - - method: MethodT - params: RequestParamsT - model_config = ConfigDict(extra="allow") - - -class PaginatedRequest(Request[PaginatedRequestParams | None, MethodT], Generic[MethodT]): - """Base class for paginated requests, - matching the schema's PaginatedRequest interface.""" - - params: PaginatedRequestParams | None = None - - -class Notification(BaseModel, Generic[NotificationParamsT, MethodT]): - """Base class for JSON-RPC notifications.""" - - method: MethodT - params: NotificationParamsT - model_config = ConfigDict(extra="allow") - - -class Result(BaseModel): - """Base class for JSON-RPC results.""" - - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class PaginatedResult(Result): - nextCursor: Cursor | None = None - """ - An opaque token representing the pagination position after the last returned result. - If present, there may be more results available. - """ - - -class JSONRPCRequest(Request[dict[str, Any] | None, str]): - """A request that expects a response.""" - - jsonrpc: Literal["2.0"] - id: RequestId - method: str - params: dict[str, Any] | None = None - - -class JSONRPCNotification(Notification[dict[str, Any] | None, str]): - """A notification which does not expect a response.""" - - jsonrpc: Literal["2.0"] - params: dict[str, Any] | None = None - - -class JSONRPCResponse(BaseModel): - """A successful (non-error) response to a request.""" - - jsonrpc: Literal["2.0"] - id: RequestId - result: dict[str, Any] - model_config = ConfigDict(extra="allow") - - -# SDK error codes -CONNECTION_CLOSED = -32000 -# REQUEST_TIMEOUT = -32001 # the typescript sdk uses this - -# Standard JSON-RPC error codes -PARSE_ERROR = -32700 -INVALID_REQUEST = -32600 -METHOD_NOT_FOUND = -32601 -INVALID_PARAMS = -32602 -INTERNAL_ERROR = -32603 - - -class ErrorData(BaseModel): - """Error information for JSON-RPC error responses.""" - - code: int - """The error type that occurred.""" - - message: str - """ - A short description of the error. The message SHOULD be limited to a concise single - sentence. - """ - - data: Any | None = None - """ - Additional information about the error. The value of this member is defined by the - sender (e.g. detailed error information, nested errors etc.). - """ - - model_config = ConfigDict(extra="allow") - - -class JSONRPCError(BaseModel): - """A response to a request that indicates an error occurred.""" - - jsonrpc: Literal["2.0"] - id: str | int - error: ErrorData - model_config = ConfigDict(extra="allow") - - -class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): - pass - - -class EmptyResult(Result): - """A response that indicates success but carries no data.""" - - -class BaseMetadata(BaseModel): - """Base class for entities with name and optional title fields.""" - - name: str - """The programmatic name of the entity.""" - - title: str | None = None - """ - Intended for UI and end-user contexts — optimized to be human-readable and easily understood, - even by those unfamiliar with domain-specific terminology. - - If not provided, the name should be used for display (except for Tool, - where `annotations.title` should be given precedence over using `name`, - if present). - """ - - -class Implementation(BaseMetadata): - """Describes the name and version of an MCP implementation.""" - - version: str - model_config = ConfigDict(extra="allow") - - -class RootsCapability(BaseModel): - """Capability for root operations.""" - - listChanged: bool | None = None - """Whether the client supports notifications for changes to the roots list.""" - model_config = ConfigDict(extra="allow") - - -class SamplingCapability(BaseModel): - """Capability for sampling operations.""" - - model_config = ConfigDict(extra="allow") - - -class ElicitationCapability(BaseModel): - """Capability for elicitation operations.""" - - model_config = ConfigDict(extra="allow") - - -class ClientCapabilities(BaseModel): - """Capabilities a client may support.""" - - experimental: dict[str, dict[str, Any]] | None = None - """Experimental, non-standard capabilities that the client supports.""" - sampling: SamplingCapability | None = None - """Present if the client supports sampling from an LLM.""" - elicitation: ElicitationCapability | None = None - """Present if the client supports elicitation from the user.""" - roots: RootsCapability | None = None - """Present if the client supports listing roots.""" - model_config = ConfigDict(extra="allow") - - -class PromptsCapability(BaseModel): - """Capability for prompts operations.""" - - listChanged: bool | None = None - """Whether this server supports notifications for changes to the prompt list.""" - model_config = ConfigDict(extra="allow") - - -class ResourcesCapability(BaseModel): - """Capability for resources operations.""" - - subscribe: bool | None = None - """Whether this server supports subscribing to resource updates.""" - listChanged: bool | None = None - """Whether this server supports notifications for changes to the resource list.""" - model_config = ConfigDict(extra="allow") - - -class ToolsCapability(BaseModel): - """Capability for tools operations.""" - - listChanged: bool | None = None - """Whether this server supports notifications for changes to the tool list.""" - model_config = ConfigDict(extra="allow") - - -class LoggingCapability(BaseModel): - """Capability for logging operations.""" - - model_config = ConfigDict(extra="allow") - - -class CompletionsCapability(BaseModel): - """Capability for completions operations.""" - - model_config = ConfigDict(extra="allow") - - -class ServerCapabilities(BaseModel): - """Capabilities that a server may support.""" - - experimental: dict[str, dict[str, Any]] | None = None - """Experimental, non-standard capabilities that the server supports.""" - logging: LoggingCapability | None = None - """Present if the server supports sending log messages to the client.""" - prompts: PromptsCapability | None = None - """Present if the server offers any prompt templates.""" - resources: ResourcesCapability | None = None - """Present if the server offers any resources to read.""" - tools: ToolsCapability | None = None - """Present if the server offers any tools to call.""" - completions: CompletionsCapability | None = None - """Present if the server offers autocompletion suggestions for prompts and resources.""" - model_config = ConfigDict(extra="allow") - - -class InitializeRequestParams(RequestParams): - """Parameters for the initialize request.""" - - protocolVersion: str | int - """The latest version of the Model Context Protocol that the client supports.""" - capabilities: ClientCapabilities - clientInfo: Implementation - model_config = ConfigDict(extra="allow") - - -class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]): - """ - This request is sent from the client to the server when it first connects, asking it - to begin initialization. - """ - - method: Literal["initialize"] - params: InitializeRequestParams - - -class InitializeResult(Result): - """After receiving an initialize request from the client, the server sends this.""" - - protocolVersion: str | int - """The version of the Model Context Protocol that the server wants to use.""" - capabilities: ServerCapabilities - serverInfo: Implementation - instructions: str | None = None - """Instructions describing how to use the server and its features.""" - - -class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]): - """ - This notification is sent from the client to the server after initialization has - finished. - """ - - method: Literal["notifications/initialized"] - params: NotificationParams | None = None - - -class PingRequest(Request[RequestParams | None, Literal["ping"]]): - """ - A ping, issued by either the server or the client, to check that the other party is - still alive. - """ - - method: Literal["ping"] - params: RequestParams | None = None - - -class ProgressNotificationParams(NotificationParams): - """Parameters for progress notifications.""" - - progressToken: ProgressToken - """ - The progress token which was given in the initial request, used to associate this - notification with the request that is proceeding. - """ - progress: float - """ - The progress thus far. This should increase every time progress is made, even if the - total is unknown. - """ - total: float | None = None - """Total number of items to process (or total progress required), if known.""" - message: str | None = None - """ - Message related to progress. This should provide relevant human readable - progress information. - """ - model_config = ConfigDict(extra="allow") - - -class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]): - """ - An out-of-band notification used to inform the receiver of a progress update for a - long-running request. - """ - - method: Literal["notifications/progress"] - params: ProgressNotificationParams - - -class ListResourcesRequest(PaginatedRequest[Literal["resources/list"]]): - """Sent from the client to request a list of resources the server has.""" - - method: Literal["resources/list"] - - -class Annotations(BaseModel): - audience: list[Role] | None = None - priority: Annotated[float, Field(ge=0.0, le=1.0)] | None = None - model_config = ConfigDict(extra="allow") - - -class Resource(BaseMetadata): - """A known resource that the server is capable of reading.""" - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] - """The URI of this resource.""" - description: str | None = None - """A description of what this resource represents.""" - mimeType: str | None = None - """The MIME type of this resource, if known.""" - size: int | None = None - """ - The size of the raw resource content, in bytes (i.e., before base64 encoding - or any tokenization), if known. - - This can be used by Hosts to display file sizes and estimate context window usage. - """ - annotations: Annotations | None = None - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ResourceTemplate(BaseMetadata): - """A template description for resources available on the server.""" - - uriTemplate: str - """ - A URI template (according to RFC 6570) that can be used to construct resource - URIs. - """ - description: str | None = None - """A human-readable description of what this template is for.""" - mimeType: str | None = None - """ - The MIME type for all resources that match this template. This should only be - included if all resources matching this template have the same type. - """ - annotations: Annotations | None = None - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ListResourcesResult(PaginatedResult): - """The server's response to a resources/list request from the client.""" - - resources: list[Resource] - - -class ListResourceTemplatesRequest(PaginatedRequest[Literal["resources/templates/list"]]): - """Sent from the client to request a list of resource templates the server has.""" - - method: Literal["resources/templates/list"] - - -class ListResourceTemplatesResult(PaginatedResult): - """The server's response to a resources/templates/list request from the client.""" - - resourceTemplates: list[ResourceTemplate] - - -class ReadResourceRequestParams(RequestParams): - """Parameters for reading a resource.""" - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] - """ - The URI of the resource to read. The URI can use any protocol; it is up to the - server how to interpret it. - """ - model_config = ConfigDict(extra="allow") - - -class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]): - """Sent from the client to the server, to read a specific resource URI.""" - - method: Literal["resources/read"] - params: ReadResourceRequestParams - - -class ResourceContents(BaseModel): - """The contents of a specific resource or sub-resource.""" - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] - """The URI of this resource.""" - mimeType: str | None = None - """The MIME type of this resource, if known.""" - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class TextResourceContents(ResourceContents): - """Text contents of a resource.""" - - text: str - """ - The text of the item. This must only be set if the item can actually be represented - as text (not binary data). - """ - - -class BlobResourceContents(ResourceContents): - """Binary contents of a resource.""" - - blob: str - """A base64-encoded string representing the binary data of the item.""" - - -class ReadResourceResult(Result): - """The server's response to a resources/read request from the client.""" - - contents: list[TextResourceContents | BlobResourceContents] - - -class ResourceListChangedNotification( - Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]] -): - """ - An optional notification from the server to the client, informing it that the list - of resources it can read from has changed. - """ - - method: Literal["notifications/resources/list_changed"] - params: NotificationParams | None = None - - -class SubscribeRequestParams(RequestParams): - """Parameters for subscribing to a resource.""" - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] - """ - The URI of the resource to subscribe to. The URI can use any protocol; it is up to - the server how to interpret it. - """ - model_config = ConfigDict(extra="allow") - - -class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]): - """ - Sent from the client to request resources/updated notifications from the server - whenever a particular resource changes. - """ - - method: Literal["resources/subscribe"] - params: SubscribeRequestParams - - -class UnsubscribeRequestParams(RequestParams): - """Parameters for unsubscribing from a resource.""" - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] - """The URI of the resource to unsubscribe from.""" - model_config = ConfigDict(extra="allow") - - -class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]): - """ - Sent from the client to request cancellation of resources/updated notifications from - the server. - """ - - method: Literal["resources/unsubscribe"] - params: UnsubscribeRequestParams - - -class ResourceUpdatedNotificationParams(NotificationParams): - """Parameters for resource update notifications.""" - - uri: Annotated[AnyUrl, UrlConstraints(host_required=False)] - """ - The URI of the resource that has been updated. This might be a sub-resource of the - one that the client actually subscribed to. - """ - model_config = ConfigDict(extra="allow") - - -class ResourceUpdatedNotification( - Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]] -): - """ - A notification from the server to the client, informing it that a resource has - changed and may need to be read again. - """ - - method: Literal["notifications/resources/updated"] - params: ResourceUpdatedNotificationParams - - -class ListPromptsRequest(PaginatedRequest[Literal["prompts/list"]]): - """Sent from the client to request a list of prompts and prompt templates.""" - - method: Literal["prompts/list"] - - -class PromptArgument(BaseModel): - """An argument for a prompt template.""" - - name: str - """The name of the argument.""" - description: str | None = None - """A human-readable description of the argument.""" - required: bool | None = None - """Whether this argument must be provided.""" - model_config = ConfigDict(extra="allow") - - -class Prompt(BaseMetadata): - """A prompt or prompt template that the server offers.""" - - description: str | None = None - """An optional description of what this prompt provides.""" - arguments: list[PromptArgument] | None = None - """A list of arguments to use for templating the prompt.""" - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ListPromptsResult(PaginatedResult): - """The server's response to a prompts/list request from the client.""" - - prompts: list[Prompt] - - -class GetPromptRequestParams(RequestParams): - """Parameters for getting a prompt.""" - - name: str - """The name of the prompt or prompt template.""" - arguments: dict[str, str] | None = None - """Arguments to use for templating the prompt.""" - model_config = ConfigDict(extra="allow") - - -class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]): - """Used by the client to get a prompt provided by the server.""" - - method: Literal["prompts/get"] - params: GetPromptRequestParams - - -class TextContent(BaseModel): - """Text content for a message.""" - - type: Literal["text"] - text: str - """The text content of the message.""" - annotations: Annotations | None = None - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ImageContent(BaseModel): - """Image content for a message.""" - - type: Literal["image"] - data: str - """The base64-encoded image data.""" - mimeType: str - """ - The MIME type of the image. Different providers may support different - image types. - """ - annotations: Annotations | None = None - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class AudioContent(BaseModel): - """Audio content for a message.""" - - type: Literal["audio"] - data: str - """The base64-encoded audio data.""" - mimeType: str - """ - The MIME type of the audio. Different providers may support different - audio types. - """ - annotations: Annotations | None = None - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class SamplingMessage(BaseModel): - """Describes a message issued to or received from an LLM API.""" - - role: Role - content: TextContent | ImageContent | AudioContent - model_config = ConfigDict(extra="allow") - - -class EmbeddedResource(BaseModel): - """ - The contents of a resource, embedded into a prompt or tool call result. - - It is up to the client how best to render embedded resources for the benefit - of the LLM and/or the user. - """ - - type: Literal["resource"] - resource: TextResourceContents | BlobResourceContents - annotations: Annotations | None = None - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ResourceLink(Resource): - """ - A resource that the server is capable of reading, included in a prompt or tool call result. - - Note: resource links returned by tools are not guaranteed to appear in the results of `resources/list` requests. - """ - - type: Literal["resource_link"] - - -ContentBlock = TextContent | ImageContent | AudioContent | ResourceLink | EmbeddedResource -"""A content block that can be used in prompts and tool results.""" - -Content: TypeAlias = ContentBlock -# """DEPRECATED: Content is deprecated, you should use ContentBlock directly.""" - - -class PromptMessage(BaseModel): - """Describes a message returned as part of a prompt.""" - - role: Role - content: ContentBlock - model_config = ConfigDict(extra="allow") - - -class GetPromptResult(Result): - """The server's response to a prompts/get request from the client.""" - - description: str | None = None - """An optional description for the prompt.""" - messages: list[PromptMessage] - - -class PromptListChangedNotification( - Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]] -): - """ - An optional notification from the server to the client, informing it that the list - of prompts it offers has changed. - """ - - method: Literal["notifications/prompts/list_changed"] - params: NotificationParams | None = None - - -class ListToolsRequest(PaginatedRequest[Literal["tools/list"]]): - """Sent from the client to request a list of tools the server has.""" - - method: Literal["tools/list"] - - -class ToolAnnotations(BaseModel): - """ - Additional properties describing a Tool to clients. - - NOTE: all properties in ToolAnnotations are **hints**. - They are not guaranteed to provide a faithful description of - tool behavior (including descriptive properties like `title`). - - Clients should never make tool use decisions based on ToolAnnotations - received from untrusted servers. - """ - - title: str | None = None - """A human-readable title for the tool.""" - - readOnlyHint: bool | None = None - """ - If true, the tool does not modify its environment. - Default: false - """ - - destructiveHint: bool | None = None - """ - If true, the tool may perform destructive updates to its environment. - If false, the tool performs only additive updates. - (This property is meaningful only when `readOnlyHint == false`) - Default: true - """ - - idempotentHint: bool | None = None - """ - If true, calling the tool repeatedly with the same arguments - will have no additional effect on the its environment. - (This property is meaningful only when `readOnlyHint == false`) - Default: false - """ - - openWorldHint: bool | None = None - """ - If true, this tool may interact with an "open world" of external - entities. If false, the tool's domain of interaction is closed. - For example, the world of a web search tool is open, whereas that - of a memory tool is not. - Default: true - """ - model_config = ConfigDict(extra="allow") - - -class Tool(BaseMetadata): - """Definition for a tool the client can call.""" - - description: str | None = None - """A human-readable description of the tool.""" - inputSchema: dict[str, Any] - """A JSON Schema object defining the expected parameters for the tool.""" - outputSchema: dict[str, Any] | None = None - """ - An optional JSON Schema object defining the structure of the tool's output - returned in the structuredContent field of a CallToolResult. - """ - annotations: ToolAnnotations | None = None - """Optional additional tool information.""" - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ListToolsResult(PaginatedResult): - """The server's response to a tools/list request from the client.""" - - tools: list[Tool] - - -class CallToolRequestParams(RequestParams): - """Parameters for calling a tool.""" - - name: str - arguments: dict[str, Any] | None = None - model_config = ConfigDict(extra="allow") - - -class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]): - """Used by the client to invoke a tool provided by the server.""" - - method: Literal["tools/call"] - params: CallToolRequestParams - - -class CallToolResult(Result): - """The server's response to a tool call.""" - - content: list[ContentBlock] - structuredContent: dict[str, Any] | None = None - """An optional JSON object that represents the structured result of the tool call.""" - isError: bool = False - - -class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]): - """ - An optional notification from the server to the client, informing it that the list - of tools it offers has changed. - """ - - method: Literal["notifications/tools/list_changed"] - params: NotificationParams | None = None - - -LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"] - - -class SetLevelRequestParams(RequestParams): - """Parameters for setting the logging level.""" - - level: LoggingLevel - """The level of logging that the client wants to receive from the server.""" - model_config = ConfigDict(extra="allow") - - -class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]): - """A request from the client to the server, to enable or adjust logging.""" - - method: Literal["logging/setLevel"] - params: SetLevelRequestParams - - -class LoggingMessageNotificationParams(NotificationParams): - """Parameters for logging message notifications.""" - - level: LoggingLevel - """The severity of this log message.""" - logger: str | None = None - """An optional name of the logger issuing this message.""" - data: Any - """ - The data to be logged, such as a string message or an object. Any JSON serializable - type is allowed here. - """ - model_config = ConfigDict(extra="allow") - - -class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]): - """Notification of a log message passed from server to client.""" - - method: Literal["notifications/message"] - params: LoggingMessageNotificationParams - - -IncludeContext = Literal["none", "thisServer", "allServers"] - - -class ModelHint(BaseModel): - """Hints to use for model selection.""" - - name: str | None = None - """A hint for a model name.""" - - model_config = ConfigDict(extra="allow") - - -class ModelPreferences(BaseModel): - """ - The server's preferences for model selection, requested by the client during - sampling. - - Because LLMs can vary along multiple dimensions, choosing the "best" model is - rarely straightforward. Different models excel in different areas—some are - faster but less capable, others are more capable but more expensive, and so - on. This interface allows servers to express their priorities across multiple - dimensions to help clients make an appropriate selection for their use case. - - These preferences are always advisory. The client MAY ignore them. It is also - up to the client to decide how to interpret these preferences and how to - balance them against other considerations. - """ - - hints: list[ModelHint] | None = None - """ - Optional hints to use for model selection. - - If multiple hints are specified, the client MUST evaluate them in order - (such that the first match is taken). - - The client SHOULD prioritize these hints over the numeric priorities, but - MAY still use the priorities to select from ambiguous matches. - """ - - costPriority: float | None = None - """ - How much to prioritize cost when selecting a model. A value of 0 means cost - is not important, while a value of 1 means cost is the most important - factor. - """ - - speedPriority: float | None = None - """ - How much to prioritize sampling speed (latency) when selecting a model. A - value of 0 means speed is not important, while a value of 1 means speed is - the most important factor. - """ - - intelligencePriority: float | None = None - """ - How much to prioritize intelligence and capabilities when selecting a - model. A value of 0 means intelligence is not important, while a value of 1 - means intelligence is the most important factor. - """ - - model_config = ConfigDict(extra="allow") - - -class CreateMessageRequestParams(RequestParams): - """Parameters for creating a message.""" - - messages: list[SamplingMessage] - modelPreferences: ModelPreferences | None = None - """ - The server's preferences for which model to select. The client MAY ignore - these preferences. - """ - systemPrompt: str | None = None - """An optional system prompt the server wants to use for sampling.""" - includeContext: IncludeContext | None = None - """ - A request to include context from one or more MCP servers (including the caller), to - be attached to the prompt. - """ - temperature: float | None = None - maxTokens: int - """The maximum number of tokens to sample, as requested by the server.""" - stopSequences: list[str] | None = None - metadata: dict[str, Any] | None = None - """Optional metadata to pass through to the LLM provider.""" - model_config = ConfigDict(extra="allow") - - -class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]): - """A request from the server to sample an LLM via the client.""" - - method: Literal["sampling/createMessage"] - params: CreateMessageRequestParams - - -StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str - - -class CreateMessageResult(Result): - """The client's response to a sampling/create_message request from the server.""" - - role: Role - content: TextContent | ImageContent | AudioContent - model: str - """The name of the model that generated the message.""" - stopReason: StopReason | None = None - """The reason why sampling stopped, if known.""" - - -class ResourceTemplateReference(BaseModel): - """A reference to a resource or resource template definition.""" - - type: Literal["ref/resource"] - uri: str - """The URI or URI template of the resource.""" - model_config = ConfigDict(extra="allow") - - -@deprecated("`ResourceReference` is deprecated, you should use `ResourceTemplateReference`.") -class ResourceReference(ResourceTemplateReference): - pass - - -class PromptReference(BaseModel): - """Identifies a prompt.""" - - type: Literal["ref/prompt"] - name: str - """The name of the prompt or prompt template""" - model_config = ConfigDict(extra="allow") - - -class CompletionArgument(BaseModel): - """The argument's information for completion requests.""" - - name: str - """The name of the argument""" - value: str - """The value of the argument to use for completion matching.""" - model_config = ConfigDict(extra="allow") - - -class CompletionContext(BaseModel): - """Additional, optional context for completions.""" - - arguments: dict[str, str] | None = None - """Previously-resolved variables in a URI template or prompt.""" - model_config = ConfigDict(extra="allow") - - -class CompleteRequestParams(RequestParams): - """Parameters for completion requests.""" - - ref: ResourceTemplateReference | PromptReference - argument: CompletionArgument - context: CompletionContext | None = None - """Additional, optional context for completions""" - model_config = ConfigDict(extra="allow") - - -class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]): - """A request from the client to the server, to ask for completion options.""" - - method: Literal["completion/complete"] - params: CompleteRequestParams - - -class Completion(BaseModel): - """Completion information.""" - - values: list[str] - """An array of completion values. Must not exceed 100 items.""" - total: int | None = None - """ - The total number of completion options available. This can exceed the number of - values actually sent in the response. - """ - hasMore: bool | None = None - """ - Indicates whether there are additional completion options beyond those provided in - the current response, even if the exact total is unknown. - """ - model_config = ConfigDict(extra="allow") - - -class CompleteResult(Result): - """The server's response to a completion/complete request""" - - completion: Completion - - -class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]): - """ - Sent from the server to request a list of root URIs from the client. Roots allow - servers to ask for specific directories or files to operate on. A common example - for roots is providing a set of repositories or directories a server should operate - on. - - This request is typically used when the server needs to understand the file system - structure or access specific locations that the client has permission to read from. - """ - - method: Literal["roots/list"] - params: RequestParams | None = None - - -class Root(BaseModel): - """Represents a root directory or file that the server can operate on.""" - - uri: FileUrl - """ - The URI identifying the root. This *must* start with file:// for now. - This restriction may be relaxed in future versions of the protocol to allow - other URI schemes. - """ - name: str | None = None - """ - An optional name for the root. This can be used to provide a human-readable - identifier for the root, which may be useful for display purposes or for - referencing the root in other parts of the application. - """ - meta: dict[str, Any] | None = Field(alias="_meta", default=None) - """ - See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - for notes on _meta usage. - """ - model_config = ConfigDict(extra="allow") - - -class ListRootsResult(Result): - """ - The client's response to a roots/list request from the server. - This result contains an array of Root objects, each representing a root directory - or file that the server can operate on. - """ - - roots: list[Root] - - -class RootsListChangedNotification( - Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]] -): - """ - A notification from the client to the server, informing it that the list of - roots has changed. - - This notification should be sent whenever the client adds, removes, or - modifies any root. The server should then request an updated list of roots - using the ListRootsRequest. - """ - - method: Literal["notifications/roots/list_changed"] - params: NotificationParams | None = None - - -class CancelledNotificationParams(NotificationParams): - """Parameters for cancellation notifications.""" - - requestId: RequestId - """The ID of the request to cancel.""" - reason: str | None = None - """An optional string describing the reason for the cancellation.""" - model_config = ConfigDict(extra="allow") - - -class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]): - """ - This notification can be sent by either side to indicate that it is canceling a - previously-issued request. - """ - - method: Literal["notifications/cancelled"] - params: CancelledNotificationParams - - -class ClientRequest( - RootModel[ - PingRequest - | InitializeRequest - | CompleteRequest - | SetLevelRequest - | GetPromptRequest - | ListPromptsRequest - | ListResourcesRequest - | ListResourceTemplatesRequest - | ReadResourceRequest - | SubscribeRequest - | UnsubscribeRequest - | CallToolRequest - | ListToolsRequest - ] -): - pass - - -class ClientNotification( - RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification] -): - pass - - -# Type for elicitation schema - a JSON Schema dict -ElicitRequestedSchema: TypeAlias = dict[str, Any] -"""Schema for elicitation requests.""" - - -class ElicitRequestParams(RequestParams): - """Parameters for elicitation requests.""" - - message: str - requestedSchema: ElicitRequestedSchema - model_config = ConfigDict(extra="allow") - - -class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): - """A request from the server to elicit information from the client.""" - - method: Literal["elicitation/create"] - params: ElicitRequestParams - - -class ElicitResult(Result): - """The client's response to an elicitation request.""" - - action: Literal["accept", "decline", "cancel"] - """ - The user action in response to the elicitation. - - "accept": User submitted the form/confirmed the action - - "decline": User explicitly declined the action - - "cancel": User dismissed without making an explicit choice - """ - - content: dict[str, str | int | float | bool | None] | None = None - """ - The submitted form data, only present when action is "accept". - Contains values matching the requested schema. - """ - - -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]): - pass - - -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]): - pass - - -class ServerNotification( - RootModel[ - CancelledNotification - | ProgressNotification - | LoggingMessageNotification - | ResourceUpdatedNotification - | ResourceListChangedNotification - | ToolListChangedNotification - | PromptListChangedNotification - ] -): - pass - - -class ServerResult( - RootModel[ - EmptyResult - | InitializeResult - | CompleteResult - | GetPromptResult - | ListPromptsResult - | ListResourcesResult - | ListResourceTemplatesResult - | ReadResourceResult - | CallToolResult - | ListToolsResult - ] -): - pass diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index b84db89a2..a7b108602 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -281,7 +281,7 @@ async def revoke_token( def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: parsed_uri = urlparse(redirect_uri_base) - query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] + query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs] for k, v in params.items(): if v is not None: query_params.append((k, v)) diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index e4db806e7..bce32df52 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -190,6 +190,8 @@ def create_protected_resource_routes( resource_url: AnyHttpUrl, authorization_servers: list[AnyHttpUrl], scopes_supported: list[str] | None = None, + resource_name: str | None = None, + resource_documentation: AnyHttpUrl | None = None, ) -> list[Route]: """ Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728). @@ -209,6 +211,8 @@ def create_protected_resource_routes( resource=resource_url, authorization_servers=authorization_servers, scopes_supported=scopes_supported, + resource_name=resource_name, + resource_documentation=resource_documentation, # bearer_methods_supported defaults to ["header"] in the model ) diff --git a/src/mcp/server/fastmcp/__init__.py b/src/mcp/server/fastmcp/__init__.py index 84b052078..f8f9c1c4c 100644 --- a/src/mcp/server/fastmcp/__init__.py +++ b/src/mcp/server/fastmcp/__init__.py @@ -3,7 +3,7 @@ from importlib.metadata import version from .server import Context, FastMCP -from .utilities.types import Image +from .utilities.types import Audio, Image __version__ = version("mcp") -__all__ = ["FastMCP", "Context", "Image"] +__all__ = ["FastMCP", "Context", "Image", "Audio"] diff --git a/src/mcp/server/fastmcp/prompts/prompt_manager.py b/src/mcp/server/fastmcp/prompts/prompt_manager.py deleted file mode 100644 index 389e89624..000000000 --- a/src/mcp/server/fastmcp/prompts/prompt_manager.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Prompt management functionality.""" - -from mcp.server.fastmcp.prompts.base import Prompt -from mcp.server.fastmcp.utilities.logging import get_logger - -logger = get_logger(__name__) - - -class PromptManager: - """Manages FastMCP prompts.""" - - def __init__(self, warn_on_duplicate_prompts: bool = True): - self._prompts: dict[str, Prompt] = {} - self.warn_on_duplicate_prompts = warn_on_duplicate_prompts - - def add_prompt(self, prompt: Prompt) -> Prompt: - """Add a prompt to the manager.""" - logger.debug(f"Adding prompt: {prompt.name}") - existing = self._prompts.get(prompt.name) - if existing: - if self.warn_on_duplicate_prompts: - logger.warning(f"Prompt already exists: {prompt.name}") - return existing - self._prompts[prompt.name] = prompt - return prompt - - def get_prompt(self, name: str) -> Prompt | None: - """Get prompt by name.""" - return self._prompts.get(name) - - def list_prompts(self) -> list[Prompt]: - """List all registered prompts.""" - return list(self._prompts.values()) diff --git a/src/mcp/mcp/server/fastmcp/secure/__init__.py b/src/mcp/server/fastmcp/secure/__init__.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/__init__.py rename to src/mcp/server/fastmcp/secure/__init__.py diff --git a/src/mcp/mcp/server/fastmcp/secure/annotations.py b/src/mcp/server/fastmcp/secure/annotations.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/annotations.py rename to src/mcp/server/fastmcp/secure/annotations.py diff --git a/src/mcp/mcp/server/fastmcp/secure/identity.py b/src/mcp/server/fastmcp/secure/identity.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/identity.py rename to src/mcp/server/fastmcp/secure/identity.py diff --git a/src/mcp/mcp/server/fastmcp/secure/prompt.py b/src/mcp/server/fastmcp/secure/prompt.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/prompt.py rename to src/mcp/server/fastmcp/secure/prompt.py diff --git a/src/mcp/mcp/server/fastmcp/secure/resource.py b/src/mcp/server/fastmcp/secure/resource.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/resource.py rename to src/mcp/server/fastmcp/secure/resource.py diff --git a/src/mcp/mcp/server/fastmcp/secure/session.py b/src/mcp/server/fastmcp/secure/session.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/session.py rename to src/mcp/server/fastmcp/secure/session.py diff --git a/src/mcp/mcp/server/fastmcp/secure/tool.py b/src/mcp/server/fastmcp/secure/tool.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/tool.py rename to src/mcp/server/fastmcp/secure/tool.py diff --git a/src/mcp/mcp/server/fastmcp/secure/utils.py b/src/mcp/server/fastmcp/secure/utils.py similarity index 100% rename from src/mcp/mcp/server/fastmcp/secure/utils.py rename to src/mcp/server/fastmcp/secure/utils.py diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 2fe7c1224..924baaa9b 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -4,16 +4,13 @@ import inspect import re -from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence -from contextlib import ( - AbstractAsyncContextManager, - asynccontextmanager, -) +from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence +from contextlib import AbstractAsyncContextManager, asynccontextmanager from typing import Any, Generic, Literal import anyio import pydantic_core -from pydantic import BaseModel, Field +from pydantic import BaseModel from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.applications import Starlette @@ -25,10 +22,7 @@ from starlette.types import Receive, Scope, Send from mcp.server.auth.middleware.auth_context import AuthContextMiddleware -from mcp.server.auth.middleware.bearer_auth import ( - BearerAuthBackend, - RequireAuthMiddleware, -) +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier from mcp.server.auth.settings import AuthSettings from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation @@ -48,12 +42,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.context import LifespanContextT, RequestContext, RequestT -from mcp.types import ( - AnyFunction, - ContentBlock, - GetPromptResult, - ToolAnnotations, -) +from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument from mcp.types import Resource as MCPResource @@ -79,58 +68,57 @@ class Settings(BaseSettings, Generic[LifespanResultT]): ) # Server settings - debug: bool = False - log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO" + debug: bool + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] # HTTP settings - host: str = "127.0.0.1" - port: int = 8000 - mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path) - sse_path: str = "/sse" - message_path: str = "/messages/" - streamable_http_path: str = "/mcp" + host: str + port: int + mount_path: str + sse_path: str + message_path: str + streamable_http_path: str # StreamableHTTP settings - json_response: bool = False - stateless_http: bool = False # If True, uses true stateless mode (new transport per request) + json_response: bool + stateless_http: bool + """Define if the server should create a new transport per request.""" # resource settings - warn_on_duplicate_resources: bool = True + warn_on_duplicate_resources: bool # tool settings - warn_on_duplicate_tools: bool = True + warn_on_duplicate_tools: bool # prompt settings - warn_on_duplicate_prompts: bool = True + warn_on_duplicate_prompts: bool - dependencies: list[str] = Field( - default_factory=list, - description="List of dependencies to install in the server environment", - ) + # TODO(Marcelo): Investigate if this is used. If it is, it's probably a good idea to remove it. + dependencies: list[str] + """A list of dependencies to install in the server environment.""" - lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field( - None, description="Lifespan context manager" - ) + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None + """A async context manager that will be called when the server is started.""" - auth: AuthSettings | None = None + auth: AuthSettings | None # Transport security settings (DNS rebinding protection) - transport_security: TransportSecuritySettings | None = None + transport_security: TransportSecuritySettings | None def lifespan_wrapper( - app: FastMCP, - lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]], -) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]: + app: FastMCP[LifespanResultT], + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]], +) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]: @asynccontextmanager - async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]: + async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]: async with lifespan(app) as context: yield context return wrap -class FastMCP: +class FastMCP(Generic[LifespanResultT]): def __init__( self, name: str | None = None, @@ -140,14 +128,50 @@ def __init__( event_store: EventStore | None = None, *, tools: list[Tool] | None = None, - **settings: Any, + debug: bool = False, + log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO", + host: str = "127.0.0.1", + port: int = 8000, + mount_path: str = "/", + sse_path: str = "/sse", + message_path: str = "/messages/", + streamable_http_path: str = "/mcp", + json_response: bool = False, + stateless_http: bool = False, + warn_on_duplicate_resources: bool = True, + warn_on_duplicate_tools: bool = True, + warn_on_duplicate_prompts: bool = True, + dependencies: Collection[str] = (), + lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None, + auth: AuthSettings | None = None, + transport_security: TransportSecuritySettings | None = None, ): - self.settings = Settings(**settings) + self.settings = Settings( + debug=debug, + log_level=log_level, + host=host, + port=port, + mount_path=mount_path, + sse_path=sse_path, + message_path=message_path, + streamable_http_path=streamable_http_path, + json_response=json_response, + stateless_http=stateless_http, + warn_on_duplicate_resources=warn_on_duplicate_resources, + warn_on_duplicate_tools=warn_on_duplicate_tools, + warn_on_duplicate_prompts=warn_on_duplicate_prompts, + dependencies=list(dependencies), + lifespan=lifespan, + auth=auth, + transport_security=transport_security, + ) self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, - lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), + # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. + # We need to create a Lifespan type that is a generic on the server type, like Starlette does. + lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore ) self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) @@ -257,7 +281,7 @@ async def list_tools(self) -> list[MCPTool]: for info in tools ] - def get_context(self) -> Context[ServerSession, object, Request]: + def get_context(self) -> Context[ServerSession, LifespanResultT, Request]: """ Returns a Context object. Note that the context will only be valid during a request; outside a request, most methods will error. diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index 3d8827ed7..f50126081 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -73,7 +73,7 @@ def from_function( skip_names=[context_kwarg] if context_kwarg is not None else [], structured_output=structured_output, ) - parameters = func_arg_metadata.arg_model.model_json_schema() + parameters = func_arg_metadata.arg_model.model_json_schema(by_alias=True) return cls( fn=fn, diff --git a/src/mcp/server/fastmcp/utilities/func_metadata.py b/src/mcp/server/fastmcp/utilities/func_metadata.py index 3c9506fba..a4cb8ac5b 100644 --- a/src/mcp/server/fastmcp/utilities/func_metadata.py +++ b/src/mcp/server/fastmcp/utilities/func_metadata.py @@ -21,7 +21,7 @@ from mcp.server.fastmcp.exceptions import InvalidSignature from mcp.server.fastmcp.utilities.logging import get_logger -from mcp.server.fastmcp.utilities.types import Image +from mcp.server.fastmcp.utilities.types import Audio, Image from mcp.types import ContentBlock, TextContent logger = get_logger(__name__) @@ -47,8 +47,11 @@ def model_dump_one_level(self) -> dict[str, Any]: That is, sub-models etc are not dumped - they are kept as pydantic models. """ kwargs: dict[str, Any] = {} - for field_name in self.__class__.model_fields.keys(): - kwargs[field_name] = getattr(self, field_name) + for field_name, field_info in self.__class__.model_fields.items(): + value = getattr(self, field_name) + # Use the alias if it exists, otherwise use the field name + output_name = field_info.alias if field_info.alias else field_name + kwargs[output_name] = value return kwargs model_config = ConfigDict( @@ -127,12 +130,23 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: dicts (JSON objects) as JSON strings, which can be pre-parsed here. """ new_data = data.copy() # Shallow copy - for field_name in self.arg_model.model_fields.keys(): - if field_name not in data.keys(): + + # Build a mapping from input keys (including aliases) to field info + key_to_field_info: dict[str, FieldInfo] = {} + for field_name, field_info in self.arg_model.model_fields.items(): + # Map both the field name and its alias (if any) to the field info + key_to_field_info[field_name] = field_info + if field_info.alias: + key_to_field_info[field_info.alias] = field_info + + for data_key in data.keys(): + if data_key not in key_to_field_info: continue - if isinstance(data[field_name], str) and self.arg_model.model_fields[field_name].annotation is not str: + + field_info = key_to_field_info[data_key] + if isinstance(data[data_key], str) and field_info.annotation is not str: try: - pre_parsed = json.loads(data[field_name]) + pre_parsed = json.loads(data[data_key]) except json.JSONDecodeError: continue # Not JSON - skip if isinstance(pre_parsed, str | int | float): @@ -140,7 +154,7 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]: # Should really be parsed as '"hello"' in Python - but if we parse # it as JSON it'll turn into just 'hello'. So we skip it. continue - new_data[field_name] = pre_parsed + new_data[data_key] = pre_parsed assert new_data.keys() == data.keys() return new_data @@ -222,7 +236,19 @@ def func_metadata( _get_typed_annotation(annotation, globalns), param.default if param.default is not inspect.Parameter.empty else PydanticUndefined, ) - dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) + + # Check if the parameter name conflicts with BaseModel attributes + # This is necessary because Pydantic warns about shadowing parent attributes + if hasattr(BaseModel, param.name) and callable(getattr(BaseModel, param.name)): + # Use an alias to avoid the shadowing warning + field_info.alias = param.name + field_info.validation_alias = param.name + field_info.serialization_alias = param.name + # Use a prefixed internal name + internal_name = f"field_{param.name}" + dynamic_pydantic_model_params[internal_name] = (field_info.annotation, field_info) + else: + dynamic_pydantic_model_params[param.name] = (field_info.annotation, field_info) continue arguments_model = create_model( @@ -480,6 +506,9 @@ def _convert_to_content( if isinstance(result, Image): return [result.to_image_content()] + if isinstance(result, Audio): + return [result.to_audio_content()] + if isinstance(result, list | tuple): return list( chain.from_iterable( diff --git a/src/mcp/server/fastmcp/utilities/types.py b/src/mcp/server/fastmcp/utilities/types.py index ccaa3d69a..1be6f8274 100644 --- a/src/mcp/server/fastmcp/utilities/types.py +++ b/src/mcp/server/fastmcp/utilities/types.py @@ -3,7 +3,7 @@ import base64 from pathlib import Path -from mcp.types import ImageContent +from mcp.types import AudioContent, ImageContent class Image: @@ -52,3 +52,50 @@ def to_image_content(self) -> ImageContent: raise ValueError("No image data available") return ImageContent(type="image", data=data, mimeType=self._mime_type) + + +class Audio: + """Helper class for returning audio from tools.""" + + def __init__( + self, + path: str | Path | None = None, + data: bytes | None = None, + format: str | None = None, + ): + if not bool(path) ^ bool(data): + raise ValueError("Either path or data can be provided") + + self.path = Path(path) if path else None + self.data = data + self._format = format + self._mime_type = self._get_mime_type() + + def _get_mime_type(self) -> str: + """Get MIME type from format or guess from file extension.""" + if self._format: + return f"audio/{self._format.lower()}" + + if self.path: + suffix = self.path.suffix.lower() + return { + ".wav": "audio/wav", + ".mp3": "audio/mpeg", + ".ogg": "audio/ogg", + ".flac": "audio/flac", + ".aac": "audio/aac", + ".m4a": "audio/mp4", + }.get(suffix, "application/octet-stream") + return "audio/wav" # default for raw binary data + + def to_audio_content(self) -> AudioContent: + """Convert to MCP AudioContent.""" + if self.path: + with open(self.path, "rb") as f: + data = base64.b64encode(f.read()).decode() + elif self.data is not None: + data = base64.b64encode(self.data).decode() + else: + raise ValueError("No audio data available") + + return AudioContent(type="audio", data=data, mimeType=self._mime_type) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 562de31b7..8c459383c 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -93,7 +93,7 @@ async def main(): logger = logging.getLogger(__name__) -LifespanResultT = TypeVar("LifespanResultT") +LifespanResultT = TypeVar("LifespanResultT", default=Any) RequestT = TypeVar("RequestT", default=Any) # type aliases for tool call results @@ -118,7 +118,7 @@ def __init__( @asynccontextmanager -async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]: +async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Args: @@ -149,7 +149,6 @@ def __init__( types.PingRequest: _ping_handler, } self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} - self.notification_options = NotificationOptions() self._tool_cache: dict[str, types.Tool] = {} logger.debug("Initializing server %r", name) @@ -647,6 +646,12 @@ async def _handle_request( response = await handler(req) except McpError as err: response = err.error + except anyio.get_cancelled_exc_class(): + logger.info( + "Request %s cancelled - duplicate response suppressed", + message.request_id, + ) + return except Exception as err: if raise_exceptions: raise err diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 10a388b6c..802cb8680 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -612,7 +612,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: if not await self._validate_request_headers(request, send): return - await self._terminate_session() + await self.terminate() response = self._create_json_response( None, @@ -620,7 +620,7 @@ async def _handle_delete_request(self, request: Request, send: Send) -> None: ) await response(request.scope, request.receive, send) - async def _terminate_session(self) -> None: + async def terminate(self) -> None: """Terminate the current session, closing all streams. Once terminated, all requests with this session ID will receive 404 Not Found. @@ -837,9 +837,7 @@ async def message_router(): response_id = str(message.root.id) # If this response is for an existing request stream, # send it there - if response_id in self._request_streams: - target_request_id = response_id - + target_request_id = response_id else: # Extract related_request_id from meta if it exists if ( diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index e953ca39f..53d542d21 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -4,7 +4,6 @@ import contextlib import logging -import threading from collections.abc import AsyncIterator from http import HTTPStatus from typing import Any @@ -75,7 +74,7 @@ def __init__( # The task group will be set during lifespan self._task_group = None # Thread-safe tracking of run() calls - self._run_lock = threading.Lock() + self._run_lock = anyio.Lock() self._has_started = False @contextlib.asynccontextmanager @@ -97,7 +96,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: yield """ # Thread-safe check to ensure run() is only called once - with self._run_lock: + async with self._run_lock: if self._has_started: raise RuntimeError( "StreamableHTTPSessionManager .run() can only be called " @@ -190,6 +189,9 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA # Handle the HTTP request and return the response await http_transport.handle_request(scope, receive, send) + # Terminate the transport after the request is handled + await http_transport.terminate() + async def _handle_stateful_request( self, scope: Scope, diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 33878ee15..6bf15b531 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -114,20 +114,20 @@ class OAuthMetadata(BaseModel): registration_endpoint: AnyHttpUrl | None = None scopes_supported: list[str] | None = None response_types_supported: list[str] = ["code"] - response_modes_supported: list[Literal["query", "fragment", "form_post"]] | None = None + response_modes_supported: list[str] | None = None grant_types_supported: list[str] | None = None token_endpoint_auth_methods_supported: list[str] | None = None - token_endpoint_auth_signing_alg_values_supported: None = None + token_endpoint_auth_signing_alg_values_supported: list[str] | None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None op_policy_uri: AnyHttpUrl | None = None op_tos_uri: AnyHttpUrl | None = None revocation_endpoint: AnyHttpUrl | None = None revocation_endpoint_auth_methods_supported: list[str] | None = None - revocation_endpoint_auth_signing_alg_values_supported: None = None + revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None introspection_endpoint: AnyHttpUrl | None = None introspection_endpoint_auth_methods_supported: list[str] | None = None - introspection_endpoint_auth_signing_alg_values_supported: None = None + introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None code_challenge_methods_supported: list[str] | None = None @@ -139,6 +139,17 @@ class ProtectedResourceMetadata(BaseModel): resource: AnyHttpUrl authorization_servers: list[AnyHttpUrl] = Field(..., min_length=1) + jwks_uri: AnyHttpUrl | None = None scopes_supported: list[str] | None = None bearer_methods_supported: list[str] | None = Field(default=["header"]) # MCP only supports header method + resource_signing_alg_values_supported: list[str] | None = None + resource_name: str | None = None resource_documentation: AnyHttpUrl | None = None + resource_policy_uri: AnyHttpUrl | None = None + resource_tos_uri: AnyHttpUrl | None = None + # tls_client_certificate_bound_access_tokens default is False, but ommited here for clarity + tls_client_certificate_bound_access_tokens: bool | None = None + authorization_details_types_supported: list[str] | None = None + dpop_signing_alg_values_supported: list[str] | None = None + # dpop_bound_access_tokens_required default is False, but ommited here for clarity + dpop_bound_access_tokens_required: bool | None = None diff --git a/src/mcp/types.py b/src/mcp/types.py index 91432d69c..98fefa080 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -36,7 +36,7 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -RequestId = Annotated[int | str, Field(union_mode="left_to_right")] +RequestId = Annotated[int, Field(strict=True)] | str AnyFunction: TypeAlias = Callable[..., Any] @@ -849,7 +849,7 @@ class Tool(BaseMetadata): """A JSON Schema object defining the expected parameters for the tool.""" outputSchema: dict[str, Any] | None = None """ - An optional JSON Schema object defining the structure of the tool's output + An optional JSON Schema object defining the structure of the tool's output returned in the structuredContent field of a CallToolResult. """ annotations: ToolAnnotations | None = None