|
| 1 | +""" |
| 2 | +Automated OAuth server client for MCP servers requiring OAuth authentication. |
| 3 | +
|
| 4 | +This client handles OAuth using client credentials grant (machine-to-machine authentication) |
| 5 | +without browser interaction. |
| 6 | +""" |
| 7 | + |
| 8 | +import logging |
| 9 | +import time |
| 10 | +from datetime import timedelta |
| 11 | +from typing import Optional |
| 12 | + |
| 13 | +import httpx |
| 14 | + |
| 15 | +from mcp.client.auth import OAuthClientProvider, TokenStorage |
| 16 | +from mcp.client.streamable_http import streamablehttp_client, MCP_PROTOCOL_VERSION |
| 17 | +from mcp.shared.auth import ( |
| 18 | + OAuthClientInformationFull, |
| 19 | + OAuthClientMetadata, |
| 20 | + OAuthMetadata, |
| 21 | + OAuthToken, |
| 22 | + ProtectedResourceMetadata, |
| 23 | +) |
| 24 | +from mcp.types import LATEST_PROTOCOL_VERSION |
| 25 | + |
| 26 | + |
| 27 | +class InMemoryTokenStorage(TokenStorage): |
| 28 | + """ |
| 29 | + Simple in-memory token storage implementation for automated OAuth. |
| 30 | + In production, you should persist tokens securely. However, for |
| 31 | + the demo chatbot, it's ok to ask the user to re-authenticate |
| 32 | + in the browser each time they run the chatbot. |
| 33 | + """ |
| 34 | + |
| 35 | + def __init__(self): |
| 36 | + self._client_info: Optional[OAuthClientInformationFull] = None |
| 37 | + self._token: Optional[OAuthToken] = None |
| 38 | + |
| 39 | + async def get_client_info(self) -> Optional[OAuthClientInformationFull]: |
| 40 | + return self._client_info |
| 41 | + |
| 42 | + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: |
| 43 | + self._client_info = client_info |
| 44 | + |
| 45 | + async def get_token(self) -> Optional[OAuthToken]: |
| 46 | + return self._token |
| 47 | + |
| 48 | + async def set_token(self, token: OAuthToken) -> None: |
| 49 | + self._token = token |
| 50 | + |
| 51 | + async def clear_token(self) -> None: |
| 52 | + self._token = None |
| 53 | + |
| 54 | + |
| 55 | +class AutomatedOAuthClientProvider(OAuthClientProvider): |
| 56 | + """ |
| 57 | + OAuth client provider for automated (client credentials) OAuth flows. |
| 58 | + This provider handles machine-to-machine authentication without user interaction. |
| 59 | + """ |
| 60 | + |
| 61 | + def __init__( |
| 62 | + self, |
| 63 | + server_url: str, |
| 64 | + client_metadata: OAuthClientMetadata, |
| 65 | + storage: TokenStorage, |
| 66 | + client_id: str, |
| 67 | + client_secret: str, |
| 68 | + authorization_server_url: str, |
| 69 | + ): |
| 70 | + # Create dummy handlers since they won't be used in client credentials flow |
| 71 | + async def dummy_redirect_handler(url: str) -> None: |
| 72 | + raise RuntimeError( |
| 73 | + "Redirect handler should not be called in automated OAuth flow" |
| 74 | + ) |
| 75 | + |
| 76 | + async def dummy_callback_handler() -> tuple[str, Optional[str]]: |
| 77 | + raise RuntimeError( |
| 78 | + "Callback handler should not be called in automated OAuth flow" |
| 79 | + ) |
| 80 | + |
| 81 | + super().__init__( |
| 82 | + server_url=server_url, |
| 83 | + client_metadata=client_metadata, |
| 84 | + storage=storage, |
| 85 | + redirect_handler=dummy_redirect_handler, |
| 86 | + callback_handler=dummy_callback_handler, |
| 87 | + ) |
| 88 | + |
| 89 | + self.client_id = client_id |
| 90 | + self.client_secret = client_secret |
| 91 | + self.authorization_server_url = authorization_server_url |
| 92 | + |
| 93 | + async def perform_client_credentials_flow(self) -> None: |
| 94 | + """Performs the client credentials OAuth flow to obtain access tokens.""" |
| 95 | + try: |
| 96 | + # Check if we already have valid tokens |
| 97 | + current_tokens = await self.context.storage.get_token() |
| 98 | + if current_tokens and current_tokens.access_token: |
| 99 | + self.context.current_tokens = current_tokens |
| 100 | + self.context.update_token_expiry(current_tokens) |
| 101 | + if self.context.is_token_valid(): |
| 102 | + logging.debug("Using existing valid access token") |
| 103 | + return |
| 104 | + |
| 105 | + logging.debug("Performing client credentials flow...") |
| 106 | + |
| 107 | + # Set auth server URL and discover OAuth metadata using upstream logic |
| 108 | + self.context.auth_server_url = self.authorization_server_url |
| 109 | + await self._discover_oauth_metadata() |
| 110 | + |
| 111 | + if ( |
| 112 | + not self.context.oauth_metadata |
| 113 | + or not self.context.oauth_metadata.token_endpoint |
| 114 | + ): |
| 115 | + raise RuntimeError("No token endpoint found in OAuth metadata") |
| 116 | + |
| 117 | + # Create client info and store it |
| 118 | + client_info = OAuthClientInformationFull( |
| 119 | + client_id=self.client_id, |
| 120 | + client_secret=self.client_secret, |
| 121 | + client_id_issued_at=int(time.time()), |
| 122 | + **self.context.client_metadata.model_dump(exclude_unset=True), |
| 123 | + ) |
| 124 | + await self.context.storage.set_client_info(client_info) |
| 125 | + self.context.client_info = client_info |
| 126 | + |
| 127 | + # Perform client credentials token request |
| 128 | + token_data = { |
| 129 | + "grant_type": "client_credentials", |
| 130 | + "client_id": self.client_id, |
| 131 | + "client_secret": self.client_secret, |
| 132 | + } |
| 133 | + |
| 134 | + # Add scope if specified |
| 135 | + if self.context.client_metadata.scope: |
| 136 | + token_data["scope"] = self.context.client_metadata.scope |
| 137 | + |
| 138 | + logging.debug( |
| 139 | + f"Making token request to: {self.context.oauth_metadata.token_endpoint}" |
| 140 | + ) |
| 141 | + |
| 142 | + async with httpx.AsyncClient() as client: |
| 143 | + response = await client.post( |
| 144 | + str(self.context.oauth_metadata.token_endpoint), |
| 145 | + data=token_data, |
| 146 | + headers={"Content-Type": "application/x-www-form-urlencoded"}, |
| 147 | + ) |
| 148 | + |
| 149 | + if response.status_code != 200: |
| 150 | + error_text = response.text |
| 151 | + raise RuntimeError( |
| 152 | + f"Token request failed: HTTP {response.status_code} - {error_text}" |
| 153 | + ) |
| 154 | + |
| 155 | + token_response = response.json() |
| 156 | + |
| 157 | + # Create and store tokens |
| 158 | + tokens = OAuthToken( |
| 159 | + access_token=token_response["access_token"], |
| 160 | + token_type=token_response.get("token_type", "Bearer"), |
| 161 | + expires_in=token_response.get("expires_in"), |
| 162 | + refresh_token=token_response.get("refresh_token"), |
| 163 | + scope=token_response.get("scope"), |
| 164 | + ) |
| 165 | + |
| 166 | + await self.context.storage.set_token(tokens) |
| 167 | + self.context.current_tokens = tokens |
| 168 | + self.context.update_token_expiry(tokens) |
| 169 | + |
| 170 | + logging.debug( |
| 171 | + "Successfully obtained access token via client credentials flow" |
| 172 | + ) |
| 173 | + |
| 174 | + except Exception as error: |
| 175 | + logging.error(f"Client credentials flow failed: {error}") |
| 176 | + raise |
| 177 | + |
| 178 | + async def _discover_oauth_metadata(self) -> None: |
| 179 | + """Discover OAuth metadata using upstream MCP SDK discovery logic.""" |
| 180 | + # Use the inherited _get_discovery_urls method from parent class |
| 181 | + discovery_urls = self._get_discovery_urls() |
| 182 | + |
| 183 | + async with httpx.AsyncClient() as client: |
| 184 | + for metadata_url in discovery_urls: |
| 185 | + try: |
| 186 | + response = await client.get(metadata_url, follow_redirects=True) |
| 187 | + if response.status_code == 200: |
| 188 | + metadata = OAuthMetadata.model_validate_json(response.content) |
| 189 | + self.context.oauth_metadata = metadata |
| 190 | + logging.debug( |
| 191 | + f"Successfully discovered OAuth metadata from: {metadata_url}" |
| 192 | + ) |
| 193 | + return |
| 194 | + except Exception as e: |
| 195 | + logging.debug( |
| 196 | + f"Failed to discover OAuth metadata from {metadata_url}: {e}" |
| 197 | + ) |
| 198 | + continue |
| 199 | + |
| 200 | + raise RuntimeError( |
| 201 | + "Failed to discover OAuth metadata from any well-known endpoint" |
| 202 | + ) |
| 203 | + |
| 204 | + |
| 205 | +class AutomatedOAuthClient: |
| 206 | + """ |
| 207 | + Manages OAuth authentication for MCP servers requiring automated OAuth. |
| 208 | +
|
| 209 | + This client handles OAuth using client credentials grant (machine-to-machine authentication) |
| 210 | + without browser interaction. |
| 211 | + """ |
| 212 | + |
| 213 | + def __init__(self, name: str, server_url: str, client_id: str, client_secret: str): |
| 214 | + self.name = name |
| 215 | + self.server_url = server_url |
| 216 | + self.client_id = client_id |
| 217 | + self.client_secret = client_secret |
| 218 | + |
| 219 | + async def create_transport(self): |
| 220 | + """Create OAuth-authenticated transport for MCP communication.""" |
| 221 | + logging.debug(f"Connecting to OAuth-protected MCP server: {self.server_url}") |
| 222 | + |
| 223 | + # Discover the required scope from the server |
| 224 | + scope, authorization_server_url = await self._discover_scope_and_auth_server(self.server_url) |
| 225 | + |
| 226 | + # Get OAuth client configuration (handled by mcp_clients.py) |
| 227 | + if not self.client_id or not self.client_secret: |
| 228 | + raise ValueError("client_id and client_secret must be provided") |
| 229 | + |
| 230 | + # Create client metadata |
| 231 | + client_metadata = OAuthClientMetadata( |
| 232 | + client_name=f"MCP Client - {self.name}", |
| 233 | + redirect_uris=["http://localhost"], # Required but not used in client credentials flow |
| 234 | + grant_types=["authorization_code"], # Required format, though we use client_credentials |
| 235 | + response_types=["code"], # Required for authorization_code grant type |
| 236 | + token_endpoint_auth_method="client_secret_post", |
| 237 | + scope=scope, |
| 238 | + ) |
| 239 | + |
| 240 | + # Create storage and OAuth provider |
| 241 | + storage = InMemoryTokenStorage() |
| 242 | + oauth_provider = AutomatedOAuthClientProvider( |
| 243 | + server_url=self.server_url, |
| 244 | + client_metadata=client_metadata, |
| 245 | + storage=storage, |
| 246 | + client_id=self.client_id, |
| 247 | + client_secret=self.client_secret, |
| 248 | + authorization_server_url=authorization_server_url, |
| 249 | + ) |
| 250 | + |
| 251 | + # Perform client credentials flow |
| 252 | + logging.debug("Starting automated OAuth flow...") |
| 253 | + await oauth_provider.perform_client_credentials_flow() |
| 254 | + |
| 255 | + # Create transport with OAuth provider |
| 256 | + logging.debug("Creating transport with automated OAuth provider...") |
| 257 | + return streamablehttp_client( |
| 258 | + url=self.server_url, |
| 259 | + auth=oauth_provider, |
| 260 | + timeout=timedelta(seconds=60), |
| 261 | + ) |
| 262 | + |
| 263 | + async def _discover_scope_and_auth_server(self, server_url: str) -> tuple[str, str]: |
| 264 | + """Discovers the required scope and authorization server from OAuth protected resource metadata.""" |
| 265 | + logging.debug("Making initial request to discover OAuth metadata...") |
| 266 | + |
| 267 | + async with httpx.AsyncClient() as client: |
| 268 | + headers = {MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} |
| 269 | + response = await client.post( |
| 270 | + server_url, |
| 271 | + headers=headers, |
| 272 | + follow_redirects=True, |
| 273 | + json={"jsonrpc": "2.0", "method": "ping", "id": 1}, |
| 274 | + ) |
| 275 | + |
| 276 | + if response.status_code != 401: |
| 277 | + raise RuntimeError( |
| 278 | + f"Expected 401 response for OAuth discovery, got {response.status_code}" |
| 279 | + ) |
| 280 | + |
| 281 | + # Extract resource metadata URL from WWW-Authenticate header |
| 282 | + www_auth_header = response.headers.get("WWW-Authenticate") |
| 283 | + resource_metadata_url = None |
| 284 | + |
| 285 | + if www_auth_header: |
| 286 | + # Simple extraction of resource_metadata URL |
| 287 | + import re |
| 288 | + |
| 289 | + pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))' |
| 290 | + match = re.search(pattern, www_auth_header) |
| 291 | + if match: |
| 292 | + resource_metadata_url = match.group(1) or match.group(2) |
| 293 | + |
| 294 | + if not resource_metadata_url: |
| 295 | + # Fallback to well-known discovery |
| 296 | + from urllib.parse import urlparse, urljoin |
| 297 | + |
| 298 | + parsed = urlparse(server_url) |
| 299 | + base_url = f"{parsed.scheme}://{parsed.netloc}" |
| 300 | + resource_metadata_url = urljoin( |
| 301 | + base_url, "/.well-known/oauth-protected-resource" |
| 302 | + ) |
| 303 | + |
| 304 | + logging.debug(f"Discovered resource metadata URL: {resource_metadata_url}") |
| 305 | + |
| 306 | + # Fetch protected resource metadata |
| 307 | + logging.debug("Fetching OAuth protected resource metadata...") |
| 308 | + metadata_response = await client.get( |
| 309 | + resource_metadata_url, headers=headers, follow_redirects=True |
| 310 | + ) |
| 311 | + |
| 312 | + if metadata_response.status_code != 200: |
| 313 | + raise RuntimeError( |
| 314 | + f"Failed to fetch resource metadata: HTTP {metadata_response.status_code}" |
| 315 | + ) |
| 316 | + |
| 317 | + resource_metadata = ProtectedResourceMetadata.model_validate_json( |
| 318 | + metadata_response.content |
| 319 | + ) |
| 320 | + |
| 321 | + # Extract authorization server |
| 322 | + if not resource_metadata.authorization_servers: |
| 323 | + raise RuntimeError( |
| 324 | + "No authorization server found in OAuth protected resource metadata" |
| 325 | + ) |
| 326 | + |
| 327 | + authorization_server_url = str(resource_metadata.authorization_servers[0]) |
| 328 | + |
| 329 | + # Extract scope |
| 330 | + if not resource_metadata.scopes_supported: |
| 331 | + logging.warning( |
| 332 | + "No scopes found in OAuth protected resource metadata. Using empty scope." |
| 333 | + ) |
| 334 | + scope = "" |
| 335 | + else: |
| 336 | + scope = " ".join(resource_metadata.scopes_supported) |
| 337 | + |
| 338 | + logging.debug(f"Discovered scope: {scope}") |
| 339 | + logging.debug(f"Discovered authorization server: {authorization_server_url}") |
| 340 | + |
| 341 | + return scope, authorization_server_url |
0 commit comments