@@ -48,6 +48,44 @@ async def set_client_info(self, client_info: OAuthClientInformationFull) -> None
4848 ...
4949
5050
51+ def _get_authorization_base_url (server_url : str ) -> str :
52+ """Return the authorization base URL for ``server_url``.
53+
54+ Per MCP spec 2.3.2, the path component must be discarded so that
55+ ``https://api.example.com/v1/mcp`` becomes ``https://api.example.com``.
56+ """
57+ from urllib .parse import urlparse , urlunparse
58+
59+ parsed = urlparse (server_url )
60+ return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
61+
62+
63+ async def _discover_oauth_metadata (server_url : str ) -> OAuthMetadata | None :
64+ """Discover OAuth metadata from the server's well-known endpoint."""
65+
66+ auth_base_url = _get_authorization_base_url (server_url )
67+ url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
68+ headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
69+
70+ async with httpx .AsyncClient () as client :
71+ try :
72+ response = await client .get (url , headers = headers )
73+ if response .status_code == 404 :
74+ return None
75+ response .raise_for_status ()
76+ return OAuthMetadata .model_validate (response .json ())
77+ except Exception :
78+ try :
79+ response = await client .get (url )
80+ if response .status_code == 404 :
81+ return None
82+ response .raise_for_status ()
83+ return OAuthMetadata .model_validate (response .json ())
84+ except Exception :
85+ logger .exception ("Failed to discover OAuth metadata" )
86+ return None
87+
88+
5189class OAuthClientProvider (httpx .Auth ):
5290 """
5391 Authentication for httpx using anyio.
@@ -110,52 +148,6 @@ def _generate_code_challenge(self, code_verifier: str) -> str:
110148 digest = hashlib .sha256 (code_verifier .encode ()).digest ()
111149 return base64 .urlsafe_b64encode (digest ).decode ().rstrip ("=" )
112150
113- def _get_authorization_base_url (self , server_url : str ) -> str :
114- """
115- Extract base URL by removing path component.
116-
117- Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
118- """
119- from urllib .parse import urlparse , urlunparse
120-
121- parsed = urlparse (server_url )
122- # Remove path component
123- return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
124-
125- async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
126- """
127- Discover OAuth metadata from server's well-known endpoint.
128- """
129- # Extract base URL per MCP spec
130- auth_base_url = self ._get_authorization_base_url (server_url )
131- url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
132- headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
133-
134- async with httpx .AsyncClient () as client :
135- try :
136- response = await client .get (url , headers = headers )
137- if response .status_code == 404 :
138- return None
139- response .raise_for_status ()
140- metadata_json = response .json ()
141- logger .debug (f"OAuth metadata discovered: { metadata_json } " )
142- return OAuthMetadata .model_validate (metadata_json )
143- except Exception :
144- # Retry without MCP header for CORS compatibility
145- try :
146- response = await client .get (url )
147- if response .status_code == 404 :
148- return None
149- response .raise_for_status ()
150- metadata_json = response .json ()
151- logger .debug (
152- f"OAuth metadata discovered (no MCP header): { metadata_json } "
153- )
154- return OAuthMetadata .model_validate (metadata_json )
155- except Exception :
156- logger .exception ("Failed to discover OAuth metadata" )
157- return None
158-
159151 async def _register_oauth_client (
160152 self ,
161153 server_url : str ,
@@ -166,13 +158,13 @@ async def _register_oauth_client(
166158 Register OAuth client with server.
167159 """
168160 if not metadata :
169- metadata = await self . _discover_oauth_metadata (server_url )
161+ metadata = await _discover_oauth_metadata (server_url )
170162
171163 if metadata and metadata .registration_endpoint :
172164 registration_url = str (metadata .registration_endpoint )
173165 else :
174166 # Use fallback registration endpoint
175- auth_base_url = self . _get_authorization_base_url (server_url )
167+ auth_base_url = _get_authorization_base_url (server_url )
176168 registration_url = urljoin (auth_base_url , "/register" )
177169
178170 # Handle default scope
@@ -321,7 +313,7 @@ async def _perform_oauth_flow(self) -> None:
321313
322314 # Discover OAuth metadata
323315 if not self ._metadata :
324- self ._metadata = await self . _discover_oauth_metadata (self .server_url )
316+ self ._metadata = await _discover_oauth_metadata (self .server_url )
325317
326318 # Ensure client registration
327319 client_info = await self ._get_or_register_client ()
@@ -335,7 +327,7 @@ async def _perform_oauth_flow(self) -> None:
335327 auth_url_base = str (self ._metadata .authorization_endpoint )
336328 else :
337329 # Use fallback authorization endpoint
338- auth_base_url = self . _get_authorization_base_url (self .server_url )
330+ auth_base_url = _get_authorization_base_url (self .server_url )
339331 auth_url_base = urljoin (auth_base_url , "/authorize" )
340332
341333 # Build authorization URL
@@ -386,7 +378,7 @@ async def _exchange_code_for_token(
386378 token_url = str (self ._metadata .token_endpoint )
387379 else :
388380 # Use fallback token endpoint
389- auth_base_url = self . _get_authorization_base_url (self .server_url )
381+ auth_base_url = _get_authorization_base_url (self .server_url )
390382 token_url = urljoin (auth_base_url , "/token" )
391383
392384 token_data = {
@@ -453,7 +445,7 @@ async def _refresh_access_token(self) -> bool:
453445 token_url = str (self ._metadata .token_endpoint )
454446 else :
455447 # Use fallback token endpoint
456- auth_base_url = self . _get_authorization_base_url (self .server_url )
448+ auth_base_url = _get_authorization_base_url (self .server_url )
457449 token_url = urljoin (auth_base_url , "/token" )
458450
459451 refresh_data = {
@@ -523,48 +515,19 @@ def __init__(
523515
524516 self ._token_lock = anyio .Lock ()
525517
526- def _get_authorization_base_url (self , server_url : str ) -> str :
527- from urllib .parse import urlparse , urlunparse
528-
529- parsed = urlparse (server_url )
530- return urlunparse ((parsed .scheme , parsed .netloc , "" , "" , "" , "" ))
531-
532- async def _discover_oauth_metadata (self , server_url : str ) -> OAuthMetadata | None :
533- auth_base_url = self ._get_authorization_base_url (server_url )
534- url = urljoin (auth_base_url , "/.well-known/oauth-authorization-server" )
535- headers = {"MCP-Protocol-Version" : LATEST_PROTOCOL_VERSION }
536-
537- async with httpx .AsyncClient () as client :
538- try :
539- response = await client .get (url , headers = headers )
540- if response .status_code == 404 :
541- return None
542- response .raise_for_status ()
543- return OAuthMetadata .model_validate (response .json ())
544- except Exception :
545- try :
546- response = await client .get (url )
547- if response .status_code == 404 :
548- return None
549- response .raise_for_status ()
550- return OAuthMetadata .model_validate (response .json ())
551- except Exception :
552- logger .exception ("Failed to discover OAuth metadata" )
553- return None
554-
555518 async def _register_oauth_client (
556519 self ,
557520 server_url : str ,
558521 client_metadata : OAuthClientMetadata ,
559522 metadata : OAuthMetadata | None = None ,
560523 ) -> OAuthClientInformationFull :
561524 if not metadata :
562- metadata = await self . _discover_oauth_metadata (server_url )
525+ metadata = await _discover_oauth_metadata (server_url )
563526
564527 if metadata and metadata .registration_endpoint :
565528 registration_url = str (metadata .registration_endpoint )
566529 else :
567- auth_base_url = self . _get_authorization_base_url (server_url )
530+ auth_base_url = _get_authorization_base_url (server_url )
568531 registration_url = urljoin (auth_base_url , "/register" )
569532
570533 if (
@@ -636,14 +599,14 @@ async def _get_or_register_client(self) -> OAuthClientInformationFull:
636599
637600 async def _request_token (self ) -> None :
638601 if not self ._metadata :
639- self ._metadata = await self . _discover_oauth_metadata (self .server_url )
602+ self ._metadata = await _discover_oauth_metadata (self .server_url )
640603
641604 client_info = await self ._get_or_register_client ()
642605
643606 if self ._metadata and self ._metadata .token_endpoint :
644607 token_url = str (self ._metadata .token_endpoint )
645608 else :
646- auth_base_url = self . _get_authorization_base_url (self .server_url )
609+ auth_base_url = _get_authorization_base_url (self .server_url )
647610 token_url = urljoin (auth_base_url , "/token" )
648611
649612 token_data = {
0 commit comments