|
1 | | -from fastapi_keycloak import FastAPIKeycloak |
| 1 | +import jwt |
| 2 | +from fastapi import Depends, HTTPException, WebSocket, status |
| 3 | +from fastapi.security import OAuth2AuthorizationCodeBearer |
| 4 | +from fastapi_keycloak import OIDCUser |
| 5 | +from jwt import PyJWKClient |
| 6 | +from loguru import logger |
| 7 | + |
2 | 8 | from .config.settings import settings |
3 | 9 |
|
4 | | -# create the FastAPIKeycloak instance — used to protect routes |
5 | | -# The server_url must include trailing slash for library |
6 | | -keycloak = FastAPIKeycloak( |
7 | | - server_url=str(settings.keycloak_server_url), |
8 | | - client_id=settings.keycloak_client_id, |
9 | | - client_secret=settings.keycloak_client_secret, |
10 | | - admin_client_secret=settings.keycloak_client_secret, # optional for admin operations |
11 | | - realm=settings.keycloak_realm, |
12 | | - callback_uri="http://localhost:8000/callback", # for auth code flow if needed |
| 10 | +# Keycloak OIDC info |
| 11 | +KEYCLOAK_BASE_URL = f"https://{settings.keycloak_host}/realms/{settings.keycloak_realm}" |
| 12 | +JWKS_URL = f"{KEYCLOAK_BASE_URL}/protocol/openid-connect/certs" |
| 13 | +ALGORITHM = "RS256" |
| 14 | + |
| 15 | + |
| 16 | +# Keycloak OIDC endpoints |
| 17 | +oauth2_scheme = OAuth2AuthorizationCodeBearer( |
| 18 | + authorizationUrl=f"https://{settings.keycloak_host}/realms/{settings.keycloak_realm}/" |
| 19 | + "protocol/openid-connect/auth", |
| 20 | + tokenUrl=f"https://{settings.keycloak_host}/realms/{settings.keycloak_realm}/" |
| 21 | + "protocol/openid-connect/token", |
13 | 22 | ) |
14 | 23 |
|
15 | | -# expose a helper dependency for current user |
16 | | -get_current_user = keycloak.get_current_user |
17 | | -get_current_active_user = keycloak.get_current_user |
| 24 | +# PyJWT helper to fetch and cache keys |
| 25 | +jwks_client = PyJWKClient(JWKS_URL, cache_keys=True) |
| 26 | + |
| 27 | + |
| 28 | +def _decode_token(token: str): |
| 29 | + try: |
| 30 | + signing_key = jwks_client.get_signing_key_from_jwt(token).key |
| 31 | + payload = jwt.decode( |
| 32 | + token, |
| 33 | + signing_key, |
| 34 | + algorithms=[ALGORITHM], |
| 35 | + issuer=KEYCLOAK_BASE_URL, |
| 36 | + ) |
| 37 | + return payload |
| 38 | + except Exception: |
| 39 | + raise HTTPException( |
| 40 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 41 | + detail="Could not validate credentials", |
| 42 | + ) |
| 43 | + |
| 44 | + |
| 45 | +def get_current_user_id(token: str = Depends(oauth2_scheme)): |
| 46 | + user: OIDCUser = _decode_token(token) |
| 47 | + return user["sub"] |
| 48 | + |
| 49 | + |
| 50 | +async def websocket_authenticate(websocket: WebSocket) -> str | None: |
| 51 | + """ |
| 52 | + Authenticate a WebSocket connection using a JWT token from query params. |
| 53 | + Returns the ID of the authenticated user payload if valid, otherwise closes the connection. |
| 54 | + """ |
| 55 | + logger.debug("Authenticating websocket") |
| 56 | + token = websocket.query_params.get("token") |
| 57 | + if not token: |
| 58 | + logger.error("Token is missing from websocket authentication") |
| 59 | + await websocket.close(code=1008, reason="Missing token") |
| 60 | + return None |
| 61 | + |
| 62 | + try: |
| 63 | + user_id = get_current_user_id(token) |
| 64 | + await websocket.accept() |
| 65 | + return user_id |
| 66 | + except Exception as e: |
| 67 | + logger.error(f"Invalid token in websocket authentication: {e}") |
| 68 | + await websocket.close(code=1008, reason="Invalid token") |
| 69 | + return None |
0 commit comments