1313from collections .abc import AsyncGenerator , Awaitable , Callable
1414from dataclasses import dataclass , field
1515from typing import Any , Protocol
16- from urllib .parse import urlencode , urljoin , urlparse
16+ from urllib .parse import quote , urlencode , urljoin , urlparse
1717
1818import anyio
1919import httpx
2020from pydantic import BaseModel , Field , ValidationError
2121
22- from mcp .client .auth import OAuthFlowError , OAuthTokenError
22+ from mcp .client .auth . exceptions import OAuthFlowError , OAuthRegistrationError , OAuthTokenError
2323from mcp .client .auth .utils import (
2424 build_oauth_authorization_server_metadata_discovery_urls ,
2525 build_protected_resource_metadata_discovery_urls ,
@@ -173,6 +173,42 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
173173 # Version format is YYYY-MM-DD, so string comparison works
174174 return protocol_version >= "2025-06-18"
175175
176+ def prepare_token_auth (
177+ self , data : dict [str , str ], headers : dict [str , str ] | None = None
178+ ) -> tuple [dict [str , str ], dict [str , str ]]:
179+ """Prepare authentication for token requests.
180+
181+ Args:
182+ data: The form data to send
183+ headers: Optional headers dict to update
184+
185+ Returns:
186+ Tuple of (updated_data, updated_headers)
187+ """
188+ if headers is None :
189+ headers = {} # pragma: no cover
190+
191+ if not self .client_info :
192+ return data , headers # pragma: no cover
193+
194+ auth_method = self .client_info .token_endpoint_auth_method
195+
196+ if auth_method == "client_secret_basic" and self .client_info .client_id and self .client_info .client_secret :
197+ # URL-encode client ID and secret per RFC 6749 Section 2.3.1
198+ encoded_id = quote (self .client_info .client_id , safe = "" )
199+ encoded_secret = quote (self .client_info .client_secret , safe = "" )
200+ credentials = f"{ encoded_id } :{ encoded_secret } "
201+ encoded_credentials = base64 .b64encode (credentials .encode ()).decode ()
202+ headers ["Authorization" ] = f"Basic { encoded_credentials } "
203+ # Don't include client_secret in body for basic auth
204+ data = {k : v for k , v in data .items () if k != "client_secret" }
205+ elif auth_method == "client_secret_post" and self .client_info .client_secret :
206+ # Include client_secret in request body
207+ data ["client_secret" ] = self .client_info .client_secret
208+ # For auth_method == "none", don't add any client_secret
209+
210+ return data , headers
211+
176212
177213class OAuthClientProvider (httpx .Auth ):
178214 """
@@ -247,6 +283,27 @@ async def _register_client(self) -> httpx.Request | None:
247283
248284 registration_data = self .context .client_metadata .model_dump (by_alias = True , mode = "json" , exclude_none = True )
249285
286+ # If token_endpoint_auth_method is None, auto-select based on server support
287+ if self .context .client_metadata .token_endpoint_auth_method is None :
288+ preference_order = ["client_secret_basic" , "client_secret_post" , "none" ]
289+
290+ if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint_auth_methods_supported :
291+ supported = self .context .oauth_metadata .token_endpoint_auth_methods_supported
292+ for method in preference_order :
293+ if method in supported :
294+ registration_data ["token_endpoint_auth_method" ] = method
295+ break
296+ else :
297+ # No compatible methods between client and server
298+ raise OAuthRegistrationError (
299+ f"No compatible authentication methods. "
300+ f"Server supports: { supported } , "
301+ f"Client supports: { preference_order } "
302+ )
303+ else :
304+ # No server metadata available, use our default preference
305+ registration_data ["token_endpoint_auth_method" ] = preference_order [0 ]
306+
250307 return httpx .Request (
251308 "POST" , registration_url , json = registration_data , headers = {"Content-Type" : "application/json" }
252309 )
@@ -343,12 +400,11 @@ async def _exchange_token_authorization_code(
343400 if self .context .should_include_resource_param (self .context .protocol_version ):
344401 token_data ["resource" ] = self .context .get_resource_url () # RFC 8707
345402
346- if self .context .client_info .client_secret :
347- token_data ["client_secret" ] = self .context .client_info .client_secret
403+ # Prepare authentication based on preferred method
404+ headers = {"Content-Type" : "application/x-www-form-urlencoded" }
405+ token_data , headers = self .context .prepare_token_auth (token_data , headers )
348406
349- return httpx .Request (
350- "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
351- )
407+ return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
352408
353409 async def _handle_token_response (self , response : httpx .Response ) -> None :
354410 """Handle token exchange response."""
@@ -370,7 +426,7 @@ async def _refresh_token(self) -> httpx.Request:
370426 if not self .context .current_tokens or not self .context .current_tokens .refresh_token :
371427 raise OAuthTokenError ("No refresh token available" ) # pragma: no cover
372428
373- if not self .context .client_info :
429+ if not self .context .client_info or not self . context . client_info . client_id :
374430 raise OAuthTokenError ("No client info available" ) # pragma: no cover
375431
376432 if self .context .oauth_metadata and self .context .oauth_metadata .token_endpoint :
@@ -379,7 +435,7 @@ async def _refresh_token(self) -> httpx.Request:
379435 auth_base_url = self .context .get_authorization_base_url (self .context .server_url )
380436 token_url = urljoin (auth_base_url , "/token" )
381437
382- refresh_data = {
438+ refresh_data : dict [ str , str ] = {
383439 "grant_type" : "refresh_token" ,
384440 "refresh_token" : self .context .current_tokens .refresh_token ,
385441 "client_id" : self .context .client_info .client_id ,
@@ -389,12 +445,11 @@ async def _refresh_token(self) -> httpx.Request:
389445 if self .context .should_include_resource_param (self .context .protocol_version ):
390446 refresh_data ["resource" ] = self .context .get_resource_url () # RFC 8707
391447
392- if self .context .client_info .client_secret : # pragma: no branch
393- refresh_data ["client_secret" ] = self .context .client_info .client_secret
448+ # Prepare authentication based on preferred method
449+ headers = {"Content-Type" : "application/x-www-form-urlencoded" }
450+ refresh_data , headers = self .context .prepare_token_auth (refresh_data , headers )
394451
395- return httpx .Request (
396- "POST" , token_url , data = refresh_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
397- )
452+ return httpx .Request ("POST" , token_url , data = refresh_data , headers = headers )
398453
399454 async def _handle_refresh_response (self , response : httpx .Response ) -> bool : # pragma: no cover
400455 """Handle token refresh response. Returns True if successful."""
0 commit comments