Skip to content

Commit 8e80248

Browse files
committed
feat: integration of Keycloak
1 parent 56e9af2 commit 8e80248

File tree

12 files changed

+269
-112
lines changed

12 files changed

+269
-112
lines changed

app/auth.py

Lines changed: 65 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,69 @@
1-
from fastapi_keycloak import FastAPIKeycloak
1+
import jwt
2+
from fastapi import Depends, HTTPException, WebSocket, status
3+
from fastapi.security import OAuth2AuthorizationCodeBearer
4+
from fastapi_keycloak import OIDCUser
5+
from jwt import PyJWKClient
6+
from loguru import logger
7+
28
from .config.settings import settings
39

4-
# create the FastAPIKeycloak instance — used to protect routes
5-
# The server_url must include trailing slash for library
6-
keycloak = FastAPIKeycloak(
7-
server_url=str(settings.keycloak_server_url),
8-
client_id=settings.keycloak_client_id,
9-
client_secret=settings.keycloak_client_secret,
10-
admin_client_secret=settings.keycloak_client_secret, # optional for admin operations
11-
realm=settings.keycloak_realm,
12-
callback_uri="http://localhost:8000/callback", # for auth code flow if needed
10+
# Keycloak OIDC info
11+
KEYCLOAK_BASE_URL = f"https://{settings.keycloak_host}/realms/{settings.keycloak_realm}"
12+
JWKS_URL = f"{KEYCLOAK_BASE_URL}/protocol/openid-connect/certs"
13+
ALGORITHM = "RS256"
14+
15+
16+
# Keycloak OIDC endpoints
17+
oauth2_scheme = OAuth2AuthorizationCodeBearer(
18+
authorizationUrl=f"https://{settings.keycloak_host}/realms/{settings.keycloak_realm}/"
19+
"protocol/openid-connect/auth",
20+
tokenUrl=f"https://{settings.keycloak_host}/realms/{settings.keycloak_realm}/"
21+
"protocol/openid-connect/token",
1322
)
1423

15-
# expose a helper dependency for current user
16-
get_current_user = keycloak.get_current_user
17-
get_current_active_user = keycloak.get_current_user
24+
# PyJWT helper to fetch and cache keys
25+
jwks_client = PyJWKClient(JWKS_URL, cache_keys=True)
26+
27+
28+
def _decode_token(token: str):
29+
try:
30+
signing_key = jwks_client.get_signing_key_from_jwt(token).key
31+
payload = jwt.decode(
32+
token,
33+
signing_key,
34+
algorithms=[ALGORITHM],
35+
issuer=KEYCLOAK_BASE_URL,
36+
)
37+
return payload
38+
except Exception:
39+
raise HTTPException(
40+
status_code=status.HTTP_401_UNAUTHORIZED,
41+
detail="Could not validate credentials",
42+
)
43+
44+
45+
def get_current_user_id(token: str = Depends(oauth2_scheme)):
46+
user: OIDCUser = _decode_token(token)
47+
return user["sub"]
48+
49+
50+
async def websocket_authenticate(websocket: WebSocket) -> str | None:
51+
"""
52+
Authenticate a WebSocket connection using a JWT token from query params.
53+
Returns the ID of the authenticated user payload if valid, otherwise closes the connection.
54+
"""
55+
logger.debug("Authenticating websocket")
56+
token = websocket.query_params.get("token")
57+
if not token:
58+
logger.error("Token is missing from websocket authentication")
59+
await websocket.close(code=1008, reason="Missing token")
60+
return None
61+
62+
try:
63+
user_id = get_current_user_id(token)
64+
await websocket.accept()
65+
return user_id
66+
except Exception as e:
67+
logger.error(f"Invalid token in websocket authentication: {e}")
68+
await websocket.close(code=1008, reason="Invalid token")
69+
return None

app/config/settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import AnyHttpUrl, Field
1+
from pydantic import Field
22
from pydantic_settings import BaseSettings, SettingsConfigDict
33

44

@@ -13,9 +13,9 @@ class Settings(BaseSettings):
1313
env: str = Field(default="development", json_schema_extra={"env": "APP_ENV"})
1414

1515
# Keycloak / OIDC
16-
keycloak_server_url: AnyHttpUrl = Field(
17-
default=AnyHttpUrl("https://localhost"),
18-
json_schema_extra={"env": "KEYCLOAK_SERVER_URL"},
16+
keycloak_host: str = Field(
17+
default=str("localhost"),
18+
json_schema_extra={"env": "KEYCLOAK_HOST"},
1919
)
2020
keycloak_realm: str = Field(default="", json_schema_extra={"env": "KEYCLOAK_REALM"})
2121
keycloak_client_id: str = Field(

app/main.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from app.middleware.correlation_id import add_correlation_id
44
from app.platforms.dispatcher import load_processing_platforms
55
from app.services.tiles.base import load_grids
6-
from .config.logger import setup_logging
7-
from .config.settings import settings
8-
from .routers import jobs_status, unit_jobs, health, tiles, upscale_tasks
6+
from app.config.logger import setup_logging
7+
from app.config.settings import settings
8+
from app.routers import jobs_status, unit_jobs, health, tiles, upscale_tasks
99

1010
setup_logging()
1111

@@ -20,9 +20,6 @@
2020

2121
app.middleware("http")(add_correlation_id)
2222

23-
# Register Keycloak - must be done after FastAPI app creation
24-
# keycloak.register(app, prefix="/auth") # mounts OIDC endpoints for login if needed
25-
2623
# include routers
2724
app.include_router(tiles.router)
2825
app.include_router(jobs_status.router)

app/routers/jobs_status.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from app.schemas.websockets import WSStatusMessage
1111
from app.services.processing import get_processing_jobs_by_user_id
1212
from app.services.upscaling import get_upscaling_tasks_by_user_id
13+
from app.auth import get_current_user_id, websocket_authenticate
1314

1415
router = APIRouter()
1516

@@ -21,7 +22,7 @@
2122
)
2223
async def get_jobs_status(
2324
db: Session = Depends(get_db),
24-
user: str = "foobar",
25+
user: str = Depends(get_current_user_id),
2526
) -> JobsStatusResponse:
2627
"""
2728
Return combined list of upscaling tasks and processing jobs for the authenticated user.
@@ -37,14 +38,16 @@ async def get_jobs_status(
3738
"/ws/jobs_status",
3839
)
3940
async def ws_jobs_status(
40-
websocket: WebSocket, user: str = "foobar", interval: int = 10
41+
websocket: WebSocket,
42+
interval: int = 10,
4143
):
4244
"""
4345
Return combined list of upscaling tasks and processing jobs for the authenticated user.
4446
"""
4547

46-
await websocket.accept()
47-
logger.debug(f"WebSocket connected for user {user}")
48+
user = await websocket_authenticate(websocket)
49+
if not user:
50+
return
4851

4952
await websocket.send_json(
5053
WSStatusMessage(type="init", message="Starting status stream").model_dump()

app/routers/unit_jobs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from loguru import logger
44
from sqlalchemy.orm import Session
55

6+
from app.auth import get_current_user_id
67
from app.database.db import get_db
78
from app.schemas.enum import OutputFormatEnum, ProcessTypeEnum
89
from app.schemas.unit_job import (
@@ -99,7 +100,7 @@ async def create_unit_job(
99100
),
100101
],
101102
db: Session = Depends(get_db),
102-
user: str = "foobar",
103+
user: str = Depends(get_current_user_id),
103104
) -> ProcessingJobSummary:
104105
"""Create a new processing job with the provided data."""
105106
try:
@@ -118,7 +119,9 @@ async def create_unit_job(
118119
responses={404: {"description": "Processing job not found"}},
119120
)
120121
async def get_job(
121-
job_id: int, db: Session = Depends(get_db), user: str = "foobar"
122+
job_id: int,
123+
db: Session = Depends(get_db),
124+
user: str = Depends(get_current_user_id),
122125
) -> ProcessingJob:
123126
job = get_processing_job_by_user_id(db, job_id, user)
124127
if not job:
@@ -136,7 +139,9 @@ async def get_job(
136139
responses={404: {"description": "Processing job not found"}},
137140
)
138141
async def get_job_results(
139-
job_id: int, db: Session = Depends(get_db), user: str = "foobar"
142+
job_id: int,
143+
db: Session = Depends(get_db),
144+
user: str = Depends(get_current_user_id),
140145
) -> Collection | None:
141146
try:
142147
result = get_processing_job_results(db, job_id, user)

app/routers/upscale_tasks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from loguru import logger
1515
from sqlalchemy.orm import Session
1616

17+
from app.auth import get_current_user_id, websocket_authenticate
1718
from app.database.db import SessionLocal, get_db
1819
from app.schemas.enum import OutputFormatEnum, ProcessTypeEnum
1920
from app.schemas.unit_job import (
@@ -106,7 +107,7 @@ async def create_upscale_task(
106107
],
107108
background_tasks: BackgroundTasks,
108109
db: Session = Depends(get_db),
109-
user: str = "foobar",
110+
user: str = Depends(get_current_user_id),
110111
) -> UpscalingTaskSummary:
111112
"""Create a new upscaling job with the provided data."""
112113
try:
@@ -133,7 +134,9 @@ async def create_upscale_task(
133134
responses={404: {"description": "Upscale task not found"}},
134135
)
135136
async def get_upscale_task(
136-
task_id: int, db: Session = Depends(get_db), user: str = "foobar"
137+
task_id: int,
138+
db: Session = Depends(get_db),
139+
user: str = Depends(get_current_user_id),
137140
) -> UpscalingTask:
138141
job = get_upscaling_task_by_user_id(db, task_id, user)
139142
if not job:
@@ -151,10 +154,12 @@ async def get_upscale_task(
151154
async def ws_task_status(
152155
websocket: WebSocket,
153156
task_id: int,
154-
user: str = "foobar",
155157
interval: int = 10,
156158
):
157-
await websocket.accept()
159+
user = await websocket_authenticate(websocket)
160+
if not user:
161+
return
162+
158163
logger.info("WebSocket connected", extra={"user": user, "task_id": task_id})
159164

160165
try:

env.example

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# Keycloak
2-
KEYCLOAK_SERVER_URL=
2+
KEYCLOAK_HOST=
33
KEYCLOAK_REALM=
4-
KEYCLOAK_CLIENT_ID=
5-
KEYCLOAK_CLIENT_SECRET=
64

75
# App
86
APP_NAME="APEx Dispatch API"

0 commit comments

Comments
 (0)