Skip to content

Commit a91c878

Browse files
authored
Merge branch 'main' into poc_parameters_endpoint
2 parents 6f19971 + 19df780 commit a91c878

File tree

15 files changed

+517
-161
lines changed

15 files changed

+517
-161
lines changed

app/auth.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from typing import Any, Dict
22
import httpx
33
import jwt
4-
from fastapi import Depends, HTTPException, WebSocket, status
4+
from fastapi import Depends, WebSocket, status
55
from fastapi.security import OAuth2AuthorizationCodeBearer
66
from jwt import PyJWKClient
77
from loguru import logger
88

9+
from app.error import AuthException, DispatcherException
10+
from app.schemas.websockets import WSStatusMessage
11+
912
from .config.settings import settings
1013

1114
# Keycloak OIDC info
@@ -37,9 +40,9 @@ def _decode_token(token: str):
3740
)
3841
return payload
3942
except Exception:
40-
raise HTTPException(
41-
status_code=status.HTTP_401_UNAUTHORIZED,
42-
detail="Could not validate credentials",
43+
raise AuthException(
44+
http_status=status.HTTP_401_UNAUTHORIZED,
45+
message="Could not validate credentials. Please retry signing in.",
4346
)
4447

4548

@@ -55,6 +58,7 @@ async def websocket_authenticate(websocket: WebSocket) -> str | None:
5558
"""
5659
logger.debug("Authenticating websocket")
5760
token = websocket.query_params.get("token")
61+
5862
if not token:
5963
logger.error("Token is missing from websocket authentication")
6064
await websocket.close(code=1008, reason="Missing token")
@@ -63,9 +67,22 @@ async def websocket_authenticate(websocket: WebSocket) -> str | None:
6367
try:
6468
await websocket.accept()
6569
return token
70+
except DispatcherException as ae:
71+
logger.error(f"Dispatcher exception detected: {ae.message}")
72+
await websocket.send_json(
73+
WSStatusMessage(type="error", message=ae.message).model_dump()
74+
)
75+
await websocket.close(code=1008, reason=ae.error_code)
76+
return None
6677
except Exception as e:
67-
logger.error(f"Invalid token in websocket authentication: {e}")
68-
await websocket.close(code=1008, reason="Invalid token")
78+
logger.error(f"Unexpected error occurred during websocket authentication: {e}")
79+
await websocket.send_json(
80+
WSStatusMessage(
81+
type="error",
82+
message="Something went wrong during authentication. Please try again.",
83+
).model_dump()
84+
)
85+
await websocket.close(code=1008, reason="INTERNAL_ERROR")
6986
return None
7087

7188

@@ -81,15 +98,15 @@ async def exchange_token_for_provider(
8198
8299
:return: The token response (dict) on success.
83100
84-
:raise: Raises HTTPException with an appropriate status and message on error.
101+
:raise: Raises AuthException with an appropriate status and message on error.
85102
"""
86103
token_url = f"{KEYCLOAK_BASE_URL}/protocol/openid-connect/token"
87104

88105
# Check if the necessary settings are in place
89106
if not settings.keycloak_client_id or not settings.keycloak_client_secret:
90-
raise HTTPException(
91-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
92-
detail="Token exchange not configured on the server (missing client credentials).",
107+
raise AuthException(
108+
http_status=status.HTTP_500_INTERNAL_SERVER_ERROR,
109+
message="Token exchange not configured on the server (missing client credentials).",
93110
)
94111

95112
payload = {
@@ -105,9 +122,12 @@ async def exchange_token_for_provider(
105122
resp = await client.post(token_url, data=payload)
106123
except httpx.RequestError as exc:
107124
logger.error(f"Token exchange network error for provider={provider}: {exc}")
108-
raise HTTPException(
109-
status_code=status.HTTP_502_BAD_GATEWAY,
110-
detail="Failed to contact the identity provider for token exchange.",
125+
raise AuthException(
126+
http_status=status.HTTP_502_BAD_GATEWAY,
127+
message=(
128+
f"Could not authenticate with {provider}. Please contact APEx support or reach out "
129+
"through the <a href='https://forum.apex.esa.int/'>APEx User Forum</a>."
130+
),
111131
)
112132

113133
# Parse response
@@ -117,9 +137,12 @@ async def exchange_token_for_provider(
117137
logger.error(
118138
f"Token exchange invalid JSON response (status={resp.status_code})"
119139
)
120-
raise HTTPException(
121-
status_code=status.HTTP_502_BAD_GATEWAY,
122-
detail="Invalid response from identity provider during token exchange.",
140+
raise AuthException(
141+
http_status=status.HTTP_502_BAD_GATEWAY,
142+
message=(
143+
f"Could not authenticate with {provider}. Please contact APEx support or reach out "
144+
"through the <a href='https://forum.apex.esa.int/'>APEx User Forum</a>."
145+
),
123146
)
124147

125148
if resp.status_code != 200:
@@ -136,7 +159,16 @@ async def exchange_token_for_provider(
136159
else status.HTTP_502_BAD_GATEWAY
137160
)
138161

139-
raise HTTPException(client_status, detail=body)
162+
raise AuthException(
163+
http_status=client_status,
164+
message=(
165+
f"Please link your account with {provider} in your "
166+
"<a href='https://{settings.keycloak_host}/realms/{settings.keycloak_realm}/"
167+
"account'>Account Dashboard</a>"
168+
if body.get("error", "") == "not_linked"
169+
else f"Could not authenticate with {provider}: {err}"
170+
),
171+
)
140172

141173
# Successful exchange, return token response (access_token, expires_in, etc.)
142174
return body

app/database/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_db():
3535
yield db
3636
db.commit()
3737
except Exception:
38-
logger.exception("An error occurred during database retrieval")
38+
logger.error("An error occurred during database retrieval")
3939
db.rollback()
4040
raise
4141
finally:

app/error.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Any, Dict, Optional
2+
from fastapi import status
3+
from pydantic import BaseModel
4+
5+
6+
class ErrorResponse(BaseModel):
7+
status: str = "error"
8+
error_code: str
9+
message: str
10+
details: Optional[Dict[str, Any]] = None
11+
request_id: Optional[str] = None
12+
13+
14+
class DispatcherException(Exception):
15+
"""
16+
Base domain exception for the APEx Dispatch API.
17+
"""
18+
19+
http_status: int = status.HTTP_400_BAD_REQUEST
20+
error_code: str = "APEX_ERROR"
21+
message: str = "An error occurred."
22+
details: Optional[Dict[str, Any]] = None
23+
24+
def __init__(
25+
self,
26+
message: Optional[str] = None,
27+
error_code: Optional[str] = None,
28+
http_status: Optional[int] = None,
29+
details: Optional[Dict[str, Any]] = None,
30+
):
31+
if message:
32+
self.message = message
33+
if error_code:
34+
self.error_code = error_code
35+
if http_status:
36+
self.http_status = http_status
37+
if details:
38+
self.details = details
39+
40+
def __str__(self):
41+
return f"{self.error_code}: {self.message}"
42+
43+
44+
class AuthException(DispatcherException):
45+
def __init__(
46+
self,
47+
http_status: Optional[int] = status.HTTP_401_UNAUTHORIZED,
48+
message: Optional[str] = "Authentication failed.",
49+
):
50+
super().__init__(message, "AUTHENTICATION_FAILED", http_status)
51+
52+
53+
class JobNotFoundException(DispatcherException):
54+
http_status: int = status.HTTP_404_NOT_FOUND
55+
error_code: str = "JOB_NOT_FOUND"
56+
message: str = "The requested job was not found."
57+
58+
59+
class TaskNotFoundException(DispatcherException):
60+
http_status: int = status.HTTP_404_NOT_FOUND
61+
error_code: str = "TASK_NOT_FOUND"
62+
message: str = "The requested task was not found."
63+
64+
65+
class InternalException(DispatcherException):
66+
http_status: int = status.HTTP_500_INTERNAL_SERVER_ERROR
67+
error_code: str = "INTERNAL_ERROR"
68+
message: str = "An internal server error occurred."

app/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi.middleware.cors import CORSMiddleware
33

44
from app.middleware.correlation_id import add_correlation_id
5+
from app.middleware.error_handling import register_exception_handlers
56
from app.platforms.dispatcher import load_processing_platforms
67
from app.services.tiles.base import load_grids
78
from app.config.logger import setup_logging
@@ -36,6 +37,7 @@
3637
)
3738

3839
app.middleware("http")(add_correlation_id)
40+
register_exception_handlers(app)
3941

4042
# include routers
4143
app.include_router(tiles.router)

app/middleware/error_handling.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Any
2+
from fastapi import Request, status
3+
from fastapi.exceptions import RequestValidationError
4+
from fastapi.responses import JSONResponse
5+
from app.error import DispatcherException, ErrorResponse
6+
from app.middleware.correlation_id import correlation_id_ctx
7+
from loguru import logger
8+
9+
10+
def get_dispatcher_error_response(
11+
exc: DispatcherException, request_id: str
12+
) -> ErrorResponse:
13+
return ErrorResponse(
14+
error_code=exc.error_code,
15+
message=exc.message,
16+
details=exc.details,
17+
request_id=request_id,
18+
)
19+
20+
21+
async def dispatch_exception_handler(request: Request, exc: DispatcherException):
22+
23+
content = get_dispatcher_error_response(exc, correlation_id_ctx.get())
24+
logger.exception(f"DispatcherException raised: {exc.message}")
25+
return JSONResponse(status_code=exc.http_status, content=content.dict())
26+
27+
28+
async def generic_exception_handler(request: Request, exc: Exception):
29+
30+
# DO NOT expose internal exceptions to the client
31+
content = ErrorResponse(
32+
error_code="INTERNAL_SERVER_ERROR",
33+
message="An unexpected error occurred.",
34+
details=None,
35+
request_id=correlation_id_ctx.get(),
36+
)
37+
38+
logger.exception(f"GenericException raised: {exc}")
39+
return JSONResponse(
40+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=content.dict()
41+
)
42+
43+
44+
def _parse_validation_error(err: Any):
45+
if "ctx" in err:
46+
del err["ctx"]
47+
return err
48+
49+
50+
async def validation_exception_handler(request: Request, exc: RequestValidationError):
51+
52+
logger.error(f"Request validation error: {exc.__class__.__name__}: {exc}")
53+
content = ErrorResponse(
54+
error_code="VALIDATION_ERROR",
55+
message="Request validation failed.",
56+
details={"errors": [_parse_validation_error(error) for error in exc.errors()]},
57+
request_id=correlation_id_ctx.get(),
58+
)
59+
60+
logger.error(content.dict())
61+
62+
return JSONResponse(
63+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content=content.dict()
64+
)
65+
66+
67+
def register_exception_handlers(app):
68+
"""
69+
Call this in main.py after creating the FastAPI() instance.
70+
"""
71+
72+
app.add_exception_handler(DispatcherException, dispatch_exception_handler)
73+
app.add_exception_handler(RequestValidationError, validation_exception_handler)
74+
app.add_exception_handler(Exception, generic_exception_handler)

app/routers/jobs_status.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from loguru import logger
88

99
from app.database.db import SessionLocal, get_db
10+
from app.error import DispatcherException, ErrorResponse, InternalException
11+
from app.middleware.error_handling import get_dispatcher_error_response
1012
from app.schemas.jobs_status import JobsFilter, JobsStatusResponse
1113
from app.schemas.websockets import WSStatusMessage
1214
from app.services.processing import get_processing_jobs_by_user_id
@@ -22,6 +24,19 @@
2224
"/jobs_status",
2325
tags=["Upscale Tasks", "Unit Jobs"],
2426
summary="Get a list of all upscaling tasks & processing jobs for the authenticated user",
27+
responses={
28+
InternalException.http_status: {
29+
"description": "Internal server error",
30+
"model": ErrorResponse,
31+
"content": {
32+
"application/json": {
33+
"example": get_dispatcher_error_response(
34+
InternalException(), "request-id"
35+
)
36+
}
37+
},
38+
},
39+
},
2540
)
2641
async def get_jobs_status(
2742
db: Session = Depends(get_db),
@@ -34,21 +49,29 @@ async def get_jobs_status(
3449
"""
3550
Return combined list of upscaling tasks and processing jobs for the authenticated user.
3651
"""
37-
logger.debug("Fetching jobs list")
38-
upscaling_tasks = (
39-
await get_upscaling_tasks_by_user_id(token, db)
40-
if JobsFilter.upscaling in filter
41-
else []
42-
)
43-
processing_jobs = (
44-
await get_processing_jobs_by_user_id(token, db)
45-
if JobsFilter.processing in filter
46-
else []
47-
)
48-
return JobsStatusResponse(
49-
upscaling_tasks=upscaling_tasks,
50-
processing_jobs=processing_jobs,
51-
)
52+
try:
53+
logger.debug("Fetching jobs list")
54+
upscaling_tasks = (
55+
await get_upscaling_tasks_by_user_id(token, db)
56+
if JobsFilter.upscaling in filter
57+
else []
58+
)
59+
processing_jobs = (
60+
await get_processing_jobs_by_user_id(token, db)
61+
if JobsFilter.processing in filter
62+
else []
63+
)
64+
return JobsStatusResponse(
65+
upscaling_tasks=upscaling_tasks,
66+
processing_jobs=processing_jobs,
67+
)
68+
except DispatcherException as de:
69+
raise de
70+
except Exception as e:
71+
logger.error(f"Error retrieving job status: {e}")
72+
raise InternalException(
73+
message="An error occurred while retrieving the job status."
74+
)
5275

5376

5477
@router.websocket(
@@ -91,8 +114,20 @@ async def ws_jobs_status(
91114

92115
except WebSocketDisconnect:
93116
logger.info("WebSocket disconnected")
117+
except DispatcherException as ae:
118+
logger.error(f"Dispatcher exception detected: {ae.message}")
119+
await websocket.send_json(
120+
WSStatusMessage(type="error", message=ae.message).model_dump()
121+
)
122+
await websocket.close(code=1011, reason=ae.error_code)
94123
except Exception as e:
95-
logger.exception(f"Error in jobs_status_ws: {e}")
96-
await websocket.close(code=1011, reason="Error in job status websocket: {e}")
124+
logger.error(f"Unexpected error occurred during websocket : {e}")
125+
await websocket.send_json(
126+
WSStatusMessage(
127+
type="error",
128+
message="An error occurred while monitoring the job status.",
129+
).model_dump()
130+
)
131+
await websocket.close(code=1011, reason="INTERNAL_ERROR")
97132
finally:
98133
db.close()

0 commit comments

Comments
 (0)