1616import httpx
1717from botocore .exceptions import ClientError
1818
19- from mcp .client .auth import OAuthClientProvider , TokenStorage
19+ from mcp .client .auth import OAuthClientProvider , TokenStorage , OAuthContext
2020from mcp .client .session import ClientSession
2121from mcp .client .streamable_http import streamablehttp_client , MCP_PROTOCOL_VERSION
2222from mcp .shared .auth import (
@@ -131,20 +131,14 @@ async def perform_client_credentials_flow(self) -> None:
131131
132132 logging .debug ("Performing client credentials flow..." )
133133
134- # Discover OAuth metadata
135- async with httpx .AsyncClient () as client :
136- base_url = self .authorization_server_url .rstrip ("/" )
137- metadata_url = f"{ base_url } /.well-known/oauth-authorization-server"
138- response = await client .get (metadata_url , follow_redirects = True )
139-
140- if response .status_code != 200 :
141- raise RuntimeError (
142- f"Failed to discover OAuth metadata: HTTP { response .status_code } "
143- )
144-
145- metadata = OAuthMetadata .model_validate_json (response .content )
134+ # Set auth server URL and discover OAuth metadata using upstream logic
135+ self .context .auth_server_url = self .authorization_server_url
136+ await self ._discover_oauth_metadata ()
146137
147- if not metadata .token_endpoint :
138+ if (
139+ not self .context .oauth_metadata
140+ or not self .context .oauth_metadata .token_endpoint
141+ ):
148142 raise RuntimeError ("No token endpoint found in OAuth metadata" )
149143
150144 # Create client info and store it
@@ -168,11 +162,13 @@ async def perform_client_credentials_flow(self) -> None:
168162 if self .context .client_metadata .scope :
169163 token_data ["scope" ] = self .context .client_metadata .scope
170164
171- logging .debug (f"Making token request to: { metadata .token_endpoint } " )
165+ logging .debug (
166+ f"Making token request to: { self .context .oauth_metadata .token_endpoint } "
167+ )
172168
173169 async with httpx .AsyncClient () as client :
174170 response = await client .post (
175- str (metadata .token_endpoint ),
171+ str (self . context . oauth_metadata .token_endpoint ),
176172 data = token_data ,
177173 headers = {"Content-Type" : "application/x-www-form-urlencoded" },
178174 )
@@ -206,6 +202,32 @@ async def perform_client_credentials_flow(self) -> None:
206202 logging .error (f"Client credentials flow failed: { error } " )
207203 raise
208204
205+ async def _discover_oauth_metadata (self ) -> None :
206+ """Discover OAuth metadata using upstream MCP SDK discovery logic."""
207+ # Use the inherited _get_discovery_urls method from parent class
208+ discovery_urls = self ._get_discovery_urls ()
209+
210+ async with httpx .AsyncClient () as client :
211+ for metadata_url in discovery_urls :
212+ try :
213+ response = await client .get (metadata_url , follow_redirects = True )
214+ if response .status_code == 200 :
215+ metadata = OAuthMetadata .model_validate_json (response .content )
216+ self .context .oauth_metadata = metadata
217+ logging .debug (
218+ f"Successfully discovered OAuth metadata from: { metadata_url } "
219+ )
220+ return
221+ except Exception as e :
222+ logging .debug (
223+ f"Failed to discover OAuth metadata from { metadata_url } : { e } "
224+ )
225+ continue
226+
227+ raise RuntimeError (
228+ "Failed to discover OAuth metadata from any well-known endpoint"
229+ )
230+
209231
210232class AutomatedOAuthClient (Server ):
211233 """
@@ -359,16 +381,18 @@ async def _discover_scope_and_auth_server(self, server_url: str) -> tuple[str, s
359381
360382 # Extract resource metadata URL from WWW-Authenticate header
361383 www_auth_header = response .headers .get ("WWW-Authenticate" )
362- if not www_auth_header :
363- raise RuntimeError ("No WWW-Authenticate header found in response" )
384+ resource_metadata_url = None
364385
365- # Simple extraction of resource_metadata URL
366- import re
386+ if www_auth_header :
387+ # Simple extraction of resource_metadata URL
388+ import re
367389
368- pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
369- match = re .search (pattern , www_auth_header )
390+ pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
391+ match = re .search (pattern , www_auth_header )
392+ if match :
393+ resource_metadata_url = match .group (1 ) or match .group (2 )
370394
371- if not match :
395+ if not resource_metadata_url :
372396 # Fallback to well-known discovery
373397 from urllib .parse import urlparse , urljoin
374398
@@ -377,8 +401,6 @@ async def _discover_scope_and_auth_server(self, server_url: str) -> tuple[str, s
377401 resource_metadata_url = urljoin (
378402 base_url , "/.well-known/oauth-protected-resource"
379403 )
380- else :
381- resource_metadata_url = match .group (1 ) or match .group (2 )
382404
383405 logging .debug (f"Discovered resource metadata URL: { resource_metadata_url } " )
384406
0 commit comments