|
| 1 | +# filename: main.py |
| 2 | +import os |
| 3 | +import json |
| 4 | +import requests |
| 5 | +from typing import Dict, List, Any |
| 6 | +import logging |
| 7 | + |
| 8 | +# Third-party imports |
| 9 | +from fastapi import FastAPI, Request, HTTPException, status |
| 10 | +from jose import jwt, jwk |
| 11 | +from jose.exceptions import ExpiredSignatureError, JWTError |
| 12 | + |
| 13 | +# --- Configuration --- |
| 14 | +KEYCLOAK_URL = os.getenv("KEYCLOAK_URL", "http://keycloak:8080") |
| 15 | +KEYCLOAK_REALM = os.getenv("KEYCLOAK_REALM", "my_realm") |
| 16 | +KEYCLOAK_AUDIENCE = os.getenv("KEYCLOAK_AUDIENCE", "my_client") |
| 17 | + |
| 18 | +# --- Path-to-role mapping --- |
| 19 | +ROLE_MAPPING: Dict[str, List[str]] = { |
| 20 | + "/api/admin": ["admin-role"], |
| 21 | + "/api/manager": ["manager-role", "admin-role"], |
| 22 | + "/api/public": ["authenticated-user"], # Anyone with a valid token |
| 23 | + "/api/v2/special": ["special-role"] |
| 24 | +} |
| 25 | + |
| 26 | +# JWKS Cache |
| 27 | +JWKS_CACHE: Dict[str, Any] = {} |
| 28 | + |
| 29 | +# Configure logging |
| 30 | +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| 31 | +app_logger = logging.getLogger(__name__) |
| 32 | + |
| 33 | +app = FastAPI( |
| 34 | + title="Traefik Auth (M2M) Middleware", |
| 35 | + description="Custom Traefik ForwardAuth service to verify Keycloak JWT tokens and enforce role-based access control.", |
| 36 | + version="1.0.0" |
| 37 | +) |
| 38 | + |
| 39 | +# --- JWKS and JWT Validation Functions --- |
| 40 | +def get_jwks() -> Dict[str, Any]: |
| 41 | + """ |
| 42 | + Fetches and caches the public keys (JWKS) from the Keycloak server. |
| 43 | + """ |
| 44 | + global JWKS_CACHE |
| 45 | + if not JWKS_CACHE: |
| 46 | + jwks_url = f"{KEYCLOAK_URL}/realms/{KEYCLOAK_REALM}/protocol/openid-connect/certs" |
| 47 | + app_logger.info(f"get_jwks::jwks_url: {jwks_url}") |
| 48 | + try: |
| 49 | + response = requests.get(jwks_url, timeout=5) |
| 50 | + app_logger.info(f"get_jwks::response: {response}") |
| 51 | + response.raise_for_status() |
| 52 | + JWKS_CACHE = response.json() |
| 53 | + except requests.exceptions.RequestException as e: |
| 54 | + app_logger.error(f"get_jwks: '{str}': {e}") |
| 55 | + raise HTTPException( |
| 56 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 57 | + detail=f"Could not fetch JWKS from Keycloak: {e}" |
| 58 | + ) |
| 59 | + app_logger.info(f"get_jwks::JWKS_CACHE: {JWKS_CACHE}") |
| 60 | + return JWKS_CACHE |
| 61 | + |
| 62 | +def get_public_key(kid: str) -> str: |
| 63 | + """ |
| 64 | + Extracts the public key from the JWKS cache using the key ID (kid). |
| 65 | + """ |
| 66 | + jwks = get_jwks() |
| 67 | + keys = jwks.get("keys", []) |
| 68 | + for key in keys: |
| 69 | + if key.get("kid") == kid: |
| 70 | + return json.dumps(key) |
| 71 | + raise HTTPException( |
| 72 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 73 | + detail="No public key found for the given kid in the JWT header" |
| 74 | + ) |
| 75 | + |
| 76 | +def verify_token(token: str) -> Dict[str, Any]: |
| 77 | + """ |
| 78 | + Verifies the JWT token and returns its payload if valid. |
| 79 | + """ |
| 80 | + try: |
| 81 | + # Get the key ID from the token header to find the right public key. |
| 82 | + unverified_header = jwt.get_unverified_header(token) |
| 83 | + kid = unverified_header.get("kid") |
| 84 | + app_logger.info(f"verify_token::unverified_header: {unverified_header}") |
| 85 | + if not kid: |
| 86 | + raise HTTPException( |
| 87 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 88 | + detail="JWT header is missing 'kid' claim" |
| 89 | + ) |
| 90 | + |
| 91 | + public_key_json = get_public_key(kid) |
| 92 | + public_key = jwk.JWK.from_json(public_key_json).public_key() |
| 93 | + |
| 94 | + # Decode and verify the token. |
| 95 | + # Keycloak JWTs use RS256 algorithm. |
| 96 | + payload = jwt.decode( |
| 97 | + token, |
| 98 | + public_key, |
| 99 | + algorithms=["RS256"], |
| 100 | + audience=KEYCLOAK_AUDIENCE, |
| 101 | + options={"verify_aud": True, "verify_signature": True} |
| 102 | + ) |
| 103 | + return payload |
| 104 | + except ExpiredSignatureError: |
| 105 | + raise HTTPException( |
| 106 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 107 | + detail="Token has expired" |
| 108 | + ) |
| 109 | + except JWTError as e: |
| 110 | + raise HTTPException( |
| 111 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 112 | + detail=f"Invalid token: {e}" |
| 113 | + ) |
| 114 | + except Exception as e: |
| 115 | + raise HTTPException( |
| 116 | + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
| 117 | + detail=f"An error occurred during token validation: {e}" |
| 118 | + ) |
| 119 | + |
| 120 | +# --- FastAPI Endpoint for Traefik ForwardAuth --- |
| 121 | +@app.get("/auth") |
| 122 | +@app.post("/auth") |
| 123 | +async def authenticate(request: Request): |
| 124 | + """ |
| 125 | + Endpoint for Traefik's ForwardAuth middleware. |
| 126 | + Authenticates the request and authorizes based on path and roles. |
| 127 | + """ |
| 128 | + auth_header = request.headers.get("Authorization") |
| 129 | + if not auth_header or not auth_header.startswith("Bearer "): |
| 130 | + raise HTTPException( |
| 131 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 132 | + detail="Authorization header is missing or malformed" |
| 133 | + ) |
| 134 | + |
| 135 | + token = auth_header.split(" ")[1] |
| 136 | + app_logger.info(f"Parsed token: {token}") |
| 137 | + |
| 138 | + # Verify the JWT token |
| 139 | + payload = verify_token(token) |
| 140 | + app_logger.info(f"Payload: {payload}") |
| 141 | + |
| 142 | + # Extract roles from the token. Keycloak typically puts roles under |
| 143 | + # realm_access.roles or resource_access.<client_id>.roles. |
| 144 | + token_roles = set() |
| 145 | + realm_access = payload.get("realm_access", {}) |
| 146 | + if realm_access and isinstance(realm_access, dict): |
| 147 | + token_roles.update(realm_access.get("roles", [])) |
| 148 | + |
| 149 | + resource_access = payload.get("resource_access", {}) |
| 150 | + client_roles = resource_access.get(KEYCLOAK_AUDIENCE, {}).get("roles", []) |
| 151 | + token_roles.update(client_roles) |
| 152 | + |
| 153 | + if not token_roles: |
| 154 | + # A valid token with no roles might not be enough for any path. |
| 155 | + raise HTTPException( |
| 156 | + status_code=status.HTTP_403_FORBIDDEN, |
| 157 | + detail="Token does not contain any roles" |
| 158 | + ) |
| 159 | + |
| 160 | + # Get the request path from the Traefik header |
| 161 | + forwarded_uri = request.headers.get("X-Forwarded-Uri", "/") |
| 162 | + |
| 163 | + # Find the required roles for the current path |
| 164 | + required_roles: List[str] = [] |
| 165 | + # Find the most specific match for the path |
| 166 | + best_match = "" |
| 167 | + for path_prefix, roles in ROLE_MAPPING.items(): |
| 168 | + if forwarded_uri.startswith(path_prefix) and len(path_prefix) > len(best_match): |
| 169 | + best_match = path_prefix |
| 170 | + required_roles = roles |
| 171 | + |
| 172 | + # If no matching path is found in the mapping, assume it's unauthorized |
| 173 | + if not required_roles: |
| 174 | + raise HTTPException( |
| 175 | + status_code=status.HTTP_403_FORBIDDEN, |
| 176 | + detail=f"Path '{forwarded_uri}' is not configured for access control" |
| 177 | + ) |
| 178 | + |
| 179 | + # Check if the user has any of the required roles |
| 180 | + if not token_roles.intersection(required_roles): |
| 181 | + raise HTTPException( |
| 182 | + status_code=status.HTTP_403_FORBIDDEN, |
| 183 | + detail=f"Insufficient permissions. Required roles: {required_roles}" |
| 184 | + ) |
| 185 | + |
| 186 | + # If all checks pass, return a 200 OK. |
| 187 | + return { |
| 188 | + "message": "Authentication successful", |
| 189 | + "user_id": payload.get("sub"), |
| 190 | + "roles": list(token_roles) |
| 191 | + } |
| 192 | + |
0 commit comments