44import logging
55import urllib .request
66from dataclasses import dataclass , field
7- from typing import Annotated , Any , Callable , Optional , Sequence
7+ from typing import Annotated , Optional , Sequence
88
99import jwt
1010from fastapi import HTTPException , Security , security , status
@@ -25,8 +25,6 @@ class OpenIdConnectAuth:
2525 # Generated attributes
2626 auth_scheme : SecurityBase = field (init = False )
2727 jwks_client : jwt .PyJWKClient = field (init = False )
28- validated_user : Callable [..., Any ] = field (init = False )
29- maybe_validated_user : Callable [..., Any ] = field (init = False )
3028
3129 def __post_init__ (self ):
3230 """Initialize the OIDC authentication class."""
@@ -50,70 +48,80 @@ def __post_init__(self):
5048 openIdConnectUrl = str (self .openid_configuration_url ),
5149 auto_error = False ,
5250 )
53- self .validated_user = self ._build (auto_error = True )
54- self .maybe_validated_user = self ._build (auto_error = False )
55-
56- def _build (self , auto_error : bool = True ):
57- """Build a dependency for validating an OIDC token."""
58-
59- def valid_token_dependency (
60- auth_header : Annotated [str , Security (self .auth_scheme )],
61- required_scopes : security .SecurityScopes ,
62- ):
63- """Dependency to validate an OIDC token."""
64- if not auth_header :
51+
52+ # Update annotations to support FastAPI's dependency injection
53+ for endpoint in [self .validated_user , self .maybe_validated_user ]:
54+ endpoint .__annotations__ ["auth_header" ] = Annotated [
55+ str ,
56+ Security (self .auth_scheme ),
57+ ]
58+
59+ def maybe_validated_user (
60+ self ,
61+ auth_header : Annotated [str , Security (...)],
62+ required_scopes : security .SecurityScopes ,
63+ ):
64+ """Dependency to validate an OIDC token."""
65+ return self .validated_user (auth_header , required_scopes , auto_error = False )
66+
67+ def validated_user (
68+ self ,
69+ auth_header : Annotated [str , Security (...)],
70+ required_scopes : security .SecurityScopes ,
71+ auto_error : bool = True ,
72+ ):
73+ """Dependency to validate an OIDC token."""
74+ if not auth_header :
75+ if auto_error :
76+ raise HTTPException (
77+ status_code = status .HTTP_403_FORBIDDEN ,
78+ detail = "Not authenticated" ,
79+ )
80+ return None
81+
82+ # Extract token from header
83+ token_parts = auth_header .split (" " )
84+ if len (token_parts ) != 2 or token_parts [0 ].lower () != "bearer" :
85+ logger .error (f"Invalid token: { auth_header } " )
86+ raise HTTPException (
87+ status_code = status .HTTP_401_UNAUTHORIZED ,
88+ detail = "Could not validate credentials" ,
89+ headers = {"WWW-Authenticate" : "Bearer" },
90+ )
91+ [_ , token ] = token_parts
92+
93+ # Parse & validate token
94+ try :
95+ key = self .jwks_client .get_signing_key_from_jwt (token ).key
96+ payload = jwt .decode (
97+ token ,
98+ key ,
99+ algorithms = ["RS256" ],
100+ # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
101+ audience = self .allowed_jwt_audiences ,
102+ )
103+ except (jwt .exceptions .InvalidTokenError , jwt .exceptions .DecodeError ) as e :
104+ logger .exception (f"InvalidTokenError: { e = } " )
105+ raise HTTPException (
106+ status_code = status .HTTP_401_UNAUTHORIZED ,
107+ detail = "Could not validate credentials" ,
108+ headers = {"WWW-Authenticate" : "Bearer" },
109+ ) from e
110+
111+ # Validate scopes (if required)
112+ for scope in required_scopes .scopes :
113+ if scope not in payload ["scope" ]:
65114 if auto_error :
66115 raise HTTPException (
67- status_code = status .HTTP_403_FORBIDDEN ,
68- detail = "Not authenticated" ,
116+ status_code = status .HTTP_401_UNAUTHORIZED ,
117+ detail = "Not enough permissions" ,
118+ headers = {
119+ "WWW-Authenticate" : f'Bearer scope="{ required_scopes .scope_str } "'
120+ },
69121 )
70122 return None
71123
72- # Extract token from header
73- token_parts = auth_header .split (" " )
74- if len (token_parts ) != 2 or token_parts [0 ].lower () != "bearer" :
75- logger .error (f"Invalid token: { auth_header } " )
76- raise HTTPException (
77- status_code = status .HTTP_401_UNAUTHORIZED ,
78- detail = "Could not validate credentials" ,
79- headers = {"WWW-Authenticate" : "Bearer" },
80- )
81- [_ , token ] = token_parts
82-
83- # Parse & validate token
84- try :
85- key = self .jwks_client .get_signing_key_from_jwt (token ).key
86- payload = jwt .decode (
87- token ,
88- key ,
89- algorithms = ["RS256" ],
90- # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
91- audience = self .allowed_jwt_audiences ,
92- )
93- except (jwt .exceptions .InvalidTokenError , jwt .exceptions .DecodeError ) as e :
94- logger .exception (f"InvalidTokenError: { e = } " )
95- raise HTTPException (
96- status_code = status .HTTP_401_UNAUTHORIZED ,
97- detail = "Could not validate credentials" ,
98- headers = {"WWW-Authenticate" : "Bearer" },
99- ) from e
100-
101- # Validate scopes (if required)
102- for scope in required_scopes .scopes :
103- if scope not in payload ["scope" ]:
104- if auto_error :
105- raise HTTPException (
106- status_code = status .HTTP_401_UNAUTHORIZED ,
107- detail = "Not enough permissions" ,
108- headers = {
109- "WWW-Authenticate" : f'Bearer scope="{ required_scopes .scope_str } "'
110- },
111- )
112- return None
113-
114- return payload
115-
116- return valid_token_dependency
124+ return payload
117125
118126
119127class OidcFetchError (Exception ):
0 commit comments