|
| 1 | +import time |
| 2 | +from typing import Annotated, Any, Optional |
| 3 | + |
| 4 | +import jwt |
| 5 | +from fastapi import Depends, HTTPException, Request, status |
| 6 | +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer |
| 7 | + |
| 8 | +from common import AppConfig |
| 9 | +from http_app.dependencies import app_config |
| 10 | + |
| 11 | + |
| 12 | +class MissingAuthorizationServerException(HTTPException): |
| 13 | + def __init__(self, **kwargs): |
| 14 | + super().__init__( |
| 15 | + status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 16 | + detail="Authorization server not available", |
| 17 | + ) |
| 18 | + |
| 19 | + |
| 20 | +class UnauthorizedException(HTTPException): |
| 21 | + def __init__(self, detail: str, **kwargs): |
| 22 | + super().__init__(status.HTTP_403_FORBIDDEN, detail=detail) |
| 23 | + |
| 24 | + |
| 25 | +class UnauthenticatedException(HTTPException): |
| 26 | + def __init__(self): |
| 27 | + super().__init__( |
| 28 | + status_code=status.HTTP_401_UNAUTHORIZED, detail="Requires authentication" |
| 29 | + ) |
| 30 | + |
| 31 | + |
| 32 | +def _jwks_client(config: Annotated[AppConfig, Depends(app_config)]) -> jwt.PyJWKClient: |
| 33 | + if not config.AUTH.JWKS_URL: |
| 34 | + raise MissingAuthorizationServerException() |
| 35 | + return jwt.PyJWKClient(config.AUTH.JWKS_URL) |
| 36 | + |
| 37 | + |
| 38 | +class JWTBearer(HTTPBearer): |
| 39 | + async def __call__( |
| 40 | + self, |
| 41 | + request: Request, |
| 42 | + ) -> Optional[HTTPAuthorizationCredentials]: |
| 43 | + credentials = await super(JWTBearer, self).__call__(request) |
| 44 | + |
| 45 | + await self.decode(request) |
| 46 | + |
| 47 | + return credentials |
| 48 | + |
| 49 | + async def decode( |
| 50 | + self, |
| 51 | + request: Request, |
| 52 | + jwks_client: jwt.PyJWKClient = Depends(_jwks_client), |
| 53 | + config: AppConfig = Depends(app_config), |
| 54 | + ) -> dict[str, Any]: |
| 55 | + credentials = await super(JWTBearer, self).__call__(request) |
| 56 | + |
| 57 | + if not credentials: |
| 58 | + raise UnauthenticatedException() |
| 59 | + |
| 60 | + if not credentials.scheme == "Bearer": |
| 61 | + raise UnauthorizedException("Invalid authentication scheme.") |
| 62 | + |
| 63 | + try: |
| 64 | + signing_key = jwks_client.get_signing_key_from_jwt( |
| 65 | + credentials.credentials |
| 66 | + ).key |
| 67 | + except jwt.exceptions.PyJWKClientError as error: |
| 68 | + raise UnauthorizedException(str(error)) |
| 69 | + except jwt.exceptions.DecodeError as error: |
| 70 | + raise UnauthorizedException(str(error)) |
| 71 | + |
| 72 | + try: |
| 73 | + # TODO: Review decode setup and verifications |
| 74 | + # https://pyjwt.readthedocs.io/en/stable/api.html#jwt.decode |
| 75 | + payload = jwt.decode( |
| 76 | + jwt=credentials.credentials, |
| 77 | + key=signing_key, |
| 78 | + algorithms=[config.AUTH.JWT_ALGORITHM], |
| 79 | + ) |
| 80 | + except Exception as error: |
| 81 | + raise UnauthorizedException(str(error)) |
| 82 | + |
| 83 | + if payload["expires"] < time.time(): |
| 84 | + raise UnauthorizedException("Expired token") |
| 85 | + |
| 86 | + return payload |
0 commit comments