66authorization specification.
77"""
88
9+ import base64
10+ import hashlib
911import json
1012import logging
1113from datetime import datetime , timedelta
1214from typing import Any , Protocol
13- from urllib .parse import urlparse
15+ from urllib .parse import urlencode , urlparse
1416
1517import httpx
1618from pydantic import AnyHttpUrl , BaseModel , ConfigDict , Field
@@ -373,7 +375,49 @@ class OAuthClientProvider(Protocol):
373375 @property
374376 def client_metadata (self ) -> ClientMetadata : ...
375377
376- def save_client_information (self , metadata : DynamicClientRegistration ) -> None : ...
378+ @property
379+ def redirect_url (self ) -> AnyHttpUrl : ...
380+
381+ async def open_user_agent (self , url : AnyHttpUrl ) -> None :
382+ """
383+ Opens the user agent to the given URL.
384+ """
385+ ...
386+
387+ async def client_registration (
388+ self , endpoint : AnyHttpUrl
389+ ) -> DynamicClientRegistration | None :
390+ """
391+ Loads the client registration for the given endpoint.
392+ """
393+ ...
394+
395+ async def store_client_registration (
396+ self , endpoint : AnyHttpUrl , metadata : DynamicClientRegistration
397+ ) -> None :
398+ """
399+ Stores the client registration to be retreived for the next session
400+ """
401+ ...
402+
403+ def code_verifier (self ) -> str :
404+ """
405+ Loads the PKCE code verifier for the current session.
406+ See https://www.rfc-editor.org/rfc/rfc7636.html#section-4.1
407+ """
408+ ...
409+
410+ async def token (self ) -> AccessToken | None :
411+ """
412+ Loads the token for the current session.
413+ """
414+ ...
415+
416+ async def store_token (self , token : AccessToken ) -> None :
417+ """
418+ Stores the token to be retreived for the next session
419+ """
420+ ...
377421
378422
379423class NotFoundError (Exception ):
@@ -388,29 +432,64 @@ class RegistrationFailedError(Exception):
388432 pass
389433
390434
435+ class GrantNotSupported (Exception ):
436+ """Exception raised when a grant type is not supported."""
437+
438+ pass
439+
440+
391441class OAuthClient :
392442 WELL_KNOWN = "/.well-known/oauth-authorization-server"
393-
394- def __init__ (self , server_url : AnyHttpUrl , provider : OAuthClientProvider ):
443+ GRANT_TYPE : str = "authorization_code"
444+
445+ def __init__ (
446+ self ,
447+ server_url : AnyHttpUrl ,
448+ provider : OAuthClientProvider ,
449+ scope : str | None = None ,
450+ ):
395451 self .server_url = server_url
396452 self .http_client = httpx .AsyncClient ()
397453 self .provider = provider
398- self ._registration : DynamicClientRegistration | None = None
454+ self .scope = scope
399455
400- async def auth (self ):
401- metadata = await self .discover_auth_metadata () or self ._default_metadata ()
456+ @property
457+ def discovery_url (self ) -> AnyHttpUrl :
458+ base_url = str (self .server_url ).rstrip ("/" )
459+ parsed_url = urlparse (base_url )
460+ # HTTPS is required by RFC 8414
461+ discovery_url = f"https://{ parsed_url .netloc } { self .WELL_KNOWN } "
462+ return AnyHttpUrl (discovery_url )
463+
464+ async def _obtain_client (
465+ self , metadata : ServerMetadataDiscovery
466+ ) -> DynamicClientRegistration :
467+ """
468+ Obtain a client by either reading it from the OAuthProvider or registering it.
469+ """
402470 if metadata .registration_endpoint is None :
403471 raise NotFoundError ("Registration endpoint not found" )
404- self . _registration = await self . dynamic_client_registration (
405- self .provider .client_metadata , metadata .registration_endpoint
406- )
407- if self . _registration is None :
408- raise RegistrationFailedError (
409- f"Registration at { metadata .registration_endpoint } failed"
472+
473+ if registration := await self .provider .client_registration ( metadata .issuer ):
474+ return registration
475+ else :
476+ registration = await self . dynamic_client_registration (
477+ self . provider . client_metadata , metadata .registration_endpoint
410478 )
411- self .provider .save_client_information (self ._registration )
479+ if registration is None :
480+ raise RegistrationFailedError (
481+ f"Registration at { metadata .registration_endpoint } failed"
482+ )
412483
413- def _default_metadata (self ) -> ServerMetadataDiscovery :
484+ await self .provider .store_client_registration (metadata .issuer , registration )
485+ return registration
486+
487+ def default_metadata (self ) -> ServerMetadataDiscovery :
488+ """
489+ Returns default endpoints as specified in
490+ https://spec.modelcontextprotocol.io/specification/draft/basic/authorization/
491+ for the server.
492+ """
414493 base_url = AnyHttpUrl (str (self .server_url ).rstrip ("/" ))
415494 return ServerMetadataDiscovery (
416495 issuer = base_url ,
@@ -423,10 +502,11 @@ def _default_metadata(self) -> ServerMetadataDiscovery:
423502 )
424503
425504 async def discover_auth_metadata (self ) -> ServerMetadataDiscovery | None :
426- discovery_url = self ._build_discovery_url ()
427-
505+ """
506+ Use RFC 8414 to discover the authorization server metadata.
507+ """
428508 try :
429- response = await self .http_client .get (str (discovery_url ))
509+ response = await self .http_client .get (str (self . discovery_url ))
430510 if response .status_code == 404 :
431511 return None
432512 response .raise_for_status ()
@@ -439,31 +519,12 @@ async def discover_auth_metadata(self) -> ServerMetadataDiscovery | None:
439519 logger .error (f"Error during auth metadata discovery: { e } " )
440520 raise
441521
442- def _build_discovery_url (self ) -> AnyHttpUrl :
443- base_url = str (self .server_url ).rstrip ("/" )
444- parsed_url = urlparse (base_url )
445- # HTTPS is required by RFC 8414
446- discovery_url = f"https://{ parsed_url .netloc } { self .WELL_KNOWN } "
447- return AnyHttpUrl (discovery_url )
448-
449522 async def dynamic_client_registration (
450523 self , client_metadata : ClientMetadata , registration_endpoint : AnyHttpUrl
451524 ) -> DynamicClientRegistration | None :
452525 """
453526 Register a client dynamically with an OAuth 2.0 authorization server
454527 following RFC 7591.
455-
456- Args:
457- client_metadata: Typed client registration metadata
458- registration_endpoint: Where to register clients.
459- If None, will use discovery
460-
461- Returns:
462- DynamicClientRegistrationResponse if successful, None otherwise
463-
464- Raises:
465- httpx.HTTPStatusError: If the server returns an error status code
466- Exception: For other errors during registration
467528 """
468529 headers = {"Content-Type" : "application/json" , "Accept" : "application/json" }
469530
@@ -493,3 +554,145 @@ async def dynamic_client_registration(
493554 logger .error (f"Unexpected error during registration: { e } " )
494555
495556 return None
557+
558+ async def exchange_authorization (
559+ self ,
560+ metadata : ServerMetadataDiscovery ,
561+ registration : DynamicClientRegistration ,
562+ code_verifier : str ,
563+ authorization_code : str ,
564+ ) -> AccessToken :
565+ """Exchange an authorization code for an access token using OAuth 2.1 with PKCE.
566+
567+ Args:
568+ registration: The client registration information
569+ code_verifier: The PKCE code verifier used to generate the code challenge
570+ authorization_code: The authorization code received from the authorization
571+ server
572+
573+ Returns:
574+ AccessToken: The resulting access token
575+
576+ Raises:
577+ GrantNotSupported: If the grant type is not supported
578+ httpx.HTTPStatusError: If the token endpoint request fails
579+ """
580+ if self .GRANT_TYPE not in (registration .grant_types or []):
581+ raise GrantNotSupported (f"Grant type { self .GRANT_TYPE } not supported" )
582+
583+ code_verifier = self .provider .code_verifier ()
584+ # Get token endpoint from server metadata or use default
585+ token_endpoint = str (metadata .token_endpoint )
586+
587+ # Prepare token request parameters
588+ data = {
589+ "grant_type" : self .GRANT_TYPE ,
590+ "code" : authorization_code ,
591+ "redirect_uri" : str (self .provider .redirect_url ),
592+ "client_id" : registration .client_id ,
593+ "code_verifier" : code_verifier ,
594+ }
595+
596+ # Add client secret if available (optional in OAuth 2.1)
597+ if registration .client_secret :
598+ data ["client_secret" ] = registration .client_secret
599+
600+ headers = {
601+ "Content-Type" : "application/x-www-form-urlencoded" ,
602+ "Accept" : "application/json" ,
603+ }
604+
605+ try :
606+ response = await self .http_client .post (
607+ token_endpoint , data = data , headers = headers
608+ )
609+ response .raise_for_status ()
610+ token_data = response .json ()
611+
612+ # Create and return the token
613+ return AccessToken (** token_data )
614+
615+ except httpx .HTTPStatusError as e :
616+ logger .error (f"HTTP error during token exchange: { e .response .status_code } " )
617+ if e .response .content :
618+ try :
619+ error_data = json .loads (e .response .content )
620+ logger .error (f"Error details: { error_data } " )
621+ except json .JSONDecodeError :
622+ logger .error (f"Error content: { e .response .content } " )
623+ raise
624+ except Exception as e :
625+ logger .error (f"Unexpected error during token exchange: { e } " )
626+ raise
627+
628+ async def auth (self , authorization_code : str , code_verifier : str ) -> AccessToken :
629+ """
630+ Complete the OAuth 2.1 authorization flow by exchanging authorization code
631+ for tokens.
632+
633+ Args:
634+ authorization_code: The authorization code received from the authorization
635+ server
636+ code_verifier: The PKCE code verifier used to generate the code challenge
637+
638+ Returns:
639+ AccessToken: The resulting access token
640+ """
641+ metadata = await self .discover_auth_metadata () or self .default_metadata ()
642+ registration = await self ._obtain_client (metadata )
643+
644+ code_verifier = self .provider .code_verifier ()
645+
646+ authorization_url = self .get_authorization_url (
647+ metadata .authorization_endpoint ,
648+ self .provider .redirect_url ,
649+ registration .client_id ,
650+ code_verifier ,
651+ self .scope ,
652+ )
653+
654+ await self .provider .open_user_agent (AnyHttpUrl (authorization_url ))
655+
656+ return await self .exchange_authorization (
657+ metadata , registration , code_verifier , authorization_code
658+ )
659+
660+ def get_authorization_url (
661+ self ,
662+ authorization_endpoint : AnyHttpUrl ,
663+ redirect_uri : AnyHttpUrl ,
664+ client_id : str ,
665+ code_verifier : str ,
666+ scope : str | None = None ,
667+ ) -> AnyHttpUrl :
668+ """Generate an OAuth 2.1 authorization URL for the user agent.
669+
670+ This method generates a URL that the user agent (browser) should visit to
671+ authenticate the user and authorize the application. It includes PKCE
672+ (Proof Key for Code Exchange) for enhanced security as required by OAuth 2.1.
673+ """
674+ # Create a custom verifier for this authorization request
675+ code_verifier = self .provider .code_verifier ()
676+
677+ # Generate code challenge from verifier using SHA-256
678+ code_challenge = (
679+ base64 .urlsafe_b64encode (hashlib .sha256 (code_verifier .encode ()).digest ())
680+ .decode ()
681+ .rstrip ("=" )
682+ )
683+
684+ # Build authorization URL with necessary parameters
685+ params = {
686+ "response_type" : "code" ,
687+ "client_id" : client_id ,
688+ "redirect_uri" : str (redirect_uri ),
689+ "code_challenge" : code_challenge ,
690+ "code_challenge_method" : "S256" ,
691+ }
692+
693+ # Add scope if provided or use the one from registration
694+ if scope :
695+ params ["scope" ] = scope
696+
697+ # Construct the full authorization URL
698+ return AnyHttpUrl (f"{ authorization_endpoint } ?{ urlencode (params )} " )
0 commit comments