diff --git a/.gitignore b/.gitignore index 54006f93f..2754db9d9 100644 --- a/.gitignore +++ b/.gitignore @@ -165,4 +165,5 @@ cython_debug/ #.idea/ # vscode -.vscode/ \ No newline at end of file +.vscode/ +.windsurfrules diff --git a/CLAUDE.md b/CLAUDE.md index e95b75cd5..619f3bb44 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -19,7 +19,7 @@ This document contains critical information about working with this codebase. Fo - Line length: 88 chars maximum 3. Testing Requirements - - Framework: `uv run pytest` + - Framework: `uv run --frozen pytest` - Async testing: use anyio, not asyncio - Coverage: test edge cases and errors - New features require tests @@ -54,9 +54,9 @@ This document contains critical information about working with this codebase. Fo ## Code Formatting 1. Ruff - - Format: `uv run ruff format .` - - Check: `uv run ruff check .` - - Fix: `uv run ruff check . --fix` + - Format: `uv run --frozen ruff format .` + - Check: `uv run --frozen ruff check .` + - Fix: `uv run --frozen ruff check . --fix` - Critical issues: - Line length (88 chars) - Import sorting (I001) @@ -67,7 +67,7 @@ This document contains critical information about working with this codebase. Fo - Imports: split into multiple lines 2. Type Checking - - Tool: `uv run pyright` + - Tool: `uv run --frozen pyright` - Requirements: - Explicit None checks for Optional - Type narrowing for strings @@ -104,6 +104,10 @@ This document contains critical information about working with this codebase. Fo - Add None checks - Narrow string types - Match existing patterns + - Pytest: + - If the tests aren't finding the anyio pytest mark, try adding PYTEST_DISABLE_PLUGIN_AUTOLOAD="" + to the start of the pytest run command eg: + `PYTEST_DISABLE_PLUGIN_AUTOLOAD="" uv run --frozen pytest` 3. Best Practices - Check git status before commits diff --git a/README.md b/README.md index 60b5a7261..011039c44 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,33 @@ async def long_task(files: list[str], ctx: Context) -> str: return "Processing complete" ``` +### Authentication + +Authentication can be used by servers that want to expose tools accessing protected resources. + +`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by +providing an implementation of the `OAuthServerProvider` protocol. + +``` +mcp = FastMCP("My App", + auth_provider=MyOAuthServerProvider(), + auth=AuthSettings( + issuer_url="https://myapp.com", + revocation_options=RevocationOptions( + enabled=True, + ), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["myscope", "myotherscope"], + default_scopes=["myscope"], + ), + required_scopes=["myscope"], + ), +) +``` + +See [OAuthServerProvider](mcp/server/auth/provider.py) for more details. + ## Running Your Server ### Development Mode diff --git a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py index 30bca7229..7d73e9876 100644 --- a/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py +++ b/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py @@ -322,8 +322,7 @@ async def process_llm_response(self, llm_response: str) -> str: total = result["total"] percentage = (progress / total) * 100 logging.info( - f"Progress: {progress}/{total} " - f"({percentage:.1f}%)" + f"Progress: {progress}/{total} ({percentage:.1f}%)" ) return f"Tool execution result: {result}" diff --git a/pyproject.toml b/pyproject.toml index e400ad7d8..ee77806af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,10 @@ mcp = "mcp.cli:app [cli]" [tool.uv] resolution = "lowest-direct" dev-dependencies = [ - "pyright>=1.1.391", + "pyright>=1.1.396", "pytest>=8.3.4", "ruff>=0.8.5", "trio>=0.26.2", - "pytest-flakefinder>=1.1.0", "pytest-xdist>=3.6.1", "pytest-examples>=0.0.14", ] @@ -101,8 +100,12 @@ mcp = { workspace = true } xfail_strict = true filterwarnings = [ "error", + # this is a long-standing issue with fastmcp, which is just now being exercised by tests + "ignore:Unclosed:ResourceWarning", # This should be fixed on Uvicorn's side. "ignore::DeprecationWarning:websockets", "ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning", - "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel" + "ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel", + # this is a problem in starlette + "ignore:Please use `import python_multipart` instead.:PendingDeprecationWarning", ] diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mcp/client/auth/oauth.py b/src/mcp/client/auth/oauth.py new file mode 100644 index 000000000..a43a461db --- /dev/null +++ b/src/mcp/client/auth/oauth.py @@ -0,0 +1,584 @@ +""" +Authentication functionality for MCP client. + +This module provides authentication mechanisms for the MCP client to authenticate +with an MCP server. It implements the authentication flow as specified in the MCP +authorization specification. +""" + +from __future__ import annotations as _annotations + +import base64 +import hashlib +import json +import logging +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any, Protocol +from urllib.parse import urlencode, urlparse + +import httpx +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, ConfigDict, Field + +logger = logging.getLogger(__name__) + + +class AccessToken(BaseModel): + """ + Represents an OAuth 2.0 access token with its associated metadata. + """ + + access_token: str + token_type: str = Field(default="Bearer") + expires_in: timedelta | None = None + refresh_token: str | None = None + scope: str | None = None + + created_at: datetime = Field(default=datetime.now(), exclude=True) + + model_config = ConfigDict(extra="allow") + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return ( + self.expires_in is not None + and datetime.now() >= self.created_at + self.expires_in + ) + + @property + def scopes(self) -> list[str]: + """Convert scope string to list of scopes.""" + if isinstance(self.scope, list): + return self.scope + return self.scope.split() if self.scope else [] + + def to_auth_header(self) -> dict[str, str]: + """Convert token to Authorization header.""" + + return {"Authorization": f"{self.token_type} {self.access_token}"} + + +class ClientMetadata(BaseModel): + """ + OAuth 2.0 Dynamic Client Registration Metadata. + + This model represents the client metadata used when registering a client + with an OAuth 2.0 server using the Dynamic Client Registration protocol + as defined in RFC 7591 Section 2. + """ + + redirect_uris: list[AnyHttpUrl] = Field(default_factory=list) + token_endpoint_auth_method: str | None = None + grant_types: list[str] | None = None + response_types: list[str] | None = None + client_name: str | None = None + client_uri: AnyHttpUrl | None = None + logo_uri: AnyHttpUrl | None = None + scope: str | None = None + contacts: list[str] | None = None + tos_uri: AnyHttpUrl | None = None + policy_uri: AnyHttpUrl | None = None + jwks_uri: AnyHttpUrl | None = None + jwks: dict[str, Any] | None = None + software_id: str | None = None + software_version: str | None = None + + model_config = ConfigDict(extra="allow") + + +class DynamicClientRegistration(ClientMetadata): + """ + Response from OAuth 2.0 Dynamic Client Registration. + + This model represents the response received after registering a client + with an OAuth 2.0 server using the Dynamic Client Registration protocol + as defined in RFC 7591. + + Note that we inherit from ClientMetadata, which contains the client metadata, + since all values sent during the request are also returned in the response, + as per https://datatracker.ietf.org/doc/html/rfc7591#section-3.2.1 + """ + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + model_config = ConfigDict(extra="allow") + + +class ServerMetadataDiscovery(BaseModel): + """ + OAuth 2.0 Authorization Server Metadata Discovery Response. + + This model represents the response received from an OAuth 2.0 server's + metadata discovery endpoint as defined in RFC 8414. + """ + + issuer: AnyHttpUrl + authorization_endpoint: AnyHttpUrl + token_endpoint: AnyHttpUrl + registration_endpoint: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + response_types_supported: list[str] + response_modes_supported: list[str] | None = None + grant_types_supported: list[str] | None = None + token_endpoint_auth_methods_supported: list[str] | None = None + token_endpoint_auth_signing_alg_values_supported: list[str] | None = None + service_documentation: AnyHttpUrl | None = None + revocation_endpoint: AnyHttpUrl | None = None + revocation_endpoint_auth_methods_supported: list[str] | None = None + revocation_endpoint_auth_signing_alg_values_supported: list[str] | None = None + introspection_endpoint: AnyHttpUrl | None = None + introspection_endpoint_auth_methods_supported: list[str] | None = None + introspection_endpoint_auth_signing_alg_values_supported: list[str] | None = None + code_challenge_methods_supported: list[str] | None = None + + model_config = ConfigDict(extra="allow") + + +class OAuthClientProvider(Protocol): + @property + def client_metadata(self) -> ClientMetadata: ... + + @property + def redirect_url(self) -> AnyHttpUrl: ... + + async def open_user_agent(self, url: AnyHttpUrl) -> None: + """ + Opens the user agent to the given URL. + """ + ... + + async def client_registration( + self, issuer: AnyHttpUrl + ) -> DynamicClientRegistration | None: + """ + Loads the client registration for the given endpoint. + """ + ... + + async def store_client_registration( + self, issuer: AnyHttpUrl, metadata: DynamicClientRegistration + ) -> None: + """ + Stores the client registration to be retreived for the next session + """ + ... + + async def store_metadata( + self, issuer: AnyHttpUrl, metadata: ServerMetadataDiscovery + ) -> None: + """ + Stores the metadata for the given issuer + """ + ... + + async def metadata(self, issuer: AnyHttpUrl) -> ServerMetadataDiscovery | None: + """ + Loads the metadata for the given issuer + """ + ... + + def code_verifier(self) -> str: + """ + Loads the PKCE code verifier for the current session. + See https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1 + """ + ... + + async def token(self) -> AccessToken | None: + """ + Loads the token for the current session. + """ + ... + + async def store_token(self, token: AccessToken) -> None: + """ + Stores the token to be retreived for the next session + """ + ... + + +class NotFoundError(Exception): + """Exception raised when a resource or endpoint is not found.""" + + pass + + +class RegistrationFailedError(Exception): + """Exception raised when client registration fails.""" + + pass + + +class GrantNotSupported(Exception): + """Exception raised when a grant type is not supported.""" + + pass + + +class OAuthClient: + WELL_KNOWN = "/.well-known/oauth-authorization-server" + GRANT_TYPE: str = "authorization_code" + + @dataclass + class State: + metadata: ServerMetadataDiscovery | None = None + registeration: DynamicClientRegistration | None = None + + def __init__( + self, + server_url: AnyHttpUrl, + provider: OAuthClientProvider, + scope: str | None = None, + ): + self.http_client = httpx.AsyncClient() + self.server_url = server_url + self.provider = provider + self.scope = scope + self.state = self.State() + + @property + def is_authenticated(self) -> bool: + """Check if client has a valid, non-expired token.""" + return self.token is not None and not self.token.is_expired() + + @property + def discovery_url(self) -> AnyHttpUrl: + base_url = str(self.server_url).rstrip("/") + parsed_url = urlparse(base_url) + + # HTTPS is required by RFC 8414 + discovery_url = f"https://{parsed_url.netloc}{self.WELL_KNOWN}" + return AnyUrl(discovery_url) + + async def _obtain_metadata(self) -> ServerMetadataDiscovery: + if metadata := await self.provider.metadata(self.discovery_url): + return metadata + if metadata := await self.discover_auth_metadata(self.discovery_url): + await self.provider.store_metadata(self.discovery_url, metadata) + return metadata + return self.default_metadata() + + async def metadata(self) -> ServerMetadataDiscovery: + if self.state.metadata is not None: + return self.state.metadata + + self.state.metadata = await self._obtain_metadata() + return self.state.metadata + + async def _obtain_client( + self, metadata: ServerMetadataDiscovery + ) -> DynamicClientRegistration: + """ + Obtain a client by either reading it from the OAuthProvider or registering it. + """ + if metadata.registration_endpoint is None: + raise NotFoundError("Registration endpoint not found") + + if registration := await self.provider.client_registration(metadata.issuer): + return registration + else: + registration = await self.dynamic_client_registration( + self.provider.client_metadata, metadata.registration_endpoint + ) + if registration is None: + raise RegistrationFailedError( + f"Registration at {metadata.registration_endpoint} failed" + ) + + await self.provider.store_client_registration(metadata.issuer, registration) + return registration + + async def client_metadata( + self, metadata: ServerMetadataDiscovery + ) -> DynamicClientRegistration: + if self.state.registeration is not None: + return self.state.registeration + else: + return await self._obtain_client(metadata) + + def default_metadata(self) -> ServerMetadataDiscovery: + """ + Returns default endpoints as specified in + https://spec.modelcontextprotocol.io/specification/draft/basic/authorization/ + for the server. + """ + base_url = AnyUrl(str(self.server_url).rstrip("/")) + return ServerMetadataDiscovery( + issuer=base_url, + authorization_endpoint=AnyUrl(f"{base_url}/authorize"), + token_endpoint=AnyUrl(f"{base_url}/token"), + registration_endpoint=AnyUrl(f"{base_url}/register"), + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token"], + token_endpoint_auth_methods_supported=["client_secret_post"], + ) + + async def discover_auth_metadata( + self, discovery_url: AnyHttpUrl + ) -> ServerMetadataDiscovery | None: + """ + Use RFC 8414 to discover the authorization server metadata. + """ + try: + response = await self.http_client.get(str(discovery_url)) + if response.status_code == 404: + return None + response.raise_for_status() + json_data = await response.aread() + return ServerMetadataDiscovery.model_validate_json(json_data) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP status: {e}") + raise + except Exception as e: + logger.error(f"Error during auth metadata discovery: {e}") + raise + + async def dynamic_client_registration( + self, client_metadata: ClientMetadata, registration_endpoint: AnyHttpUrl + ) -> DynamicClientRegistration | None: + """ + Register a client dynamically with an OAuth 2.0 authorization server + following RFC 7591. + """ + headers = {"Content-Type": "application/json", "Accept": "application/json"} + + try: + response = await self.http_client.post( + str(registration_endpoint), + json=client_metadata.model_dump(exclude_none=True), + headers=headers, + ) + if response.status_code == 404: + logger.error( + f"Registration endpoint not found at {registration_endpoint}" + ) + return None + response.raise_for_status() + client_data = await response.aread() + return DynamicClientRegistration.model_validate_json(client_data) + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error in client registration: {e.response.status_code}") + if e.response.content: + try: + error_data = json.loads(e.response.content) + logger.error(f"Error details: {error_data}") + except json.JSONDecodeError: + logger.error(f"Error content: {e.response.content}") + except Exception as e: + logger.error(f"Unexpected error during registration: {e}") + + return None + + async def start_auth(self) -> AnyHttpUrl: + """ + Start the OAuth 2.1 authorization flow by redirecting the user to the + authorization server. + + Returns: + AnyHttpUrl: The authorization URL to redirect the user to + """ + metadata = await self.metadata() + registration = await self.client_metadata(metadata) + + # Generate PKCE code verifier + code_verifier = self.provider.code_verifier() + + # Build authorization URL + authorization_url = get_authorization_url( + metadata.authorization_endpoint, + self.provider.redirect_url, + registration.client_id, + code_verifier, + self.scope, + ) + + # Open the URL in the user's browser + await self.provider.open_user_agent(authorization_url) + + return authorization_url + + async def finalize_auth(self, authorization_code: str) -> AccessToken: + """ + Complete the OAuth 2.1 authorization flow by exchanging authorization code + for tokens. + + Args: + authorization_code: The authorization code received from the authorization + server + + Returns: + AccessToken: The resulting access token + """ + # Get metadata and registration info + metadata = await self.metadata() + registration = await self.client_metadata(metadata) + code_verifier = self.provider.code_verifier() + + # Exchange the code for a token + token = await self.exchange_authorization( + metadata, + registration, + self.provider.redirect_url, + code_verifier, + authorization_code, + ) + + # Cache the token and store it for future use + self.token = token + await self.provider.store_token(token) + + return token + + async def refresh_if_needed(self) -> AccessToken | None: + """ + Get the current token from the underlying provider + """ + # Return cached token if it's valid + metadata = await self.metadata() + registration = await self.client_metadata(metadata) + + if token := await self.provider.token(): + if not token.is_expired(): + return token + + token = await self.refresh_token( + token, + metadata.token_endpoint, + registration.client_id, + registration.client_secret, + ) + + if token is not None: + return token + + return None + + async def refresh_token( + self, + token: AccessToken, + token_endpoint: AnyHttpUrl, + client_id: str, + client_secret: str | None = None, + ) -> AccessToken: + """ + Refresh the access token using a refresh token. + """ + data = { + "grant_type": "refresh_token", + "refresh_token": token.refresh_token, + "client_id": client_id, + } + + if client_secret: + data["client_secret"] = client_secret + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + try: + response = await self.http_client.post( + str(token_endpoint), data=data, headers=headers + ) + response.raise_for_status() + token_data = response.json() + return AccessToken(**token_data) + except Exception as e: + logger.error(f"Error refreshing token: {e}") + raise + + async def exchange_authorization( + self, + metadata: ServerMetadataDiscovery, + registration: DynamicClientRegistration, + redirect_uri: AnyHttpUrl, + code_verifier: str, + authorization_code: str, + grant_type: str = "authorization_code", + ) -> AccessToken: + """ + Exchange an authorization code for an access token using OAuth 2.1 with PKCE. + """ + if grant_type not in (registration.grant_types or []): + raise GrantNotSupported(f"Grant type {grant_type} not supported") + + # Get token endpoint from server metadata or use default + token_endpoint = str(metadata.token_endpoint) + + # Prepare token request parameters + data = { + "grant_type": grant_type, + "code": authorization_code, + "redirect_uri": str(redirect_uri), + "client_id": registration.client_id, + "code_verifier": code_verifier, + } + + # Add client secret if available (optional in OAuth 2.1) + if registration.client_secret: + data["client_secret"] = registration.client_secret + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + + try: + response = await self.http_client.post( + token_endpoint, data=data, headers=headers + ) + response.raise_for_status() + token_data = response.json() + + # Create and return the token + return AccessToken(**token_data) + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP error during token exchange: {e.response.status_code}") + if e.response.content: + logger.error(f"Error content: {e.response.content}") + raise + except Exception as e: + logger.error(f"Unexpected error during token exchange: {e}") + raise + + +def get_authorization_url( + authorization_endpoint: AnyHttpUrl, + redirect_uri: AnyHttpUrl, + client_id: str, + code_verifier: str, + scope: str | None = None, +) -> AnyHttpUrl: + """Generate an OAuth 2.1 authorization URL for the user agent. + + This method generates a URL that the user agent (browser) should visit to + authenticate the user and authorize the application. It includes PKCE + (Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1. + """ + # Generate code challenge from verifier using SHA-256 + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + + # Build authorization URL with necessary parameters + params = { + "response_type": "code", + "client_id": client_id, + "redirect_uri": str(redirect_uri), + "code_challenge": code_challenge, + "code_challenge_method": "S256", + } + + # Add scope if provided or use the one from registration + if scope: + params["scope"] = scope + + # Construct the full authorization URL + return AnyUrl(f"{authorization_endpoint}?{urlencode(params)}") diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 4f6241a72..0812876fc 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -10,6 +10,8 @@ from httpx_sse import aconnect_sse import mcp.types as types +from mcp.client.auth import http as auth_http +from mcp.client.auth.oauth import AuthSession, OAuthClient logger = logging.getLogger(__name__) @@ -24,6 +26,7 @@ async def sse_client( headers: dict[str, Any] | None = None, timeout: float = 5, sse_read_timeout: float = 60 * 5, + auth: AuthSession | OAuthClient | None = None, ): """ Client transport for SSE. @@ -43,7 +46,33 @@ async def sse_client( async with anyio.create_task_group() as tg: try: logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: + + # Set up headers and auth if needed + if headers is None: + headers = {} + + if auth is not None: + await auth_http.add_auth_headers(headers, auth) + + # Set up event hooks for auth if auth is provided + event_hooks = {} + if auth is not None: + # Create a response hook for authentication + async def auth_hook(response): + if isinstance(auth, AuthSession): + return await auth_http.auth_response_hook( + response, auth_session=auth + ) + else: + return await auth_http.auth_response_hook( + response, oauth_client=auth + ) + + event_hooks["response"] = [auth_hook] + + async with httpx.AsyncClient( + headers=headers, event_hooks=event_hooks + ) as client: async with aconnect_sse( client, "GET", @@ -121,6 +150,7 @@ async def post_writer(endpoint_url: str): exclude_none=True, ), ) + # Handle 401 responses through the auth hook response.raise_for_status() logger.debug( "Client message sent successfully: " diff --git a/src/mcp/server/auth/__init__.py b/src/mcp/server/auth/__init__.py new file mode 100644 index 000000000..6888ffe8d --- /dev/null +++ b/src/mcp/server/auth/__init__.py @@ -0,0 +1,3 @@ +""" +MCP OAuth server authorization components. +""" diff --git a/src/mcp/server/auth/errors.py b/src/mcp/server/auth/errors.py new file mode 100644 index 000000000..935328598 --- /dev/null +++ b/src/mcp/server/auth/errors.py @@ -0,0 +1,35 @@ +from typing import Literal + +from pydantic import BaseModel, ValidationError + +ErrorCode = Literal["invalid_request", "invalid_client"] + + +class ErrorResponse(BaseModel): + error: ErrorCode + error_description: str + + +class OAuthError(Exception): + """ + Base class for all OAuth errors. + """ + + error_code: ErrorCode + + def __init__(self, error_description: str): + super().__init__(error_description) + self.error_description = error_description + + def error_response(self) -> ErrorResponse: + return ErrorResponse( + error=self.error_code, + error_description=self.error_description, + ) + + +def stringify_pydantic_error(validation_error: ValidationError) -> str: + return "\n".join( + f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" + for e in validation_error.errors() + ) diff --git a/src/mcp/server/auth/handlers/__init__.py b/src/mcp/server/auth/handlers/__init__.py new file mode 100644 index 000000000..e99a62de1 --- /dev/null +++ b/src/mcp/server/auth/handlers/__init__.py @@ -0,0 +1,3 @@ +""" +Request handlers for MCP authorization endpoints. +""" diff --git a/src/mcp/server/auth/handlers/authorize.py b/src/mcp/server/auth/handlers/authorize.py new file mode 100644 index 000000000..b6079da97 --- /dev/null +++ b/src/mcp/server/auth/handlers/authorize.py @@ -0,0 +1,257 @@ +import logging +from dataclasses import dataclass +from typing import Any, Literal +from urllib.parse import urlencode, urlparse, urlunparse + +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError +from starlette.datastructures import FormData, QueryParams +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response + +from mcp.server.auth.errors import ( + OAuthError, + stringify_pydantic_error, +) +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.provider import ( + AuthorizationErrorCode, + AuthorizationParams, + AuthorizeError, + OAuthServerProvider, + construct_redirect_uri, +) +from mcp.shared.auth import ( + InvalidRedirectUriError, + InvalidScopeError, +) + +logger = logging.getLogger(__name__) + + +class AuthorizationRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 + client_id: str = Field(..., description="The client ID") + redirect_uri: AnyHttpUrl | None = Field( + None, description="URL to redirect to after authorization" + ) + + # see OAuthClientMetadata; we only support `code` + response_type: Literal["code"] = Field( + ..., description="Must be 'code' for authorization code flow" + ) + code_challenge: str = Field(..., description="PKCE code challenge") + code_challenge_method: Literal["S256"] = Field( + "S256", description="PKCE code challenge method, must be S256" + ) + state: str | None = Field(None, description="Optional state parameter") + scope: str | None = Field( + None, + description="Optional scope; if specified, should be " + "a space-separated list of scope strings", + ) + + +class AuthorizationErrorResponse(BaseModel): + error: AuthorizationErrorCode + error_description: str | None + error_uri: AnyUrl | None = None + # must be set if provided in the request + state: str | None = None + + +def best_effort_extract_string( + key: str, params: None | FormData | QueryParams +) -> str | None: + if params is None: + return None + value = params.get(key) + if isinstance(value, str): + return value + return None + + +class AnyHttpUrlModel(RootModel[AnyHttpUrl]): + root: AnyHttpUrl + + +@dataclass +class AuthorizationHandler: + provider: OAuthServerProvider[Any, Any, Any] + + async def handle(self, request: Request) -> Response: + # implements authorization requests for grant_type=code; + # see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1 + + state = None + redirect_uri = None + client = None + params = None + + async def error_response( + error: AuthorizationErrorCode, + error_description: str | None, + attempt_load_client: bool = True, + ): + nonlocal client, redirect_uri, state + if client is None and attempt_load_client: + # make last-ditch attempt to load the client + client_id = best_effort_extract_string("client_id", params) + client = client_id and await self.provider.get_client(client_id) + if redirect_uri is None and client: + # make last-ditch effort to load the redirect uri + if params is not None and "redirect_uri" not in params: + raw_redirect_uri = None + else: + raw_redirect_uri = AnyHttpUrlModel.model_validate( + best_effort_extract_string("redirect_uri", params) + ).root + try: + redirect_uri = client.validate_redirect_uri(raw_redirect_uri) + except (ValidationError, InvalidRedirectUriError): + pass + if state is None: + # make last-ditch effort to load state + state = best_effort_extract_string("state", params) + + error_resp = AuthorizationErrorResponse( + error=error, + error_description=error_description, + state=state, + ) + + if redirect_uri and client: + return RedirectResponse( + url=construct_redirect_uri( + str(redirect_uri), **error_resp.model_dump(exclude_none=True) + ), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + else: + return PydanticJSONResponse( + status_code=400, + content=error_resp, + headers={"Cache-Control": "no-store"}, + ) + + try: + # Parse request parameters + if request.method == "GET": + # Convert query_params to dict for pydantic validation + params = request.query_params + else: + # Parse form data for POST requests + params = await request.form() + + # Save state if it exists, even before validation + state = best_effort_extract_string("state", params) + + try: + auth_request = AuthorizationRequest.model_validate(params) + state = auth_request.state # Update with validated state + except ValidationError as validation_error: + error: AuthorizationErrorCode = "invalid_request" + for e in validation_error.errors(): + if e["loc"] == ("response_type",) and e["type"] == "literal_error": + error = "unsupported_response_type" + break + return await error_response( + error, stringify_pydantic_error(validation_error) + ) + + # Get client information + client = await self.provider.get_client( + auth_request.client_id, + ) + if not client: + # For client_id validation errors, return direct error (no redirect) + return await error_response( + error="invalid_request", + error_description=f"Client ID '{auth_request.client_id}' not found", + attempt_load_client=False, + ) + + # Validate redirect_uri against client's registered URIs + try: + redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri) + except InvalidRedirectUriError as validation_error: + # For redirect_uri validation errors, return direct error (no redirect) + return await error_response( + error="invalid_request", + error_description=validation_error.message, + ) + + # Validate scope - for scope errors, we can redirect + try: + scopes = client.validate_scope(auth_request.scope) + except InvalidScopeError as validation_error: + # For scope errors, redirect with error parameters + return await error_response( + error="invalid_scope", + error_description=validation_error.message, + ) + + # Setup authorization parameters + auth_params = AuthorizationParams( + state=state, + scopes=scopes, + code_challenge=auth_request.code_challenge, + redirect_uri=redirect_uri, + ) + + try: + # Let the provider pick the next URI to redirect to + return RedirectResponse( + url=await self.provider.authorize( + client, + auth_params, + ), + status_code=302, + headers={"Cache-Control": "no-store"}, + ) + except AuthorizeError as e: + # Handle authorization errors as defined in RFC 6749 Section 4.1.2.1 + return await error_response( + error=e.error, + error_description=e.error_description, + ) + + except Exception as validation_error: + # Catch-all for unexpected errors + logger.exception( + "Unexpected error in authorization_handler", exc_info=validation_error + ) + return await error_response( + error="server_error", error_description="An unexpected error occurred" + ) + + +def create_error_redirect( + redirect_uri: AnyUrl, error: Exception | AuthorizationErrorResponse +) -> str: + parsed_uri = urlparse(str(redirect_uri)) + + if isinstance(error, AuthorizationErrorResponse): + # Convert ErrorResponse to dict + error_dict = error.model_dump(exclude_none=True) + query_params = {} + for key, value in error_dict.items(): + if value is not None: + if key == "error_uri" and hasattr(value, "__str__"): + query_params[key] = str(value) + else: + query_params[key] = value + + elif isinstance(error, OAuthError): + query_params = {"error": error.error_code, "error_description": str(error)} + else: + query_params = { + "error": "server_error", + "error_description": "An unknown error occurred", + } + + new_query = urlencode(query_params) + if parsed_uri.query: + new_query = f"{parsed_uri.query}&{new_query}" + + return urlunparse(parsed_uri._replace(query=new_query)) diff --git a/src/mcp/server/auth/handlers/metadata.py b/src/mcp/server/auth/handlers/metadata.py new file mode 100644 index 000000000..e37e5d311 --- /dev/null +++ b/src/mcp/server/auth/handlers/metadata.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.shared.auth import OAuthMetadata + + +@dataclass +class MetadataHandler: + metadata: OAuthMetadata + + async def handle(self, request: Request) -> Response: + return PydanticJSONResponse( + content=self.metadata, + headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour + ) diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py new file mode 100644 index 000000000..29f97319a --- /dev/null +++ b/src/mcp/server/auth/handlers/register.py @@ -0,0 +1,131 @@ +import secrets +import time +from dataclasses import dataclass +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, RootModel, ValidationError +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.errors import stringify_pydantic_error +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.provider import ( + OAuthServerProvider, + RegistrationError, + RegistrationErrorCode, +) +from mcp.server.auth.settings import ClientRegistrationOptions +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + + +class RegistrationRequest(RootModel[OAuthClientMetadata]): + # this wrapper is a no-op; it's just to separate out the types exposed to the + # provider from what we use in the HTTP handler + root: OAuthClientMetadata + + +class RegistrationErrorResponse(BaseModel): + error: RegistrationErrorCode + error_description: str | None + + +@dataclass +class RegistrationHandler: + provider: OAuthServerProvider[Any, Any, Any] + options: ClientRegistrationOptions + + async def handle(self, request: Request) -> Response: + # Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1 + try: + # Parse request body as JSON + body = await request.json() + client_metadata = OAuthClientMetadata.model_validate(body) + + # Scope validation is handled below + except ValidationError as validation_error: + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description=stringify_pydantic_error(validation_error), + ), + status_code=400, + ) + + client_id = str(uuid4()) + client_secret = None + if client_metadata.token_endpoint_auth_method != "none": + # cryptographically secure random 32-byte hex string + client_secret = secrets.token_hex(32) + + if client_metadata.scope is None and self.options.default_scopes is not None: + client_metadata.scope = " ".join(self.options.default_scopes) + elif ( + client_metadata.scope is not None and self.options.valid_scopes is not None + ): + requested_scopes = set(client_metadata.scope.split()) + valid_scopes = set(self.options.valid_scopes) + if not requested_scopes.issubset(valid_scopes): + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="Requested scopes are not valid: " + f"{', '.join(requested_scopes - valid_scopes)}", + ), + status_code=400, + ) + if set(client_metadata.grant_types) != set( + ["authorization_code", "refresh_token"] + ): + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error="invalid_client_metadata", + error_description="grant_types must be authorization_code " + "and refresh_token", + ), + status_code=400, + ) + + client_id_issued_at = int(time.time()) + client_secret_expires_at = ( + client_id_issued_at + self.options.client_secret_expiry_seconds + if self.options.client_secret_expiry_seconds is not None + else None + ) + + client_info = OAuthClientInformationFull( + client_id=client_id, + client_id_issued_at=client_id_issued_at, + client_secret=client_secret, + client_secret_expires_at=client_secret_expires_at, + # passthrough information from the client request + redirect_uris=client_metadata.redirect_uris, + token_endpoint_auth_method=client_metadata.token_endpoint_auth_method, + grant_types=client_metadata.grant_types, + response_types=client_metadata.response_types, + client_name=client_metadata.client_name, + client_uri=client_metadata.client_uri, + logo_uri=client_metadata.logo_uri, + scope=client_metadata.scope, + contacts=client_metadata.contacts, + tos_uri=client_metadata.tos_uri, + policy_uri=client_metadata.policy_uri, + jwks_uri=client_metadata.jwks_uri, + jwks=client_metadata.jwks, + software_id=client_metadata.software_id, + software_version=client_metadata.software_version, + ) + try: + # Register client + await self.provider.register_client(client_info) + + # Return client information + return PydanticJSONResponse(content=client_info, status_code=201) + except RegistrationError as e: + # Handle registration errors as defined in RFC 7591 Section 3.2.2 + return PydanticJSONResponse( + content=RegistrationErrorResponse( + error=e.error, error_description=e.error_description + ), + status_code=400, + ) diff --git a/src/mcp/server/auth/handlers/revoke.py b/src/mcp/server/auth/handlers/revoke.py new file mode 100644 index 000000000..37883cd70 --- /dev/null +++ b/src/mcp/server/auth/handlers/revoke.py @@ -0,0 +1,97 @@ +from dataclasses import dataclass +from functools import partial +from typing import Any, Literal + +from pydantic import BaseModel, ValidationError +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.auth.errors import ( + stringify_pydantic_error, +) +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.middleware.client_auth import ( + AuthenticationError, + ClientAuthenticator, +) +from mcp.server.auth.provider import AccessToken, OAuthServerProvider, RefreshToken + + +class RevocationRequest(BaseModel): + """ + # See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1 + """ + + token: str + token_type_hint: Literal["access_token", "refresh_token"] | None = None + client_id: str + client_secret: str | None + + +class RevocationErrorResponse(BaseModel): + error: Literal["invalid_request", "unauthorized_client"] + error_description: str | None = None + + +@dataclass +class RevocationHandler: + provider: OAuthServerProvider[Any, Any, Any] + client_authenticator: ClientAuthenticator + + async def handle(self, request: Request) -> Response: + """ + Handler for the OAuth 2.0 Token Revocation endpoint. + """ + try: + form_data = await request.form() + revocation_request = RevocationRequest.model_validate(dict(form_data)) + except ValidationError as e: + return PydanticJSONResponse( + status_code=400, + content=RevocationErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(e), + ), + ) + + # Authenticate client + try: + client = await self.client_authenticator.authenticate( + revocation_request.client_id, revocation_request.client_secret + ) + except AuthenticationError as e: + return PydanticJSONResponse( + status_code=401, + content=RevocationErrorResponse( + error="unauthorized_client", + error_description=e.message, + ), + ) + + loaders = [ + self.provider.load_access_token, + partial(self.provider.load_refresh_token, client), + ] + if revocation_request.token_type_hint == "refresh_token": + loaders = reversed(loaders) + + token: None | AccessToken | RefreshToken = None + for loader in loaders: + token = await loader(revocation_request.token) + if token is not None: + break + + # if token is not found, just return HTTP 200 per the RFC + if token and token.client_id == client.client_id: + # Revoke token; provider is not meant to be able to do validation + # at this point that would result in an error + await self.provider.revoke_token(token) + + # Return successful empty response + return Response( + status_code=200, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py new file mode 100644 index 000000000..aa1ce934e --- /dev/null +++ b/src/mcp/server/auth/handlers/token.py @@ -0,0 +1,258 @@ +import base64 +import hashlib +import time +from dataclasses import dataclass +from typing import Annotated, Any, Literal + +from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError +from starlette.requests import Request + +from mcp.server.auth.errors import ( + ErrorResponse, + stringify_pydantic_error, +) +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.middleware.client_auth import ( + AuthenticationError, + ClientAuthenticator, +) +from mcp.server.auth.provider import OAuthServerProvider, TokenError, TokenErrorCode +from mcp.shared.auth import OAuthToken + + +class AuthorizationCodeRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3 + grant_type: Literal["authorization_code"] + code: str = Field(..., description="The authorization code") + redirect_uri: AnyHttpUrl | None = Field( + default=None, + description="Must be the same as redirect URI provided in /authorize", + ) + client_id: str + # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + client_secret: str | None = None + # See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5 + code_verifier: str = Field(..., description="PKCE code verifier") + + +class RefreshTokenRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-6 + grant_type: Literal["refresh_token"] + refresh_token: str = Field(..., description="The refresh token") + scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + # we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1 + client_secret: str | None = None + + +class TokenRequest( + RootModel[ + Annotated[ + AuthorizationCodeRequest | RefreshTokenRequest, + Field(discriminator="grant_type"), + ] + ] +): + root: Annotated[ + AuthorizationCodeRequest | RefreshTokenRequest, + Field(discriminator="grant_type"), + ] + + +class TokenErrorResponse(BaseModel): + """ + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 + """ + + error: TokenErrorCode + error_description: str | None = None + error_uri: AnyHttpUrl | None = None + + +class TokenSuccessResponse(RootModel[OAuthToken]): + # this is just a wrapper over OAuthToken; the only reason we do this + # is to have some separation between the HTTP response type, and the + # type returned by the provider + root: OAuthToken + + +@dataclass +class TokenHandler: + provider: OAuthServerProvider[Any, Any, Any] + client_authenticator: ClientAuthenticator + + def response(self, obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse): + status_code = 200 + if isinstance(obj, TokenErrorResponse): + status_code = 400 + + return PydanticJSONResponse( + content=obj, + status_code=status_code, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) + + async def handle(self, request: Request): + try: + form_data = await request.form() + token_request = TokenRequest.model_validate(dict(form_data)).root + except ValidationError as validation_error: + return self.response( + TokenErrorResponse( + error="invalid_request", + error_description=stringify_pydantic_error(validation_error), + ) + ) + + try: + client_info = await self.client_authenticator.authenticate( + client_id=token_request.client_id, + client_secret=token_request.client_secret, + ) + except AuthenticationError as e: + return self.response( + TokenErrorResponse( + error="unauthorized_client", + error_description=e.message, + ) + ) + + if token_request.grant_type not in client_info.grant_types: + return self.response( + TokenErrorResponse( + error="unsupported_grant_type", + error_description=( + f"Unsupported grant type (supported grant types are " + f"{client_info.grant_types})" + ), + ) + ) + + tokens: OAuthToken + + match token_request: + case AuthorizationCodeRequest(): + auth_code = await self.provider.load_authorization_code( + client_info, token_request.code + ) + if auth_code is None or auth_code.client_id != token_request.client_id: + # if code belongs to different client, pretend it doesn't exist + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code does not exist", + ) + ) + + # make auth codes expire after a deadline + # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5 + if auth_code.expires_at < time.time(): + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="authorization code has expired", + ) + ) + + # # verify redirect_uri doesn't change between /authorize and /tokens + # # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6 + # if token_request.redirect_uri != auth_code.redirect_uri: + # return self.response( + # TokenErrorResponse( + # error="invalid_request", + # error_description=( + # "redirect_uri did not match the one " + # "used when creating auth code" + # ), + # ) + # ) + + # Verify PKCE code verifier + sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest() + hashed_code_verifier = ( + base64.urlsafe_b64encode(sha256).decode().rstrip("=") + ) + + if hashed_code_verifier != auth_code.code_challenge: + # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6 + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="incorrect code_verifier", + ) + ) + + try: + # Exchange authorization code for tokens + tokens = await self.provider.exchange_authorization_code( + client_info, auth_code + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + case RefreshTokenRequest(): + refresh_token = await self.provider.load_refresh_token( + client_info, token_request.refresh_token + ) + if ( + refresh_token is None + or refresh_token.client_id != token_request.client_id + ): + # if token belongs to different client, pretend it doesn't exist + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token does not exist", + ) + ) + + if refresh_token.expires_at and refresh_token.expires_at < time.time(): + # if the refresh token has expired, pretend it doesn't exist + return self.response( + TokenErrorResponse( + error="invalid_grant", + error_description="refresh token has expired", + ) + ) + + # Parse scopes if provided + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else refresh_token.scopes + ) + + for scope in scopes: + if scope not in refresh_token.scopes: + return self.response( + TokenErrorResponse( + error="invalid_scope", + error_description=( + f"cannot request scope `{scope}` " + "not provided by refresh token" + ), + ) + ) + + try: + # Exchange refresh token for new tokens + tokens = await self.provider.exchange_refresh_token( + client_info, refresh_token, scopes + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + return self.response(TokenSuccessResponse(root=tokens)) diff --git a/src/mcp/server/auth/json_response.py b/src/mcp/server/auth/json_response.py new file mode 100644 index 000000000..bd95bd693 --- /dev/null +++ b/src/mcp/server/auth/json_response.py @@ -0,0 +1,10 @@ +from typing import Any + +from starlette.responses import JSONResponse + + +class PydanticJSONResponse(JSONResponse): + # use pydantic json serialization instead of the stock `json.dumps`, + # so that we can handle serializing pydantic models like AnyHttpUrl + def render(self, content: Any) -> bytes: + return content.model_dump_json(exclude_none=True).encode("utf-8") diff --git a/src/mcp/server/auth/middleware/__init__.py b/src/mcp/server/auth/middleware/__init__.py new file mode 100644 index 000000000..ba3ff63c3 --- /dev/null +++ b/src/mcp/server/auth/middleware/__init__.py @@ -0,0 +1,3 @@ +""" +Middleware for MCP authorization. +""" diff --git a/src/mcp/server/auth/middleware/auth_context.py b/src/mcp/server/auth/middleware/auth_context.py new file mode 100644 index 000000000..1073c07ad --- /dev/null +++ b/src/mcp/server/auth/middleware/auth_context.py @@ -0,0 +1,50 @@ +import contextvars + +from starlette.types import ASGIApp, Receive, Scope, Send + +from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser +from mcp.server.auth.provider import AccessToken + +# Create a contextvar to store the authenticated user +# The default is None, indicating no authenticated user is present +auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]( + "auth_context", default=None +) + + +def get_access_token() -> AccessToken | None: + """ + Get the access token from the current context. + + Returns: + The access token if an authenticated user is available, None otherwise. + """ + auth_user = auth_context_var.get() + return auth_user.access_token if auth_user else None + + +class AuthContextMiddleware: + """ + Middleware that extracts the authenticated user from the request + and sets it in a contextvar for easy access throughout the request lifecycle. + + This middleware should be added after the AuthenticationMiddleware in the + middleware stack to ensure that the user is properly authenticated before + being stored in the context. + """ + + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send): + user = scope.get("user") + if isinstance(user, AuthenticatedUser): + # Set the authenticated user in the contextvar + token = auth_context_var.set(user) + try: + await self.app(scope, receive, send) + finally: + auth_context_var.reset(token) + else: + # No authenticated user, just process the request + await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/bearer_auth.py b/src/mcp/server/auth/middleware/bearer_auth.py new file mode 100644 index 000000000..15e6f2fc5 --- /dev/null +++ b/src/mcp/server/auth/middleware/bearer_auth.py @@ -0,0 +1,89 @@ +import time +from typing import Any + +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + SimpleUser, +) +from starlette.exceptions import HTTPException +from starlette.requests import HTTPConnection +from starlette.types import Receive, Scope, Send + +from mcp.server.auth.provider import AccessToken, OAuthServerProvider + + +class AuthenticatedUser(SimpleUser): + """User with authentication info.""" + + def __init__(self, auth_info: AccessToken): + super().__init__(auth_info.client_id) + self.access_token = auth_info + self.scopes = auth_info.scopes + + +class BearerAuthBackend(AuthenticationBackend): + """ + Authentication backend that validates Bearer tokens. + """ + + def __init__( + self, + provider: OAuthServerProvider[Any, Any, Any], + ): + self.provider = provider + + async def authenticate(self, conn: HTTPConnection): + auth_header = conn.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + return None + + token = auth_header[7:] # Remove "Bearer " prefix + + # Validate the token with the provider + auth_info = await self.provider.load_access_token(token) + + if not auth_info: + return None + + if auth_info.expires_at and auth_info.expires_at < int(time.time()): + return None + + return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info) + + +class RequireAuthMiddleware: + """ + Middleware that requires a valid Bearer token in the Authorization header. + + This will validate the token with the auth provider and store the resulting + auth info in the request state. + """ + + def __init__(self, app: Any, required_scopes: list[str]): + """ + Initialize the middleware. + + Args: + app: ASGI application + provider: Authentication provider to validate tokens + required_scopes: Optional list of scopes that the token must have + """ + self.app = app + self.required_scopes = required_scopes + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + auth_user = scope.get("user") + if not isinstance(auth_user, AuthenticatedUser): + raise HTTPException(status_code=401, detail="Unauthorized") + auth_credentials = scope.get("auth") + + for required_scope in self.required_scopes: + # auth_credentials should always be provided; this is just paranoia + if ( + auth_credentials is None + or required_scope not in auth_credentials.scopes + ): + raise HTTPException(status_code=403, detail="Insufficient scope") + + await self.app(scope, receive, send) diff --git a/src/mcp/server/auth/middleware/client_auth.py b/src/mcp/server/auth/middleware/client_auth.py new file mode 100644 index 000000000..da0ab0369 --- /dev/null +++ b/src/mcp/server/auth/middleware/client_auth.py @@ -0,0 +1,56 @@ +import time +from typing import Any + +from mcp.server.auth.provider import OAuthServerProvider +from mcp.shared.auth import OAuthClientInformationFull + + +class AuthenticationError(Exception): + def __init__(self, message: str): + self.message = message + + +class ClientAuthenticator: + """ + ClientAuthenticator is a callable which validates requests from a client + application, used to verify /token calls. + If, during registration, the client requested to be issued a secret, the + authenticator asserts that /token calls must be authenticated with + that same token. + NOTE: clients can opt for no authentication during registration, in which case this + logic is skipped. + """ + + def __init__(self, provider: OAuthServerProvider[Any, Any, Any]): + """ + Initialize the dependency. + + Args: + provider: Provider to look up client information + """ + self.provider = provider + + async def authenticate( + self, client_id: str, client_secret: str | None + ) -> OAuthClientInformationFull: + # Look up client information + client = await self.provider.get_client(client_id) + if not client: + raise AuthenticationError("Invalid client_id") + + # If client from the store expects a secret, validate that the request provides + # that secret + if client.client_secret: + if not client_secret: + raise AuthenticationError("Client secret is required") + + if client.client_secret != client_secret: + raise AuthenticationError("Invalid client_secret") + + if ( + client.client_secret_expires_at + and client.client_secret_expires_at < int(time.time()) + ): + raise AuthenticationError("Client secret has expired") + + return client diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py new file mode 100644 index 000000000..a6d5c0cf0 --- /dev/null +++ b/src/mcp/server/auth/provider.py @@ -0,0 +1,287 @@ +from dataclasses import dataclass +from typing import Generic, Literal, Protocol, TypeVar +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse + +from pydantic import AnyHttpUrl, BaseModel + +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthToken, +) + + +class AuthorizationParams(BaseModel): + state: str | None = None + scopes: list[str] | None = None + code_challenge: str + redirect_uri: AnyHttpUrl + + +class AuthorizationCode(BaseModel): + code: str + scopes: list[str] + expires_at: float + client_id: str + code_challenge: str + redirect_uri: AnyHttpUrl + + +class RefreshToken(BaseModel): + token: str + client_id: str + scopes: list[str] + expires_at: int | None = None + + +class AccessToken(BaseModel): + token: str + client_id: str + scopes: list[str] + expires_at: int | None = None + + +RegistrationErrorCode = Literal[ + "invalid_redirect_uri", + "invalid_client_metadata", + "invalid_software_statement", + "unapproved_software_statement", +] + + +@dataclass(frozen=True) +class RegistrationError(Exception): + error: RegistrationErrorCode + error_description: str | None = None + + +AuthorizationErrorCode = Literal[ + "invalid_request", + "unauthorized_client", + "access_denied", + "unsupported_response_type", + "invalid_scope", + "server_error", + "temporarily_unavailable", +] + + +@dataclass(frozen=True) +class AuthorizeError(Exception): + error: AuthorizationErrorCode + error_description: str | None = None + + +TokenErrorCode = Literal[ + "invalid_request", + "invalid_client", + "invalid_grant", + "unauthorized_client", + "unsupported_grant_type", + "invalid_scope", +] + + +@dataclass(frozen=True) +class TokenError(Exception): + error: TokenErrorCode + error_description: str | None = None + + +# NOTE: FastMCP doesn't render any of these types in the user response, so it's +# OK to add fields to subclasses which should not be exposed externally. +AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode) +RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken) +AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken) + + +class OAuthServerProvider( + Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT] +): + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + """ + Retrieves client information by client ID. + + Implementors MAY raise NotImplementedError if dynamic client registration is + disabled in ClientRegistrationOptions. + + Args: + client_id: The ID of the client to retrieve. + + Returns: + The client information, or None if the client does not exist. + """ + ... + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + """ + Saves client information as part of registering it. + + Implementors MAY raise NotImplementedError if dynamic client registration is + disabled in ClientRegistrationOptions. + + Args: + client_info: The client metadata to register. + + Raises: + RegistrationError: If the client metadata is invalid. + """ + ... + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """ + Called as part of the /authorize endpoint, and returns a URL that the client + will be redirected to. + Many MCP implementations will redirect to a third-party provider to perform + a second OAuth exchange with that provider. In this sort of setup, the client + has an OAuth connection with the MCP server, and the MCP server has an OAuth + connection with the 3rd-party provider. At the end of this flow, the client + should be redirected to the redirect_uri from params.redirect_uri. + + +--------+ +------------+ +-------------------+ + | | | | | | + | Client | --> | MCP Server | --> | 3rd Party OAuth | + | | | | | Server | + +--------+ +------------+ +-------------------+ + | ^ | + +------------+ | | | + | | | | Redirect | + |redirect_uri|<-----+ +------------------+ + | | + +------------+ + + Implementations will need to define another handler on the MCP server return + flow to perform the second redirect, and generates and stores an authorization + code as part of completing the OAuth authorization step. + + Implementations SHOULD generate an authorization code with at least 160 bits of + entropy, + and MUST generate an authorization code with at least 128 bits of entropy. + See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10. + + Args: + client: The client requesting authorization. + params: The parameters of the authorization request. + + Returns: + A URL to redirect the client to for authorization. + + Raises: + AuthorizeError: If the authorization request is invalid. + """ + ... + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCodeT | None: + """ + Loads an AuthorizationCode by its code. + + Args: + client: The client that requested the authorization code. + authorization_code: The authorization code to get the challenge for. + + Returns: + The AuthorizationCode, or None if not found + """ + ... + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT + ) -> OAuthToken: + """ + Exchanges an authorization code for an access token and refresh token. + + Args: + client: The client exchanging the authorization code. + authorization_code: The authorization code to exchange. + + Returns: + The OAuth token, containing access and refresh tokens. + + Raises: + TokenError: If the request is invalid + """ + ... + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshTokenT | None: + """ + Loads a RefreshToken by its token string. + + Args: + client: The client that is requesting to load the refresh token. + refresh_token: The refresh token string to load. + + Returns: + The RefreshToken object if found, or None if not found. + """ + + ... + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshTokenT, + scopes: list[str], + ) -> OAuthToken: + """ + Exchanges a refresh token for an access token and refresh token. + + Implementations SHOULD rotate both the access token and refresh token. + + Args: + client: The client exchanging the refresh token. + refresh_token: The refresh token to exchange. + scopes: Optional scopes to request with the new access token. + + Returns: + The OAuth token, containing access and refresh tokens. + + Raises: + TokenError: If the request is invalid + """ + ... + + async def load_access_token(self, token: str) -> AccessTokenT | None: + """ + Loads an access token by its token. + + Args: + token: The access token to verify. + + Returns: + The AuthInfo, or None if the token is invalid. + """ + ... + + async def revoke_token( + self, + token: AccessTokenT | RefreshTokenT, + ) -> None: + """ + Revokes an access or refresh token. + + If the given token is invalid or already revoked, this method should do nothing. + + Implementations SHOULD revoke both the access token and its corresponding + refresh token, regardless of which of the access token or refresh token is + provided. + + Args: + token: the token to revoke + """ + ... + + +def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str: + parsed_uri = urlparse(redirect_uri_base) + query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs] + for k, v in params.items(): + if v is not None: + query_params.append((k, v)) + + redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params))) + return redirect_uri diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py new file mode 100644 index 000000000..865a61d6d --- /dev/null +++ b/src/mcp/server/auth/routes.py @@ -0,0 +1,172 @@ +from collections.abc import Callable +from typing import Any + +from pydantic import AnyHttpUrl +from starlette.middleware.cors import CORSMiddleware +from starlette.routing import Route +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from mcp.server.auth.handlers.authorize import AuthorizationHandler +from mcp.server.auth.handlers.metadata import MetadataHandler +from mcp.server.auth.handlers.register import RegistrationHandler +from mcp.server.auth.handlers.revoke import RevocationHandler +from mcp.server.auth.handlers.token import TokenHandler +from mcp.server.auth.middleware.client_auth import ClientAuthenticator +from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions +from mcp.shared.auth import OAuthMetadata + + +def validate_issuer_url(url: AnyHttpUrl): + """ + Validate that the issuer URL meets OAuth 2.0 requirements. + + Args: + url: The issuer URL to validate + + Raises: + ValueError: If the issuer URL is invalid + """ + + # RFC 8414 requires HTTPS, but we allow localhost HTTP for testing + if ( + url.scheme != "https" + and url.host != "localhost" + and not url.host.startswith("127.0.0.1") + ): + raise ValueError("Issuer URL must be HTTPS") + + # No fragments or query parameters allowed + if url.fragment: + raise ValueError("Issuer URL must not have a fragment") + if url.query: + raise ValueError("Issuer URL must not have a query string") + + +AUTHORIZATION_PATH = "/authorize" +TOKEN_PATH = "/token" +REGISTRATION_PATH = "/register" +REVOCATION_PATH = "/revoke" + + +def create_auth_routes( + provider: OAuthServerProvider[Any, Any, Any], + issuer_url: AnyHttpUrl, + service_documentation_url: AnyHttpUrl | None = None, + client_registration_options: ClientRegistrationOptions | None = None, + revocation_options: RevocationOptions | None = None, +) -> list[Route]: + validate_issuer_url(issuer_url) + + client_registration_options = ( + client_registration_options or ClientRegistrationOptions() + ) + revocation_options = revocation_options or RevocationOptions() + metadata = build_metadata( + issuer_url, + service_documentation_url, + client_registration_options, + revocation_options, + ) + client_authenticator = ClientAuthenticator(provider) + + # Create routes + routes = [ + Route( + "/.well-known/oauth-authorization-server", + endpoint=MetadataHandler(metadata).handle, + methods=["GET", "OPTIONS"], + ), + Route( + AUTHORIZATION_PATH, + endpoint=AuthorizationHandler(provider).handle, + methods=["GET", "POST", "OPTIONS"], + ), + Route( + TOKEN_PATH, + endpoint=TokenHandler(provider, client_authenticator).handle, + methods=["POST", "OPTIONS"], + ), + ] + + if client_registration_options.enabled: + registration_handler = RegistrationHandler( + provider, + options=client_registration_options, + ) + routes.append( + Route( + REGISTRATION_PATH, + endpoint=registration_handler.handle, + methods=["POST"], + ) + ) + + if revocation_options.enabled: + revocation_handler = RevocationHandler(provider, client_authenticator) + routes.append( + Route(REVOCATION_PATH, endpoint=revocation_handler.handle, methods=["POST"]) + ) + + return routes + + +def modify_url_path(url: AnyHttpUrl, path_mapper: Callable[[str], str]) -> AnyHttpUrl: + return AnyHttpUrl.build( + scheme=url.scheme, + username=url.username, + password=url.password, + host=url.host, + port=url.port, + path=path_mapper(url.path or ""), + query=url.query, + fragment=url.fragment, + ) + + +def build_metadata( + issuer_url: AnyHttpUrl, + service_documentation_url: AnyHttpUrl | None, + client_registration_options: ClientRegistrationOptions, + revocation_options: RevocationOptions, +) -> OAuthMetadata: + authorization_url = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + AUTHORIZATION_PATH.lstrip("/") + ) + token_url = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + TOKEN_PATH.lstrip("/") + ) + # Create metadata + metadata = OAuthMetadata( + issuer=issuer_url, + authorization_endpoint=authorization_url, + token_endpoint=token_url, + scopes_supported=None, + response_types_supported=["code"], + response_modes_supported=None, + grant_types_supported=["authorization_code", "refresh_token"], + token_endpoint_auth_methods_supported=["client_secret_post"], + token_endpoint_auth_signing_alg_values_supported=None, + service_documentation=service_documentation_url, + ui_locales_supported=None, + op_policy_uri=None, + op_tos_uri=None, + introspection_endpoint=None, + code_challenge_methods_supported=["S256"], + ) + + # Add registration endpoint if supported + if client_registration_options.enabled: + metadata.registration_endpoint = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + REGISTRATION_PATH.lstrip("/") + ) + + # Add revocation endpoint if supported + if revocation_options.enabled: + metadata.revocation_endpoint = modify_url_path( + issuer_url, lambda path: path.rstrip("/") + REVOCATION_PATH.lstrip("/") + ) + metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"] + + return metadata diff --git a/src/mcp/server/auth/settings.py b/src/mcp/server/auth/settings.py new file mode 100644 index 000000000..1086bb77e --- /dev/null +++ b/src/mcp/server/auth/settings.py @@ -0,0 +1,24 @@ +from pydantic import AnyHttpUrl, BaseModel, Field + + +class ClientRegistrationOptions(BaseModel): + enabled: bool = False + client_secret_expiry_seconds: int | None = None + valid_scopes: list[str] | None = None + default_scopes: list[str] | None = None + + +class RevocationOptions(BaseModel): + enabled: bool = False + + +class AuthSettings(BaseModel): + issuer_url: AnyHttpUrl = Field( + ..., + description="URL advertised as OAuth issuer; this should be the URL the server " + "is reachable at", + ) + service_documentation_url: AnyHttpUrl | None = None + client_registration_options: ClientRegistrationOptions | None = None + revocation_options: RevocationOptions | None = None + required_scopes: list[str] | None = None diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 275bcb36c..70604b7e5 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -5,7 +5,7 @@ import inspect import json import re -from collections.abc import AsyncIterator, Callable, Iterable, Sequence +from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence from contextlib import ( AbstractAsyncContextManager, asynccontextmanager, @@ -19,10 +19,24 @@ from pydantic import BaseModel, Field from pydantic.networks import AnyUrl from pydantic_settings import BaseSettings, SettingsConfigDict +from sse_starlette import EventSourceResponse from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request -from starlette.routing import Mount, Route +from starlette.responses import Response +from starlette.routing import Mount, Route, request_response +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware +from mcp.server.auth.middleware.bearer_auth import ( + BearerAuthBackend, + RequireAuthMiddleware, +) +from mcp.server.auth.provider import OAuthServerProvider +from mcp.server.auth.settings import ( + AuthSettings, +) from mcp.server.fastmcp.exceptions import ResourceError from mcp.server.fastmcp.prompts import Prompt, PromptManager from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager @@ -63,6 +77,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): model_config = SettingsConfigDict( env_prefix="FASTMCP_", env_file=".env", + env_nested_delimiter="__", + nested_model_default_partial_update=True, extra="ignore", ) @@ -94,6 +110,8 @@ class Settings(BaseSettings, Generic[LifespanResultT]): Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None ) = Field(None, description="Lifespan context manager") + auth: AuthSettings | None = None + def lifespan_wrapper( app: FastMCP, @@ -109,7 +127,11 @@ async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]: class FastMCP: def __init__( - self, name: str | None = None, instructions: str | None = None, **settings: Any + self, + name: str | None = None, + instructions: str | None = None, + auth_provider: OAuthServerProvider[Any, Any, Any] | None = None, + **settings: Any, ): self.settings = Settings(**settings) @@ -129,6 +151,13 @@ def __init__( self._prompt_manager = PromptManager( warn_on_duplicate_prompts=self.settings.warn_on_duplicate_prompts ) + if (self.settings.auth is not None) != (auth_provider is not None): + raise ValueError( + "settings.auth must be specified if and only if auth_provider " + "is specified" + ) + self._auth_provider = auth_provider + self._custom_starlette_routes: list[Route] = [] self.dependencies = self.settings.dependencies # Set up MCP protocol handlers @@ -455,6 +484,29 @@ def decorator(func: AnyFunction) -> AnyFunction: return decorator + def custom_route( + self, + path: str, + methods: list[str], + name: str | None = None, + include_in_schema: bool = True, + ): + def decorator( + func: Callable[[Request], Awaitable[Response]], + ) -> Callable[[Request], Awaitable[Response]]: + self._custom_starlette_routes.append( + Route( + path, + endpoint=func, + methods=methods, + name=name, + include_in_schema=include_in_schema, + ) + ) + return func + + return decorator + async def run_stdio_async(self) -> None: """Run the server using stdio transport.""" async with stdio_server() as (read_stream, write_stream): @@ -465,11 +517,8 @@ async def run_stdio_async(self) -> None: ) async def run_sse_async(self) -> None: - """Run the server using SSE transport.""" - starlette_app = self.sse_app() - config = uvicorn.Config( - starlette_app, + app=self.sse_app(), host=self.settings.host, port=self.settings.port, log_level=self.settings.log_level.lower(), @@ -478,10 +527,16 @@ async def run_sse_async(self) -> None: await server.serve() def sse_app(self) -> Starlette: - """Return an instance of the SSE server app.""" + from starlette.middleware import Middleware + from starlette.routing import Mount, Route + + # Set up auth context and dependencies + sse = SseServerTransport(self.settings.message_path) - async def handle_sse(request: Request) -> None: + async def handle_sse(request: Request) -> EventSourceResponse: + # Add client ID from auth context into request context if available + async with sse.connect_sse( request.scope, request.receive, @@ -492,13 +547,71 @@ async def handle_sse(request: Request) -> None: streams[1], self._mcp_server.create_initialization_options(), ) + return streams[2] + + # Create routes + routes: list[Route | Mount] = [] + middleware: list[Middleware] = [] + required_scopes = [] + + # Add auth endpoints if auth provider is configured + if self._auth_provider: + assert self.settings.auth + from mcp.server.auth.routes import create_auth_routes + + required_scopes = self.settings.auth.required_scopes or [] + + middleware = [ + # Add CORS middleware to allow cross-origin requests + Middleware( + CORSMiddleware, + allow_origins=["*"], # Allow any origin + allow_methods=["GET", "POST", "OPTIONS"], + allow_headers=["*"], + allow_credentials=True, + ), + # extract auth info from request (but do not require it) + Middleware( + AuthenticationMiddleware, + backend=BearerAuthBackend( + provider=self._auth_provider, + ), + ), + # Add the auth context middleware to store + # authenticated user in a contextvar + Middleware(AuthContextMiddleware), + ] + routes.extend( + create_auth_routes( + provider=self._auth_provider, + issuer_url=self.settings.auth.issuer_url, + service_documentation_url=self.settings.auth.service_documentation_url, + client_registration_options=self.settings.auth.client_registration_options, + revocation_options=self.settings.auth.revocation_options, + ) + ) + routes.append( + Route( + self.settings.sse_path, + endpoint=RequireAuthMiddleware( + request_response(handle_sse), required_scopes + ), + methods=["GET"], + ) + ) + routes.append( + Mount( + self.settings.message_path, + app=RequireAuthMiddleware(sse.handle_post_message, required_scopes), + ) + ) + # mount these routes last, so they have the lowest route matching precedence + routes.extend(self._custom_starlette_routes) + + # Create Starlette app with routes and middleware return Starlette( - debug=self.settings.debug, - routes=[ - Route(self.settings.sse_path, endpoint=handle_sse), - Mount(self.settings.message_path, app=sse.handle_post_message), - ], + debug=self.settings.debug, routes=routes, middleware=middleware ) async def list_prompts(self) -> list[MCPPrompt]: @@ -652,9 +765,9 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent Returns: The resource content as either text or bytes """ - assert self._fastmcp is not None, ( - "Context is not available outside of a request" - ) + assert ( + self._fastmcp is not None + ), "Context is not available outside of a request" return await self._fastmcp.read_resource(uri) async def log( diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index e14f73e19..f588c10ca 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -576,14 +576,12 @@ async def _handle_notification(self, notify: Any): assert type(notify) in self.notification_handlers handler = self.notification_handlers[type(notify)] - logger.debug( - f"Dispatching notification of type " f"{type(notify).__name__}" - ) + logger.debug(f"Dispatching notification of type {type(notify).__name__}") try: await handler(notify) except Exception as err: - logger.error(f"Uncaught exception in notification handler: " f"{err}") + logger.error(f"Uncaught exception in notification handler: {err}") async def _ping_handler(request: types.PingRequest) -> types.ServerResult: diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index d051c25bf..a48266ca0 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -128,7 +128,7 @@ async def sse_writer(): tg.start_soon(response, scope, receive, send) logger.debug("Yielding read and write streams") - yield (read_stream, write_stream) + yield (read_stream, write_stream, response) async def handle_post_message( self, scope: Scope, receive: Receive, send: Send diff --git a/src/mcp/server/streaming_asgi_transport.py b/src/mcp/server/streaming_asgi_transport.py new file mode 100644 index 000000000..4cbd77370 --- /dev/null +++ b/src/mcp/server/streaming_asgi_transport.py @@ -0,0 +1,205 @@ +""" +A modified version of httpx.ASGITransport that supports streaming responses. + +This transport runs the ASGI app as a separate anyio task, allowing it to +handle streaming responses like SSE where the app doesn't terminate until +the connection is closed. + +This is only intended for writing tests for the SSE transport. +""" + +import typing +from typing import Any, cast + +import anyio +import anyio.abc +import anyio.streams.memory +from httpx._models import Request, Response +from httpx._transports.base import AsyncBaseTransport +from httpx._types import AsyncByteStream +from starlette.types import ASGIApp, Receive, Scope, Send + + +class StreamingASGITransport(AsyncBaseTransport): + """ + A custom AsyncTransport that handles sending requests directly to an ASGI app + and supports streaming responses like SSE. + + Unlike the standard ASGITransport, this transport runs the ASGI app in a + separate anyio task, allowing it to handle responses from apps that don't + terminate immediately (like SSE endpoints). + + Arguments: + + * `app` - The ASGI application. + * `raise_app_exceptions` - Boolean indicating if exceptions in the application + should be raised. Default to `True`. Can be set to `False` for use cases + such as testing the content of a client 500 response. + * `root_path` - The root path on which the ASGI application should be mounted. + * `client` - A two-tuple indicating the client IP and port of incoming requests. + * `response_timeout` - Timeout in seconds to wait for the initial response. + Default is 10 seconds. + """ + + def __init__( + self, + app: ASGIApp, + task_group: anyio.abc.TaskGroup, + raise_app_exceptions: bool = True, + root_path: str = "", + client: tuple[str, int] = ("127.0.0.1", 123), + ) -> None: + self.app = app + self.raise_app_exceptions = raise_app_exceptions + self.root_path = root_path + self.client = client + self.task_group = task_group + + async def handle_async_request( + self, + request: Request, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + # ASGI scope. + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": request.method, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "scheme": request.url.scheme, + "path": request.url.path, + "raw_path": request.url.raw_path.split(b"?")[0], + "query_string": request.url.query, + "server": (request.url.host, request.url.port), + "client": self.client, + "root_path": self.root_path, + } + + # Request body + request_body_chunks = request.stream.__aiter__() + request_complete = False + + # Response state + status_code = 499 + response_headers = None + response_started = False + response_complete = anyio.Event() + initial_response_ready = anyio.Event() + + # Synchronization for streaming response + asgi_send_channel, asgi_receive_channel = anyio.create_memory_object_stream[ + dict[str, Any] + ](100) + content_send_channel, content_receive_channel = ( + anyio.create_memory_object_stream[bytes](100) + ) + + # ASGI callables. + async def receive() -> dict[str, Any]: + nonlocal request_complete + + if request_complete: + await response_complete.wait() + return {"type": "http.disconnect"} + + try: + body = await request_body_chunks.__anext__() + except StopAsyncIteration: + request_complete = True + return {"type": "http.request", "body": b"", "more_body": False} + return {"type": "http.request", "body": body, "more_body": True} + + async def send(message: dict[str, Any]) -> None: + nonlocal status_code, response_headers, response_started + + await asgi_send_channel.send(message) + + # Start the ASGI application in a separate task + async def run_app() -> None: + try: + # Cast the receive and send functions to the ASGI types + await self.app( + cast(Scope, scope), cast(Receive, receive), cast(Send, send) + ) + except Exception: + if self.raise_app_exceptions: + raise + + if not response_started: + await asgi_send_channel.send( + {"type": "http.response.start", "status": 500, "headers": []} + ) + + await asgi_send_channel.send( + {"type": "http.response.body", "body": b"", "more_body": False} + ) + finally: + await asgi_send_channel.aclose() + + # Process messages from the ASGI app + async def process_messages() -> None: + nonlocal status_code, response_headers, response_started + + try: + async with asgi_receive_channel: + async for message in asgi_receive_channel: + if message["type"] == "http.response.start": + assert not response_started + status_code = message["status"] + response_headers = message.get("headers", []) + response_started = True + + # As soon as we have headers, we can return a response + initial_response_ready.set() + + elif message["type"] == "http.response.body": + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body and request.method != "HEAD": + await content_send_channel.send(body) + + if not more_body: + response_complete.set() + await content_send_channel.aclose() + break + finally: + # Ensure events are set even if there's an error + initial_response_ready.set() + response_complete.set() + + # Create tasks for running the app and processing messages + self.task_group.start_soon(run_app) + self.task_group.start_soon(process_messages) + + # Wait for the initial response or timeout + await initial_response_ready.wait() + + # Create a streaming response + return Response( + status_code, + headers=response_headers, + stream=StreamingASGIResponseStream(content_receive_channel), + ) + + +class StreamingASGIResponseStream(AsyncByteStream): + """ + A modified ASGIResponseStream that supports streaming responses. + + This class extends the standard ASGIResponseStream to handle cases where + the response body continues to be generated after the initial response + is returned. + """ + + def __init__( + self, + receive_channel: anyio.streams.memory.MemoryObjectReceiveStream[bytes], + ) -> None: + self.receive_channel = receive_channel + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async for chunk in self.receive_channel: + yield chunk diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py new file mode 100644 index 000000000..22f8a971d --- /dev/null +++ b/src/mcp/shared/auth.py @@ -0,0 +1,137 @@ +from typing import Any, Literal + +from pydantic import AnyHttpUrl, BaseModel, Field + + +class OAuthToken(BaseModel): + """ + See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1 + """ + + access_token: str + token_type: Literal["bearer"] = "bearer" + expires_in: int | None = None + scope: str | None = None + refresh_token: str | None = None + + +class InvalidScopeError(Exception): + def __init__(self, message: str): + self.message = message + + +class InvalidRedirectUriError(Exception): + def __init__(self, message: str): + self.message = message + + +class OAuthClientMetadata(BaseModel): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration metadata. + See https://datatracker.ietf.org/doc/html/rfc7591#section-2 + for the full specification. + """ + + redirect_uris: list[AnyHttpUrl] = Field(..., min_length=1) + # token_endpoint_auth_method: this implementation only supports none & + # client_secret_post; + # ie: we do not support client_secret_basic + token_endpoint_auth_method: Literal["none", "client_secret_post"] = ( + "client_secret_post" + ) + # grant_types: this implementation only supports authorization_code & refresh_token + grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + "authorization_code", + "refresh_token", + ] + # this implementation only supports code; ie: it does not support implicit grants + response_types: list[Literal["code"]] = ["code"] + scope: str | None = None + + # these fields are currently unused, but we support & store them for potential + # future use + client_name: str | None = None + client_uri: AnyHttpUrl | None = None + logo_uri: AnyHttpUrl | None = None + contacts: list[str] | None = None + tos_uri: AnyHttpUrl | None = None + policy_uri: AnyHttpUrl | None = None + jwks_uri: AnyHttpUrl | None = None + jwks: Any | None = None + software_id: str | None = None + software_version: str | None = None + + def validate_scope(self, requested_scope: str | None) -> list[str] | None: + if requested_scope is None: + return None + requested_scopes = requested_scope.split(" ") + allowed_scopes = [] if self.scope is None else self.scope.split(" ") + for scope in requested_scopes: + if scope not in allowed_scopes: + raise InvalidScopeError(f"Client was not registered with scope {scope}") + return requested_scopes + + def validate_redirect_uri(self, redirect_uri: AnyHttpUrl | None) -> AnyHttpUrl: + if redirect_uri is not None: + # Validate redirect_uri against client's registered redirect URIs + if redirect_uri not in self.redirect_uris: + raise InvalidRedirectUriError( + f"Redirect URI '{redirect_uri}' not registered for client" + ) + return redirect_uri + elif len(self.redirect_uris) == 1: + return self.redirect_uris[0] + else: + raise InvalidRedirectUriError( + "redirect_uri must be specified when client " + "has multiple registered URIs" + ) + + +class OAuthClientInformationFull(OAuthClientMetadata): + """ + RFC 7591 OAuth 2.0 Dynamic Client Registration full response + (client information plus metadata). + """ + + client_id: str + client_secret: str | None = None + client_id_issued_at: int | None = None + client_secret_expires_at: int | None = None + + +class OAuthMetadata(BaseModel): + """ + RFC 8414 OAuth 2.0 Authorization Server Metadata. + See https://datatracker.ietf.org/doc/html/rfc8414#section-2 + """ + + issuer: AnyHttpUrl + authorization_endpoint: AnyHttpUrl + token_endpoint: AnyHttpUrl + registration_endpoint: AnyHttpUrl | None = None + scopes_supported: list[str] | None = None + response_types_supported: list[Literal["code"]] = ["code"] + response_modes_supported: list[Literal["query", "fragment"]] | None = None + grant_types_supported: ( + list[Literal["authorization_code", "refresh_token"]] | None + ) = None + token_endpoint_auth_methods_supported: ( + list[Literal["none", "client_secret_post"]] | None + ) = None + token_endpoint_auth_signing_alg_values_supported: None = None + service_documentation: AnyHttpUrl | None = None + ui_locales_supported: list[str] | None = None + op_policy_uri: AnyHttpUrl | None = None + op_tos_uri: AnyHttpUrl | None = None + revocation_endpoint: AnyHttpUrl | None = None + revocation_endpoint_auth_methods_supported: ( + list[Literal["client_secret_post"]] | None + ) = None + revocation_endpoint_auth_signing_alg_values_supported: None = None + introspection_endpoint: AnyHttpUrl | None = None + introspection_endpoint_auth_methods_supported: ( + list[Literal["client_secret_post"]] | None + ) = None + introspection_endpoint_auth_signing_alg_values_supported: None = None + code_challenge_methods_supported: list[Literal["S256"]] | None = None diff --git a/tests/client/test_oauth.py b/tests/client/test_oauth.py new file mode 100644 index 000000000..90ca5683e --- /dev/null +++ b/tests/client/test_oauth.py @@ -0,0 +1,257 @@ +import json +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest +from pydantic import AnyHttpUrl + +from mcp.client.auth.oauth import ( + AccessToken, + ClientMetadata, + DynamicClientRegistration, + OAuthClient, + OAuthClientProvider, +) + + +class MockOauthClientProvider(OAuthClientProvider): + @property + def client_metadata(self) -> ClientMetadata: + return ClientMetadata( + client_name="Test Client", + redirect_uris=[AnyHttpUrl("https://client.example.com/callback")], + token_endpoint_auth_method="client_secret_post", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ) + + @property + def redirect_url(self) -> AnyHttpUrl: + return AnyHttpUrl("https://client.example.com/callback") + + async def open_user_agent(self, url: AnyHttpUrl) -> None: + pass + + async def client_registration( + self, issuer: AnyHttpUrl + ) -> DynamicClientRegistration | None: + return None + + async def store_client_registration( + self, issuer: AnyHttpUrl, metadata: DynamicClientRegistration + ) -> None: + pass + + def code_verifier(self) -> str: + return "test-code-verifier" + + async def token(self) -> AccessToken | None: + return None + + async def store_token(self, token: AccessToken) -> None: + pass + + +@pytest.fixture +def server_url(): + return AnyHttpUrl("https://example.com/v1") + + +@pytest.fixture +def http_server_urls(): + return [ + # HTTP URL should be converted to HTTPS + "http://example.com/auth", + # URL with trailing slash + "http://auth.example.org/", + # Complex path + "http://api.example.net/v1/auth/service", + # URL with query parameters (these should be ignored) + "http://example.io/oauth?version=2.0&debug=true", + # URL with port + "http://auth.example.com:8080/v1", + ] + + +@pytest.fixture +def auth_client(server_url): + return OAuthClient(server_url, MockOauthClientProvider()) + + +@pytest.fixture +def mock_http_response(): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.raise_for_status = MagicMock() + mock_response.aread = AsyncMock( + return_value=json.dumps( + { + "issuer": "https://example.com/v1", + "authorization_endpoint": "https://example.com/v1/authorize", + "token_endpoint": "https://example.com/v1/token", + "registration_endpoint": "https://example.com/v1/register", + "response_types_supported": ["code"], + } + ) + ) + return mock_response + + +@pytest.fixture +def client_metadata(): + return ClientMetadata( + client_name="Test Client", + redirect_uris=[AnyHttpUrl("https://client.example.com/callback")], + token_endpoint_auth_method="client_secret_post", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + ) + + +@pytest.mark.anyio +async def test_discover_auth_metadata(auth_client, mock_http_response): + # Mock the HTTP client's stream method + auth_client.http_client.get = AsyncMock(return_value=mock_http_response) + + # Call the method under test + result = await auth_client.discover_auth_metadata() + + # Assertions + assert result is not None + assert result.issuer == AnyHttpUrl("https://example.com/v1") + assert result.authorization_endpoint == AnyHttpUrl( + "https://example.com/v1/authorize" + ) + assert result.token_endpoint == AnyHttpUrl("https://example.com/v1/token") + assert result.registration_endpoint == AnyHttpUrl("https://example.com/v1/register") + + # Verify the correct URL was used + expected_url = "https://example.com/.well-known/oauth-authorization-server" + auth_client.http_client.get.assert_called_once_with(expected_url) + + +@pytest.mark.anyio +async def test_discover_auth_metadata_not_found(auth_client): + # Mock 404 response + mock_response = MagicMock() + mock_response.status_code = 404 + auth_client.http_client.get = AsyncMock(return_value=mock_response) + + # Call the method under test + result = await auth_client.discover_auth_metadata() + + # Assertions + assert result is None + + +@pytest.mark.anyio +async def test_dynamic_client_registration( + auth_client, client_metadata, mock_http_response +): + # Setup mock response for registration + registration_response = { + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "client_name": "Test Client", + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + } + mock_http_response.aread = AsyncMock(return_value=json.dumps(registration_response)) + auth_client.http_client.post = AsyncMock(return_value=mock_http_response) + + # Call the method under test + registration_endpoint = "https://example.com/v1/register" + result = await auth_client.dynamic_client_registration( + client_metadata, registration_endpoint + ) + + # Assertions + assert result is not None + assert result.client_id == "test-client-id" + assert result.client_secret == "test-client-secret" + assert result.client_name == "Test Client" + + # Verify the request was made correctly + auth_client.http_client.post.assert_called_once_with( + registration_endpoint, + json=client_metadata.model_dump(exclude_none=True), + headers={"Content-Type": "application/json", "Accept": "application/json"}, + ) + + +@pytest.mark.anyio +async def test_dynamic_client_registration_error(auth_client, client_metadata): + # Mock error response + mock_error_response = AsyncMock() + mock_error_response.__aenter__ = AsyncMock(return_value=mock_error_response) + mock_error_response.__aexit__ = AsyncMock(return_value=None) + mock_error_response.status_code = 400 + mock_error_response.raise_for_status = AsyncMock( + side_effect=httpx.HTTPStatusError( + "Client error '400 Bad Request'", + request=MagicMock(), + response=MagicMock( + status_code=400, + content=json.dumps({"error": "invalid_client_metadata"}), + ), + ) + ) + error_json = json.dumps({"error": "invalid_client_metadata"}) + mock_error_response.content = error_json.encode() + + auth_client.http_client.post = AsyncMock(return_value=mock_error_response) + + # Call the method under test + registration_endpoint = "https://example.com/v1/register" + result = await auth_client.dynamic_client_registration( + client_metadata, registration_endpoint + ) + + # Assertions + assert result is None + + +@pytest.mark.parametrize( + "input_url,expected_discovery_url", + [ + # Basic HTTP URL: protocol should be changed to HTTPS + ( + "http://example.com", + "https://example.com/.well-known/oauth-authorization-server", + ), + # URL with trailing slash: should be normalized + ( + "https://example.com/", + "https://example.com/.well-known/oauth-authorization-server", + ), + # URL with complex path: .well-known should be at the root + ( + "https://example.com/api/v1/auth", + "https://example.com/.well-known/oauth-authorization-server", + ), + # URL with query parameters: parameters should be ignored + ( + "https://auth.example.org?version=2.0&debug=true", + "https://auth.example.org/.well-known/oauth-authorization-server", + ), + # URL with port: port should be preserved + ( + "http://auth.example.net:8080", + "https://auth.example.net:8080/.well-known/oauth-authorization-server", + ), + # URL with subdomain, path, and trailing slash: .well-known should be at the + # root + ( + "http://api.auth.example.com/oauth/v2/", + "https://api.auth.example.com/.well-known/oauth-authorization-server", + ), + ], +) +def test_build_discovery_url_with_various_formats(input_url, expected_discovery_url): + # Create auth client with the given URL + auth_client = OAuthClient(AnyHttpUrl(input_url), MockOauthClientProvider()) + + # Assertions + assert auth_client.discovery_url == AnyHttpUrl(expected_discovery_url) diff --git a/tests/server/auth/middleware/test_bearer_auth.py b/tests/server/auth/middleware/test_bearer_auth.py new file mode 100644 index 000000000..a6da24e39 --- /dev/null +++ b/tests/server/auth/middleware/test_bearer_auth.py @@ -0,0 +1,389 @@ +""" +Tests for the BearerAuth middleware components. +""" + +import time +from typing import Any, cast + +import pytest +from starlette.authentication import AuthCredentials +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.types import Message, Receive, Scope, Send + +from mcp.server.auth.middleware.bearer_auth import ( + AuthenticatedUser, + BearerAuthBackend, + RequireAuthMiddleware, +) +from mcp.server.auth.provider import ( + AccessToken, + OAuthServerProvider, +) + + +class MockOAuthProvider: + """Mock OAuth provider for testing. + + This is a simplified version that only implements the methods needed for testing + the BearerAuthMiddleware components. + """ + + def __init__(self): + self.tokens = {} # token -> AccessToken + + def add_token(self, token: str, access_token: AccessToken) -> None: + """Add a token to the provider.""" + self.tokens[token] = access_token + + async def load_access_token(self, token: str) -> AccessToken | None: + """Load an access token.""" + return self.tokens.get(token) + + +def add_token_to_provider( + provider: OAuthServerProvider[Any, Any, Any], token: str, access_token: AccessToken +) -> None: + """Helper function to add a token to a provider. + + This is used to work around type checking issues with our mock provider. + """ + # We know this is actually a MockOAuthProvider + mock_provider = cast(MockOAuthProvider, provider) + mock_provider.add_token(token, access_token) + + +class MockApp: + """Mock ASGI app for testing.""" + + def __init__(self): + self.called = False + self.scope: Scope | None = None + self.receive: Receive | None = None + self.send: Send | None = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.called = True + self.scope = scope + self.receive = receive + self.send = send + + +@pytest.fixture +def mock_oauth_provider() -> OAuthServerProvider[Any, Any, Any]: + """Create a mock OAuth provider.""" + # Use type casting to satisfy the type checker + return cast(OAuthServerProvider[Any, Any, Any], MockOAuthProvider()) + + +@pytest.fixture +def valid_access_token() -> AccessToken: + """Create a valid access token.""" + return AccessToken( + token="valid_token", + client_id="test_client", + scopes=["read", "write"], + expires_at=int(time.time()) + 3600, # 1 hour from now + ) + + +@pytest.fixture +def expired_access_token() -> AccessToken: + """Create an expired access token.""" + return AccessToken( + token="expired_token", + client_id="test_client", + scopes=["read"], + expires_at=int(time.time()) - 3600, # 1 hour ago + ) + + +@pytest.fixture +def no_expiry_access_token() -> AccessToken: + """Create an access token with no expiry.""" + return AccessToken( + token="no_expiry_token", + client_id="test_client", + scopes=["read", "write"], + expires_at=None, + ) + + +@pytest.mark.anyio +class TestBearerAuthBackend: + """Tests for the BearerAuthBackend class.""" + + async def test_no_auth_header( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any] + ): + """Test authentication with no Authorization header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + request = Request({"type": "http", "headers": []}) + result = await backend.authenticate(request) + assert result is None + + async def test_non_bearer_auth_header( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any] + ): + """Test authentication with non-Bearer Authorization header.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Basic dXNlcjpwYXNz")], + } + ) + result = await backend.authenticate(request) + assert result is None + + async def test_invalid_token( + self, mock_oauth_provider: OAuthServerProvider[Any, Any, Any] + ): + """Test authentication with invalid token.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer invalid_token")], + } + ) + result = await backend.authenticate(request) + assert result is None + + async def test_expired_token( + self, + mock_oauth_provider: OAuthServerProvider[Any, Any, Any], + expired_access_token: AccessToken, + ): + """Test authentication with expired token.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider( + mock_oauth_provider, "expired_token", expired_access_token + ) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer expired_token")], + } + ) + result = await backend.authenticate(request) + assert result is None + + async def test_valid_token( + self, + mock_oauth_provider: OAuthServerProvider[Any, Any, Any], + valid_access_token: AccessToken, + ): + """Test authentication with valid token.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider(mock_oauth_provider, "valid_token", valid_access_token) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer valid_token")], + } + ) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == valid_access_token + assert user.scopes == ["read", "write"] + + async def test_token_without_expiry( + self, + mock_oauth_provider: OAuthServerProvider[Any, Any, Any], + no_expiry_access_token: AccessToken, + ): + """Test authentication with token that has no expiry.""" + backend = BearerAuthBackend(provider=mock_oauth_provider) + add_token_to_provider( + mock_oauth_provider, "no_expiry_token", no_expiry_access_token + ) + request = Request( + { + "type": "http", + "headers": [(b"authorization", b"Bearer no_expiry_token")], + } + ) + result = await backend.authenticate(request) + assert result is not None + credentials, user = result + assert isinstance(credentials, AuthCredentials) + assert isinstance(user, AuthenticatedUser) + assert credentials.scopes == ["read", "write"] + assert user.display_name == "test_client" + assert user.access_token == no_expiry_access_token + assert user.scopes == ["read", "write"] + + +@pytest.mark.anyio +class TestRequireAuthMiddleware: + """Tests for the RequireAuthMiddleware class.""" + + async def test_no_user(self): + """Test middleware with no user in scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + scope: Scope = {"type": "http"} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "Unauthorized" + assert not app.called + + async def test_non_authenticated_user(self): + """Test middleware with non-authenticated user in scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + scope: Scope = {"type": "http", "user": object()} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 401 + assert excinfo.value.detail == "Unauthorized" + assert not app.called + + async def test_missing_required_scope(self, valid_access_token: AccessToken): + """Test middleware with user missing required scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["admin"]) + + # Create a user with read/write scopes but not admin + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 403 + assert excinfo.value.detail == "Insufficient scope" + assert not app.called + + async def test_no_auth_credentials(self, valid_access_token: AccessToken): + """Test middleware with no auth credentials in scope.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + + scope: Scope = {"type": "http", "user": user} # No auth credentials + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + with pytest.raises(HTTPException) as excinfo: + await middleware(scope, receive, send) + + assert excinfo.value.status_code == 403 + assert excinfo.value.detail == "Insufficient scope" + assert not app.called + + async def test_has_required_scopes(self, valid_access_token: AccessToken): + """Test middleware with user having all required scopes.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read"]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + await middleware(scope, receive, send) + + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send + + async def test_multiple_required_scopes(self, valid_access_token: AccessToken): + """Test middleware with multiple required scopes.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=["read", "write"]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + await middleware(scope, receive, send) + + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send + + async def test_no_required_scopes(self, valid_access_token: AccessToken): + """Test middleware with no required scopes.""" + app = MockApp() + middleware = RequireAuthMiddleware(app, required_scopes=[]) + + # Create a user with read/write scopes + user = AuthenticatedUser(valid_access_token) + auth = AuthCredentials(["read", "write"]) + + scope: Scope = {"type": "http", "user": user, "auth": auth} + + # Create dummy async functions for receive and send + async def receive() -> Message: + return {"type": "http.request"} + + async def send(message: Message) -> None: + pass + + await middleware(scope, receive, send) + + assert app.called + assert app.scope == scope + assert app.receive == receive + assert app.send == send diff --git a/tests/server/auth/test_error_handling.py b/tests/server/auth/test_error_handling.py new file mode 100644 index 000000000..18e9933e7 --- /dev/null +++ b/tests/server/auth/test_error_handling.py @@ -0,0 +1,294 @@ +""" +Tests for OAuth error handling in the auth handlers. +""" + +import unittest.mock +from urllib.parse import parse_qs, urlparse + +import httpx +import pytest +from httpx import ASGITransport +from pydantic import AnyHttpUrl +from starlette.applications import Starlette + +from mcp.server.auth.provider import ( + AuthorizeError, + RegistrationError, + TokenError, +) +from mcp.server.auth.routes import create_auth_routes +from tests.server.fastmcp.auth.test_auth_integration import ( + MockOAuthProvider, +) + + +@pytest.fixture +def oauth_provider(): + """Return a MockOAuthProvider instance that can be configured to raise errors.""" + return MockOAuthProvider() + + +@pytest.fixture +def app(oauth_provider): + from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions + + # Enable client registration + client_registration_options = ClientRegistrationOptions(enabled=True) + revocation_options = RevocationOptions(enabled=True) + + # Create auth routes + auth_routes = create_auth_routes( + oauth_provider, + issuer_url=AnyHttpUrl("http://localhost"), + client_registration_options=client_registration_options, + revocation_options=revocation_options, + ) + + # Create Starlette app with routes directly + return Starlette(routes=auth_routes) + + +@pytest.fixture +def client(app): + transport = ASGITransport(app=app) + # Use base_url without a path since routes are directly on the app + return httpx.AsyncClient(transport=transport, base_url="http://localhost") + + +@pytest.fixture +def pkce_challenge(): + """Create a PKCE challenge with code_verifier and code_challenge.""" + import base64 + import hashlib + import secrets + + # Generate a code verifier + code_verifier = secrets.token_urlsafe(64)[:128] + + # Create code challenge using S256 method + code_verifier_bytes = code_verifier.encode("ascii") + sha256 = hashlib.sha256(code_verifier_bytes).digest() + code_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=") + + return {"code_verifier": code_verifier, "code_challenge": code_challenge} + + +@pytest.fixture +async def registered_client(client): + """Create and register a test client.""" + # Default client metadata + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "client_name": "Test Client", + } + + response = await client.post("/register", json=client_metadata) + assert response.status_code == 201, f"Failed to register client: {response.content}" + + client_info = response.json() + return client_info + + +class TestRegistrationErrorHandling: + @pytest.mark.anyio + async def test_registration_error_handling(self, client, oauth_provider): + # Mock the register_client method to raise a registration error + with unittest.mock.patch.object( + oauth_provider, + "register_client", + side_effect=RegistrationError( + error="invalid_redirect_uri", + error_description="The redirect URI is invalid", + ), + ): + # Prepare a client registration request + client_data = { + "redirect_uris": ["https://client.example.com/callback"], + "token_endpoint_auth_method": "client_secret_post", + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "client_name": "Test Client", + } + + # Send the registration request + response = await client.post( + "/register", + json=client_data, + ) + + # Verify the response + assert response.status_code == 400, response.content + data = response.json() + assert data["error"] == "invalid_redirect_uri" + assert data["error_description"] == "The redirect URI is invalid" + + +class TestAuthorizeErrorHandling: + @pytest.mark.anyio + async def test_authorize_error_handling( + self, client, oauth_provider, registered_client, pkce_challenge + ): + # Mock the authorize method to raise an authorize error + with unittest.mock.patch.object( + oauth_provider, + "authorize", + side_effect=AuthorizeError( + error="access_denied", error_description="The user denied the request" + ), + ): + # Register the client + client_id = registered_client["client_id"] + redirect_uri = registered_client["redirect_uris"][0] + + # Prepare an authorization request + params = { + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + } + + # Send the authorization request + response = await client.get("/authorize", params=params) + + # Verify the response is a redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert query_params["error"][0] == "access_denied" + assert "error_description" in query_params + assert query_params["state"][0] == "test_state" + + +class TestTokenErrorHandling: + @pytest.mark.anyio + async def test_token_error_handling_auth_code( + self, client, oauth_provider, registered_client, pkce_challenge + ): + # Register the client and get an auth code + client_id = registered_client["client_id"] + client_secret = registered_client["client_secret"] + redirect_uri = registered_client["redirect_uris"][0] + + # First get an authorization code + auth_response = await client.get( + "/authorize", + params={ + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + redirect_url = auth_response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + code = query_params["code"][0] + + # Mock the exchange_authorization_code method to raise a token error + with unittest.mock.patch.object( + oauth_provider, + "exchange_authorization_code", + side_effect=TokenError( + error="invalid_grant", + error_description="The authorization code is invalid", + ), + ): + # Try to exchange the code for tokens + token_response = await client.post( + "/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + "code_verifier": pkce_challenge["code_verifier"], + }, + ) + + # Verify the response + assert token_response.status_code == 400 + data = token_response.json() + assert data["error"] == "invalid_grant" + assert data["error_description"] == "The authorization code is invalid" + + @pytest.mark.anyio + async def test_token_error_handling_refresh_token( + self, client, oauth_provider, registered_client, pkce_challenge + ): + # Register the client and get tokens + client_id = registered_client["client_id"] + client_secret = registered_client["client_secret"] + redirect_uri = registered_client["redirect_uris"][0] + + # First get an authorization code + auth_response = await client.get( + "/authorize", + params={ + "client_id": client_id, + "redirect_uri": redirect_uri, + "response_type": "code", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + assert auth_response.status_code == 302, auth_response.content + + redirect_url = auth_response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + code = query_params["code"][0] + + # Exchange the code for tokens + token_response = await client.post( + "/token", + data={ + "grant_type": "authorization_code", + "code": code, + "redirect_uri": redirect_uri, + "client_id": client_id, + "client_secret": client_secret, + "code_verifier": pkce_challenge["code_verifier"], + }, + ) + + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Mock the exchange_refresh_token method to raise a token error + with unittest.mock.patch.object( + oauth_provider, + "exchange_refresh_token", + side_effect=TokenError( + error="invalid_scope", + error_description="The requested scope is invalid", + ), + ): + # Try to use the refresh token + refresh_response = await client.post( + "/token", + data={ + "grant_type": "refresh_token", + "refresh_token": refresh_token, + "client_id": client_id, + "client_secret": client_secret, + }, + ) + + # Verify the response + assert refresh_response.status_code == 400 + data = refresh_response.json() + assert data["error"] == "invalid_scope" + assert data["error_description"] == "The requested scope is invalid" diff --git a/tests/server/fastmcp/auth/__init__.py b/tests/server/fastmcp/auth/__init__.py new file mode 100644 index 000000000..64d318ec4 --- /dev/null +++ b/tests/server/fastmcp/auth/__init__.py @@ -0,0 +1,3 @@ +""" +Tests for the MCP server auth components. +""" diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py new file mode 100644 index 000000000..e4c310f7b --- /dev/null +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -0,0 +1,1428 @@ +""" +Integration tests for MCP authorization components. +""" + +import base64 +import hashlib +import json +import secrets +import time +import unittest.mock +from urllib.parse import parse_qs, urlparse + +import anyio +import httpx +import pytest +from httpx_sse import aconnect_sse +from pydantic import AnyHttpUrl +from starlette.applications import Starlette + +from mcp.server.auth.provider import ( + AccessToken, + AuthorizationCode, + AuthorizationParams, + OAuthServerProvider, + RefreshToken, + construct_redirect_uri, +) +from mcp.server.auth.routes import ( + ClientRegistrationOptions, + RevocationOptions, + create_auth_routes, +) +from mcp.server.auth.settings import AuthSettings +from mcp.server.fastmcp import FastMCP +from mcp.server.streaming_asgi_transport import StreamingASGITransport +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthToken, +) +from mcp.types import JSONRPCRequest + + +# Mock OAuth provider for testing +class MockOAuthProvider(OAuthServerProvider): + def __init__(self): + self.clients = {} + self.auth_codes = {} # code -> {client_id, code_challenge, redirect_uri} + self.tokens = {} # token -> {client_id, scopes, expires_at} + self.refresh_tokens = {} # refresh_token -> access_token + + async def get_client(self, client_id: str) -> OAuthClientInformationFull | None: + return self.clients.get(client_id) + + async def register_client(self, client_info: OAuthClientInformationFull): + self.clients[client_info.client_id] = client_info + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + # toy authorize implementation which just immediately generates an authorization + # code and completes the redirect + code = AuthorizationCode( + code=f"code_{int(time.time())}", + client_id=client.client_id, + code_challenge=params.code_challenge, + redirect_uri=params.redirect_uri, + expires_at=time.time() + 300, + scopes=params.scopes or ["read", "write"], + ) + self.auth_codes[code.code] = code + + return construct_redirect_uri( + str(params.redirect_uri), code=code.code, state=params.state + ) + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> AuthorizationCode | None: + return self.auth_codes.get(authorization_code) + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode + ) -> OAuthToken: + assert authorization_code.code in self.auth_codes + + # Generate an access token and refresh token + access_token = f"access_{secrets.token_hex(32)}" + refresh_token = f"refresh_{secrets.token_hex(32)}" + + # Store the tokens + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=authorization_code.scopes, + expires_at=int(time.time()) + 3600, + ) + + self.refresh_tokens[refresh_token] = access_token + + # Remove the used code + del self.auth_codes[authorization_code.code] + + return OAuthToken( + access_token=access_token, + token_type="bearer", + expires_in=3600, + scope="read write", + refresh_token=refresh_token, + ) + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> RefreshToken | None: + old_access_token = self.refresh_tokens.get(refresh_token) + if old_access_token is None: + return None + token_info = self.tokens.get(old_access_token) + if token_info is None: + return None + + # Create a RefreshToken object that matches what is expected in later code + refresh_obj = RefreshToken( + token=refresh_token, + client_id=token_info.client_id, + scopes=token_info.scopes, + expires_at=token_info.expires_at, + ) + + return refresh_obj + + async def exchange_refresh_token( + self, + client: OAuthClientInformationFull, + refresh_token: RefreshToken, + scopes: list[str], + ) -> OAuthToken: + # Check if refresh token exists + assert refresh_token.token in self.refresh_tokens + + old_access_token = self.refresh_tokens[refresh_token.token] + + # Check if the access token exists + assert old_access_token in self.tokens + + # Check if the token was issued to this client + token_info = self.tokens[old_access_token] + assert token_info.client_id == client.client_id + + # Generate a new access token and refresh token + new_access_token = f"access_{secrets.token_hex(32)}" + new_refresh_token = f"refresh_{secrets.token_hex(32)}" + + # Store the new tokens + self.tokens[new_access_token] = AccessToken( + token=new_access_token, + client_id=client.client_id, + scopes=scopes or token_info.scopes, + expires_at=int(time.time()) + 3600, + ) + + self.refresh_tokens[new_refresh_token] = new_access_token + + # Remove the old tokens + del self.refresh_tokens[refresh_token.token] + del self.tokens[old_access_token] + + return OAuthToken( + access_token=new_access_token, + token_type="bearer", + expires_in=3600, + scope=" ".join(scopes) if scopes else " ".join(token_info.scopes), + refresh_token=new_refresh_token, + ) + + async def load_access_token(self, token: str) -> AccessToken | None: + token_info = self.tokens.get(token) + + # Check if token is expired + # if token_info.expires_at < int(time.time()): + # raise InvalidTokenError("Access token has expired") + + return token_info and AccessToken( + token=token, + client_id=token_info.client_id, + scopes=token_info.scopes, + expires_at=token_info.expires_at, + ) + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + match token: + case RefreshToken(): + # Remove the refresh token + del self.refresh_tokens[token.token] + + case AccessToken(): + # Remove the access token + del self.tokens[token.token] + + # Also remove any refresh tokens that point to this access token + for refresh_token, access_token in list(self.refresh_tokens.items()): + if access_token == token.token: + del self.refresh_tokens[refresh_token] + + +@pytest.fixture +def mock_oauth_provider(): + return MockOAuthProvider() + + +@pytest.fixture +def auth_app(mock_oauth_provider): + # Create auth router + auth_routes = create_auth_routes( + mock_oauth_provider, + AnyHttpUrl("https://auth.example.com"), + AnyHttpUrl("https://docs.example.com"), + client_registration_options=ClientRegistrationOptions( + enabled=True, + valid_scopes=["read", "write", "profile"], + default_scopes=["read", "write"], + ), + revocation_options=RevocationOptions(enabled=True), + ) + + # Create Starlette app + app = Starlette(routes=auth_routes) + + return app + + +@pytest.fixture +def test_client(auth_app) -> httpx.AsyncClient: + return httpx.AsyncClient( + transport=httpx.ASGITransport(app=auth_app), base_url="https://mcptest.com" + ) + + +@pytest.fixture +async def registered_client(test_client: httpx.AsyncClient, request): + """Create and register a test client. + + Parameters can be customized via indirect parameterization: + @pytest.mark.parametrize("registered_client", + [{"grant_types": ["authorization_code"]}], + indirect=True) + """ + # Default client metadata + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code", "refresh_token"], + } + + # Override with any parameters from the test + if hasattr(request, "param") and request.param: + client_metadata.update(request.param) + + response = await test_client.post("/register", json=client_metadata) + assert response.status_code == 201, f"Failed to register client: {response.content}" + + client_info = response.json() + return client_info + + +@pytest.fixture +def pkce_challenge(): + """Create a PKCE challenge with code_verifier and code_challenge.""" + code_verifier = "some_random_verifier_string" + code_challenge = ( + base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest()) + .decode() + .rstrip("=") + ) + + return {"code_verifier": code_verifier, "code_challenge": code_challenge} + + +@pytest.fixture +async def auth_code(test_client, registered_client, pkce_challenge, request): + """Get an authorization code. + + Parameters can be customized via indirect parameterization: + @pytest.mark.parametrize("auth_code", + [{"redirect_uri": "https://client.example.com/other-callback"}], + indirect=True) + """ + # Default authorize params + auth_params = { + "response_type": "code", + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + } + + # Override with any parameters from the test + if hasattr(request, "param") and request.param: + auth_params.update(request.param) + + response = await test_client.get("/authorize", params=auth_params) + assert response.status_code == 302, f"Failed to get auth code: {response.content}" + + # Extract the authorization code + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params, f"No code in response: {query_params}" + auth_code = query_params["code"][0] + + return { + "code": auth_code, + "redirect_uri": auth_params["redirect_uri"], + "state": query_params.get("state", [None])[0], + } + + +@pytest.fixture +async def tokens(test_client, registered_client, auth_code, pkce_challenge, request): + """Exchange authorization code for tokens. + + Parameters can be customized via indirect parameterization: + @pytest.mark.parametrize("tokens", + [{"code_verifier": "wrong_verifier"}], + indirect=True) + """ + # Default token request params + token_params = { + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + } + + # Override with any parameters from the test + if hasattr(request, "param") and request.param: + token_params.update(request.param) + + response = await test_client.post("/token", data=token_params) + + # Don't assert success here since some tests will intentionally cause errors + return { + "response": response, + "params": token_params, + } + + +class TestAuthEndpoints: + @pytest.mark.anyio + async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): + """Test the OAuth 2.0 metadata endpoint.""" + print("Sending request to metadata endpoint") + response = await test_client.get("/.well-known/oauth-authorization-server") + print(f"Got response: {response.status_code}") + if response.status_code != 200: + print(f"Response content: {response.content}") + assert response.status_code == 200 + + metadata = response.json() + assert metadata["issuer"] == "https://auth.example.com/" + assert ( + metadata["authorization_endpoint"] == "https://auth.example.com/authorize" + ) + assert metadata["token_endpoint"] == "https://auth.example.com/token" + assert metadata["registration_endpoint"] == "https://auth.example.com/register" + assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke" + assert metadata["response_types_supported"] == ["code"] + assert metadata["code_challenge_methods_supported"] == ["S256"] + assert metadata["token_endpoint_auth_methods_supported"] == [ + "client_secret_post" + ] + assert metadata["grant_types_supported"] == [ + "authorization_code", + "refresh_token", + ] + assert metadata["service_documentation"] == "https://docs.example.com/" + + @pytest.mark.anyio + async def test_token_validation_error(self, test_client: httpx.AsyncClient): + """Test token endpoint error - validation error.""" + # Missing required fields + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + # Missing code, code_verifier, client_id, etc. + }, + ) + error_response = response.json() + assert error_response["error"] == "invalid_request" + assert ( + "error_description" in error_response + ) # Contains validation error messages + + @pytest.mark.anyio + async def test_token_invalid_auth_code( + self, test_client, registered_client, pkce_challenge + ): + """Test token endpoint error - authorization code does not exist.""" + # Try to use a non-existent authorization code + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": "non_existent_auth_code", + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + print(f"Status code: {response.status_code}") + print(f"Response body: {response.content}") + print(f"Response JSON: {response.json()}") + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert ( + "authorization code does not exist" in error_response["error_description"] + ) + + @pytest.mark.anyio + async def test_token_expired_auth_code( + self, + test_client, + registered_client, + auth_code, + pkce_challenge, + mock_oauth_provider, + ): + """Test token endpoint error - authorization code has expired.""" + # Get the current time for our time mocking + current_time = time.time() + + # Find the auth code object + code_value = auth_code["code"] + found_code = None + for code_obj in mock_oauth_provider.auth_codes.values(): + if code_obj.code == code_value: + found_code = code_obj + break + + assert found_code is not None + + # Authorization codes are typically short-lived (5 minutes = 300 seconds) + # So we'll mock time to be 10 minutes (600 seconds) in the future + with unittest.mock.patch("time.time", return_value=current_time + 600): + # Try to use the expired authorization code + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": code_value, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert ( + "authorization code has expired" in error_response["error_description"] + ) + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [ + { + "redirect_uris": [ + "https://client.example.com/callback", + "https://client.example.com/other-callback", + ] + } + ], + indirect=True, + ) + async def test_token_redirect_uri_mismatch( + self, test_client, registered_client, auth_code, pkce_challenge + ): + """Test token endpoint error - redirect URI mismatch.""" + # Try to use the code with a different redirect URI + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + # Different from the one used in /authorize + "redirect_uri": "https://client.example.com/other-callback", + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_request" + assert "redirect_uri did not match" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_code_verifier_mismatch( + self, test_client, registered_client, auth_code + ): + """Test token endpoint error - PKCE code verifier mismatch.""" + # Try to use the code with an incorrect code verifier + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + # Different from the one used to create challenge + "code_verifier": "incorrect_code_verifier", + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "incorrect code_verifier" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_invalid_refresh_token(self, test_client, registered_client): + """Test token endpoint error - refresh token does not exist.""" + # Try to use a non-existent refresh token + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": "non_existent_refresh_token", + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "refresh token does not exist" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_expired_refresh_token( + self, + test_client, + registered_client, + auth_code, + pkce_challenge, + mock_oauth_provider, + ): + """Test token endpoint error - refresh token has expired.""" + # Step 1: First, let's create a token and refresh token at the current time + current_time = time.time() + + # Exchange authorization code for tokens normally + token_response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert token_response.status_code == 200 + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Step 2: Time travel forward 4 hours (tokens expire in 1 hour by default) + # Mock the time.time() function to return a value 4 hours in the future + with unittest.mock.patch( + "time.time", return_value=current_time + 14400 + ): # 4 hours = 14400 seconds + # Try to use the refresh token which should now be considered expired + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": refresh_token, + }, + ) + + # In the "future", the token should be considered expired + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_grant" + assert "refresh token has expired" in error_response["error_description"] + + @pytest.mark.anyio + async def test_token_invalid_scope( + self, test_client, registered_client, auth_code, pkce_challenge + ): + """Test token endpoint error - invalid scope in refresh token request.""" + # Exchange authorization code for tokens + token_response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "code": auth_code["code"], + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": auth_code["redirect_uri"], + }, + ) + assert token_response.status_code == 200 + + tokens = token_response.json() + refresh_token = tokens["refresh_token"] + + # Try to use refresh token with an invalid scope + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "refresh_token": refresh_token, + "scope": "read write invalid_scope", # Adding an invalid scope + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_scope" + assert "cannot request scope" in error_response["error_description"] + + @pytest.mark.anyio + async def test_client_registration( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ): + """Test client registration.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "client_uri": "https://client.example.com", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201, response.content + + client_info = response.json() + assert "client_id" in client_info + assert "client_secret" in client_info + assert client_info["client_name"] == "Test Client" + assert client_info["redirect_uris"] == ["https://client.example.com/callback"] + + # Verify that the client was registered + # assert await mock_oauth_provider.clients_store.get_client( + # client_info["client_id"] + # ) is not None + + @pytest.mark.anyio + async def test_client_registration_missing_required_fields( + self, test_client: httpx.AsyncClient + ): + """Test client registration with missing required fields.""" + # Missing redirect_uris which is a required field + client_metadata = { + "client_name": "Test Client", + "client_uri": "https://client.example.com", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert error_data["error_description"] == "redirect_uris: Field required" + + @pytest.mark.anyio + async def test_client_registration_invalid_uri( + self, test_client: httpx.AsyncClient + ): + """Test client registration with invalid URIs.""" + # Invalid redirect_uri format + client_metadata = { + "redirect_uris": ["not-a-valid-uri"], + "client_name": "Test Client", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert error_data["error_description"] == ( + "redirect_uris.0: Input should be a valid URL, " + "relative URL without a base" + ) + + @pytest.mark.anyio + async def test_client_registration_empty_redirect_uris( + self, test_client: httpx.AsyncClient + ): + """Test client registration with empty redirect_uris array.""" + client_metadata = { + "redirect_uris": [], # Empty array + "client_name": "Test Client", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert ( + error_data["error_description"] + == "redirect_uris: List should have at least 1 item after validation, not 0" + ) + + @pytest.mark.anyio + async def test_authorize_form_post( + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge, + ): + """Test the authorization endpoint using POST with form-encoded data.""" + # Register a client + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # Use POST with form-encoded data for authorization + response = await test_client.post( + "/authorize", + data={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_form_state", + }, + ) + assert response.status_code == 302 + + # Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + assert query_params["state"][0] == "test_form_state" + + @pytest.mark.anyio + async def test_authorization_get( + self, + test_client: httpx.AsyncClient, + mock_oauth_provider: MockOAuthProvider, + pkce_challenge, + ): + """Test the full authorization flow.""" + # 1. Register a client + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code", "refresh_token"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # 2. Request authorization using GET with query params + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + assert response.status_code == 302 + + # 3. Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + assert query_params["state"][0] == "test_state" + auth_code = query_params["code"][0] + + # 4. Exchange the authorization code for tokens + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + + token_response = response.json() + assert "access_token" in token_response + assert "token_type" in token_response + assert "refresh_token" in token_response + assert "expires_in" in token_response + assert token_response["token_type"] == "bearer" + + # 5. Verify the access token + access_token = token_response["access_token"] + refresh_token = token_response["refresh_token"] + + # Create a test client with the token + auth_info = await mock_oauth_provider.load_access_token(access_token) + assert auth_info + assert auth_info.client_id == client_info["client_id"] + assert "read" in auth_info.scopes + assert "write" in auth_info.scopes + + # 6. Refresh the token + response = await test_client.post( + "/token", + data={ + "grant_type": "refresh_token", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "refresh_token": refresh_token, + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + + new_token_response = response.json() + assert "access_token" in new_token_response + assert "refresh_token" in new_token_response + assert new_token_response["access_token"] != access_token + assert new_token_response["refresh_token"] != refresh_token + + # 7. Revoke the token + response = await test_client.post( + "/revoke", + data={ + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "token": new_token_response["access_token"], + }, + ) + assert response.status_code == 200 + + # Verify that the token was revoked + assert ( + await mock_oauth_provider.load_access_token( + new_token_response["access_token"] + ) + is None + ) + + @pytest.mark.anyio + async def test_revoke_invalid_token(self, test_client, registered_client): + """Test revoking an invalid token.""" + response = await test_client.post( + "/revoke", + data={ + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "token": "invalid_token", + }, + ) + # per RFC, this should return 200 even if the token is invalid + assert response.status_code == 200 + + @pytest.mark.anyio + async def test_revoke_with_malformed_token(self, test_client, registered_client): + response = await test_client.post( + "/revoke", + data={ + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "token": 123, + "token_type_hint": "asdf", + }, + ) + assert response.status_code == 400 + error_response = response.json() + assert error_response["error"] == "invalid_request" + assert "token_type_hint" in error_response["error_description"] + + @pytest.mark.anyio + async def test_client_registration_disallowed_scopes( + self, test_client: httpx.AsyncClient + ): + """Test client registration with scopes that are not allowed.""" + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "scope": "read write profile admin", # 'admin' is not in valid_scopes + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert "scope" in error_data["error_description"] + assert "admin" in error_data["error_description"] + + @pytest.mark.anyio + async def test_client_registration_default_scopes( + self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider + ): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + # No scope specified + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # Verify client was registered successfully + assert client_info["scope"] == "read write" + + # Retrieve the client from the store to verify default scopes + registered_client = await mock_oauth_provider.get_client( + client_info["client_id"] + ) + assert registered_client is not None + + # Check that default scopes were applied + assert registered_client.scope == "read write" + + @pytest.mark.anyio + async def test_client_registration_invalid_grant_type( + self, test_client: httpx.AsyncClient + ): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + "grant_types": ["authorization_code"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 400 + error_data = response.json() + assert "error" in error_data + assert error_data["error"] == "invalid_client_metadata" + assert ( + error_data["error_description"] + == "grant_types must be authorization_code and refresh_token" + ) + + +class TestFastMCPWithAuth: + """Test FastMCP server with authentication.""" + + @pytest.mark.anyio + async def test_fastmcp_with_auth( + self, mock_oauth_provider: MockOAuthProvider, pkce_challenge + ): + """Test creating a FastMCP server with authentication.""" + # Create FastMCP server with auth provider + mcp = FastMCP( + auth_provider=mock_oauth_provider, + require_auth=True, + auth=AuthSettings( + issuer_url=AnyHttpUrl("https://auth.example.com"), + client_registration_options=ClientRegistrationOptions(enabled=True), + revocation_options=RevocationOptions(enabled=True), + required_scopes=["read", "write"], + ), + ) + + # Add a test tool + @mcp.tool() + def test_tool(x: int) -> str: + return f"Result: {x}" + + async with anyio.create_task_group() as task_group: + transport = StreamingASGITransport( + app=mcp.sse_app(), + task_group=task_group, + ) + test_client = httpx.AsyncClient( + transport=transport, base_url="http://mcptest.com" + ) + + # Test metadata endpoint + response = await test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + + # Test that auth is required for protected endpoints + response = await test_client.get("/sse") + assert response.status_code == 401 + + response = await test_client.post("/messages/") + assert response.status_code == 401, response.content + + response = await test_client.post( + "/messages/", + headers={"Authorization": "invalid"}, + ) + assert response.status_code == 401 + + response = await test_client.post( + "/messages/", + headers={"Authorization": "Bearer invalid"}, + ) + assert response.status_code == 401 + + # now, become authenticated and try to go through the flow again + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "Test Client", + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + assert response.status_code == 201 + client_info = response.json() + + # Request authorization using POST with form-encoded data + response = await test_client.post( + "/authorize", + data={ + "response_type": "code", + "client_id": client_info["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + assert response.status_code == 302 + + # Extract the authorization code from the redirect URL + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "code" in query_params + auth_code = query_params["code"][0] + + # Exchange the authorization code for tokens + response = await test_client.post( + "/token", + data={ + "grant_type": "authorization_code", + "client_id": client_info["client_id"], + "client_secret": client_info["client_secret"], + "code": auth_code, + "code_verifier": pkce_challenge["code_verifier"], + "redirect_uri": "https://client.example.com/callback", + }, + ) + assert response.status_code == 200 + + token_response = response.json() + assert "access_token" in token_response + authorization = f"Bearer {token_response['access_token']}" + + # Test the authenticated endpoint with valid token + async with aconnect_sse( + test_client, "GET", "/sse", headers={"Authorization": authorization} + ) as event_source: + assert event_source.response.status_code == 200 + events = event_source.aiter_sse() + sse = await events.__anext__() + assert sse.event == "endpoint" + assert sse.data.startswith("/messages/?session_id=") + messages_uri = sse.data + + # verify that we can now post to the /messages endpoint, + # and get a response on the /sse endpoint + response = await test_client.post( + messages_uri, + headers={"Authorization": authorization}, + content=JSONRPCRequest( + jsonrpc="2.0", + id="123", + method="initialize", + params={ + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": {"listChanged": True}, + "sampling": {}, + }, + "clientInfo": {"name": "ExampleClient", "version": "1.0.0"}, + }, + ).model_dump_json(), + ) + assert response.status_code == 202 + assert response.content == b"Accepted" + + sse = await events.__anext__() + assert sse.event == "message" + sse_data = json.loads(sse.data) + assert sse_data["id"] == "123" + assert set(sse_data["result"]["capabilities"].keys()) == set( + ("experimental", "prompts", "resources", "tools") + ) + # the /sse endpoint will never finish; normally, the client could just + # disconnect, but in tests the easiest way to do this is to cancel the + # task group + task_group.cancel_scope.cancel() + + +class TestAuthorizeEndpointErrors: + """Test error handling in the OAuth authorization endpoint.""" + + @pytest.mark.anyio + async def test_authorize_missing_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge + ): + """Test authorization endpoint with missing client_id. + + According to the OAuth2.0 spec, if client_id is missing, the server should + inform the resource owner and NOT redirect. + """ + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + # Missing client_id + "redirect_uri": "https://client.example.com/callback", + "state": "test_state", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + }, + ) + + # Should NOT redirect, should show an error page + assert response.status_code == 400 + # The response should include an error message about missing client_id + assert "client_id" in response.text.lower() + + @pytest.mark.anyio + async def test_authorize_invalid_client_id( + self, test_client: httpx.AsyncClient, pkce_challenge + ): + """Test authorization endpoint with invalid client_id. + + According to the OAuth2.0 spec, if client_id is invalid, the server should + inform the resource owner and NOT redirect. + """ + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": "invalid_client_id_that_does_not_exist", + "redirect_uri": "https://client.example.com/callback", + "state": "test_state", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + }, + ) + + # Should NOT redirect, should show an error page + assert response.status_code == 400 + # The response should include an error message about invalid client_id + assert "client" in response.text.lower() + + @pytest.mark.anyio + async def test_authorize_missing_redirect_uri( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with missing redirect_uri. + + If client has only one registered redirect_uri, it can be omitted. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Missing redirect_uri + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should redirect to the registered redirect_uri + assert response.status_code == 302, response.content + redirect_url = response.headers["location"] + assert redirect_url.startswith("https://client.example.com/callback") + + @pytest.mark.anyio + async def test_authorize_invalid_redirect_uri( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with invalid redirect_uri. + + According to the OAuth2.0 spec, if redirect_uri is invalid or doesn't match, + the server should inform the resource owner and NOT redirect. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Non-matching URI + "redirect_uri": "https://attacker.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should NOT redirect, should show an error page + assert response.status_code == 400, response.content + # The response should include an error message about redirect_uri mismatch + assert "redirect" in response.text.lower() + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [ + { + "redirect_uris": [ + "https://client.example.com/callback", + "https://client.example.com/other-callback", + ] + } + ], + indirect=True, + ) + async def test_authorize_missing_redirect_uri_multiple_registered( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test endpoint with missing redirect_uri with multiple registered URIs. + + If client has multiple registered redirect_uris, redirect_uri must be provided. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Missing redirect_uri + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should NOT redirect, should return a 400 error + assert response.status_code == 400 + # The response should include an error message about missing redirect_uri + assert "redirect_uri" in response.text.lower() + + @pytest.mark.anyio + async def test_authorize_unsupported_response_type( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with unsupported response_type. + + According to the OAuth2.0 spec, for other errors like unsupported_response_type, + the server should redirect with error parameters. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "token", # Unsupported (we only support "code") + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "unsupported_response_type" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + async def test_authorize_missing_response_type( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with missing response_type. + + Missing required parameter should result in invalid_request error. + """ + + response = await test_client.get( + "/authorize", + params={ + # Missing response_type + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "state": "test_state", + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "invalid_request" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + async def test_authorize_missing_pkce_challenge( + self, test_client: httpx.AsyncClient, registered_client + ): + """Test authorization endpoint with missing PKCE code_challenge. + + Missing PKCE parameters should result in invalid_request error. + """ + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + # Missing code_challenge + "state": "test_state", + # using default URL + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "invalid_request" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + async def test_authorize_invalid_scope( + self, test_client: httpx.AsyncClient, registered_client, pkce_challenge + ): + """Test authorization endpoint with invalid scope. + + Invalid scope should redirect with invalid_scope error. + """ + + response = await test_client.get( + "/authorize", + params={ + "response_type": "code", + "client_id": registered_client["client_id"], + "redirect_uri": "https://client.example.com/callback", + "code_challenge": pkce_challenge["code_challenge"], + "code_challenge_method": "S256", + "scope": "invalid_scope_that_does_not_exist", + "state": "test_state", + }, + ) + + # Should redirect with error parameters + assert response.status_code == 302 + redirect_url = response.headers["location"] + parsed_url = urlparse(redirect_url) + query_params = parse_qs(parsed_url.query) + + assert "error" in query_params + assert query_params["error"][0] == "invalid_scope" + # State should be preserved + assert "state" in query_params + assert query_params["state"][0] == "test_state" diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 43107b597..f5158c3c3 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -138,9 +138,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 2aca97e15..1381c8153 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -134,9 +134,7 @@ def server(server_port: int) -> Generator[None, None, None]: time.sleep(0.1) attempt += 1 else: - raise RuntimeError( - f"Server failed to start after {max_attempts} attempts" - ) + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") yield diff --git a/uv.lock b/uv.lock index 664256f67..a1e9add6a 100644 --- a/uv.lock +++ b/uv.lock @@ -236,6 +236,7 @@ dependencies = [ { name = "sse-starlette" }, { name = "starlette" }, { name = "uvicorn" }, + { name = "python-multipart" }, ] [package.optional-dependencies] @@ -280,7 +281,7 @@ provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [ - { name = "pyright", specifier = ">=1.1.391" }, + { name = "pyright", specifier = ">=1.1.396" }, { name = "pytest", specifier = ">=8.3.4" }, { name = "pytest-examples", specifier = ">=0.0.14" }, { name = "pytest-flakefinder", specifier = ">=1.1.0" }, @@ -585,15 +586,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.391" +version = "1.1.396" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/11/05/4ea52a8a45cc28897edb485b4102d37cbfd5fce8445d679cdeb62bfad221/pyright-1.1.391.tar.gz", hash = "sha256:66b2d42cdf5c3cbab05f2f4b76e8bec8aa78e679bfa0b6ad7b923d9e027cadb2", size = 21965 } +sdist = { url = "https://files.pythonhosted.org/packages/bd/73/f20cb1dea1bdc1774e7f860fb69dc0718c7d8dea854a345faec845eb086a/pyright-1.1.396.tar.gz", hash = "sha256:142901f5908f5a0895be3d3befcc18bedcdb8cc1798deecaec86ef7233a29b03", size = 3814400 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/89/66f49552fbeb21944c8077d11834b2201514a56fd1b7747ffff9630f1bd9/pyright-1.1.391-py3-none-any.whl", hash = "sha256:54fa186f8b3e8a55a44ebfa842636635688670c6896dcf6cf4a7fc75062f4d15", size = 18579 }, + { url = "https://files.pythonhosted.org/packages/80/be/ecb7cfb42d242b7ee764b52e6ff4782beeec00e3b943a3ec832b281f9da6/pyright-1.1.396-py3-none-any.whl", hash = "sha256:c635e473095b9138c471abccca22b9fedbe63858e0b40d4fc4b67da041891844", size = 5689355 }, ] [[package]] @@ -904,3 +905,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884 }, { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743 }, ] + +[[package]] + +name = "python-multipart" +version = "0.0.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/87/f44d7c9f274c7ee665a29b885ec97089ec5dc034c7f3fafa03da9e39a09e/python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13", size = 85321 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/58/38b5afbc1a800eeea951b9285d3912613f2603bdf897a4ab0f4bd7f405fc/python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104", size=11111 }, +]