diff --git a/fastapi_third_party_auth/auth.py b/fastapi_third_party_auth/auth.py index 850211d..5dd792d 100644 --- a/fastapi_third_party_auth/auth.py +++ b/fastapi_third_party_auth/auth.py @@ -15,6 +15,7 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)): return f"Hello {authenticated_user.preferred_username}" """ +from logging import getLogger from typing import List from typing import Optional from typing import Type @@ -33,14 +34,16 @@ def test_auth(authenticated_user: IDToken = Security(auth.required)): from fastapi.security import OAuth2 from fastapi.security import SecurityScopes from jose import ExpiredSignatureError -from jose import JWTError from jose import jwt -from jose.exceptions import JWTClaimsError +from jose.exceptions import JWTClaimsError, JWKError, JWTError, JWSError +from requests.exceptions import ConnectionError from fastapi_third_party_auth import discovery from fastapi_third_party_auth.grant_types import GrantType from fastapi_third_party_auth.idtoken_types import IDToken +logger = getLogger(__name__) + class Auth(OAuth2): def __init__( @@ -81,8 +84,19 @@ def __init__( self.client_id = client_id self.idtoken_model = idtoken_model self.scopes = scopes - + self.discover = discovery.configure(cache_ttl=signature_cache_ttl) + self.grant_types = grant_types + + try: + flows = self.get_flows() + except ConnectionError as e: + logger.warning("Could not discover OIDC flows %s", e) + flows = OAuthFlows() + + super().__init__(scheme_name="OIDC", flows=flows, auto_error=False) + + def get_flows(self) -> OAuthFlows: oidc_discoveries = self.discover.auth_server( openid_connect_url=self.openid_connect_url ) @@ -91,36 +105,32 @@ def __init__( # } flows = OAuthFlows() - if GrantType.AUTHORIZATION_CODE in grant_types: + if GrantType.AUTHORIZATION_CODE in self.grant_types: flows.authorizationCode = OAuthFlowAuthorizationCode( authorizationUrl=self.discover.authorization_url(oidc_discoveries), tokenUrl=self.discover.token_url(oidc_discoveries), # scopes=scopes_dict, ) - if GrantType.CLIENT_CREDENTIALS in grant_types: + if GrantType.CLIENT_CREDENTIALS in self.grant_types: flows.clientCredentials = OAuthFlowClientCredentials( tokenUrl=self.discover.token_url(oidc_discoveries), # scopes=scopes_dict, ) - if GrantType.PASSWORD in grant_types: + if GrantType.PASSWORD in self.grant_types: flows.password = OAuthFlowPassword( tokenUrl=self.discover.token_url(oidc_discoveries), # scopes=scopes_dict, ) - if GrantType.IMPLICIT in grant_types: + if GrantType.IMPLICIT in self.grant_types: flows.implicit = OAuthFlowImplicit( authorizationUrl=self.discover.authorization_url(oidc_discoveries), # scopes=scopes_dict, ) - - super().__init__( - scheme_name="OIDC", - flows=flows, - auto_error=False, - ) + + return flows async def __call__(self, request: Request) -> None: return None @@ -189,6 +199,33 @@ def optional( auto_error=False, ) + + def _find_key(self, token: str) -> dict: + oidc_discoveries = self.discover.auth_server( + openid_connect_url=self.openid_connect_url + ) + try: + keys = self.discover.public_keys(oidc_discoveries)["keys"] + except KeyError as e: + raise JWKError("Badly formed JWKs_uri") from e + + header = jwt.get_unverified_header(token) + try: + kid = header['kid'] + except KeyError as e: + raise JWTError("field 'kid' is missing from JWT headers") from e + + for key in keys: + try: + key_kid = key['kid'] + except KeyError as e: + raise JWKError("field 'kid' is missing from JWK") from e + if key_kid == kid: + return key + + raise JWKError(f"Could not find JWK 'kid'={kid}") + + def authenticate_user( self, security_scopes: SecurityScopes, @@ -222,12 +259,16 @@ def authenticate_user( ) else: return None - - oidc_discoveries = self.discover.auth_server( - openid_connect_url=self.openid_connect_url - ) - key = self.discover.public_keys(oidc_discoveries) + + try: + oidc_discoveries = self.discover.auth_server( + openid_connect_url=self.openid_connect_url + ) + except ConnectionError as e: + logger.warning("Could not reach auth server %e", e) + raise HTTPException(503, detail="Could not reach auth server") from e algorithms = self.discover.signing_algos(oidc_discoveries) + key = self._find_key(authorization_credentials.credentials) try: id_token = jwt.decode( @@ -245,7 +286,8 @@ def authenticate_user( ) if ( - type(id_token["aud"]) == list + "aud" in id_token + and type(id_token["aud"]) == list and len(id_token["aud"]) >= 1 and "azp" not in id_token ):