Skip to content

Commit aea2676

Browse files
authored
Feat/jwt validation from jwks (#897)
* feat: add ISSUER_CONFIG field with validation * feat: implement JWKSManager for JWT validation and user info extraction * feat: add JWT validation to get_accounts endpoint and log user information
1 parent 26c93bd commit aea2676

File tree

3 files changed

+175
-2
lines changed

3 files changed

+175
-2
lines changed

app/api/v1/routes/access.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111

1212
from core.logging import get_module_logger
13+
from core.security import validate_jwt_token
1314
from api.dependencies.rate_limits import get_limiter
1415

1516

@@ -89,7 +90,9 @@ async def create_access_request(
8990

9091
@router.get("/accounts")
9192
@limiter.limit("5/minute")
92-
async def get_accounts(request: Request, user: dict = Depends(get_current_user)):
93+
async def get_accounts(
94+
request: Request, token_data: dict = Depends(validate_jwt_token)
95+
):
9396
"""
9497
Endpoint to retrieve active AWS account names.
9598
@@ -103,6 +106,13 @@ async def get_accounts(request: Request, user: dict = Depends(get_current_user))
103106
Returns:
104107
list: A list of active AWS account names.
105108
"""
109+
logger.info(
110+
"get_accounts",
111+
user=token_data["sub"],
112+
email=token_data["email"],
113+
issuer=token_data["iss"],
114+
token_data=token_data,
115+
)
106116
return get_active_account_names()
107117

108118

app/core/config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""SRE Bot configuration settings."""
22

3-
from pydantic import Field
3+
from typing import Any, Dict, Optional
4+
from pydantic import Field, field_validator
45
from pydantic_settings import BaseSettings, SettingsConfigDict
56
import structlog
67

@@ -241,6 +242,27 @@ class ServerSettings(BaseSettings):
241242
SECRET_KEY: str | None = Field(default=None, alias="SESSION_SECRET_KEY")
242243
ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
243244
ACCESS_TOKEN_MAX_AGE_MINUTES: int = 1440 # Defaults to 24 hours
245+
ISSUER_CONFIG: Optional[Dict[str, Dict[str, Any]]] = Field(
246+
default=None,
247+
alias="ISSUER_CONFIG",
248+
)
249+
250+
@field_validator("ISSUER_CONFIG", mode="before")
251+
@classmethod
252+
def validate_issuer_config(cls, v: Optional[Dict[str, Dict[str, Any]]]) -> Any:
253+
"""Validate the ISSUER_CONFIG field.
254+
255+
Args:
256+
cls: The class itself.
257+
v: The value of the ISSUER_CONFIG field.
258+
259+
Returns:
260+
The validated value of the ISSUER_CONFIG field.
261+
"""
262+
if v is None or not isinstance(v, dict):
263+
return {}
264+
return v
265+
244266
model_config = SettingsConfigDict(
245267
env_file=".env",
246268
case_sensitive=True,

app/core/security.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from typing import Any, Dict, Optional, Tuple
2+
from fastapi import HTTPException, Security
3+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
4+
from jwt import PyJWKClient, PyJWTError, PyJWKClientError, decode
5+
6+
from core.config import settings
7+
from core.logging import get_module_logger
8+
9+
10+
ISSUER_CONFIG = settings.server.ISSUER_CONFIG
11+
12+
logger = get_module_logger()
13+
security = HTTPBearer()
14+
15+
16+
class JWKSManager:
17+
"""
18+
A class to manage JWKS clients for different issuers.
19+
It initializes a JWKS client for each issuer in the provided configuration.
20+
Attributes:
21+
issuer_config (Dict[str, Dict[str, Any]]): A dictionary containing issuer configurations.
22+
jwks_clients (Dict[str, PyJWKClient]): A dictionary to store JWKS clients for each issuer.
23+
"""
24+
25+
def __init__(self, issuer_config: Optional[Dict[str, Dict[str, Any]]]):
26+
self.issuer_config = issuer_config
27+
self.jwks_clients: Dict[str, PyJWKClient] = {}
28+
29+
def get_jwks_client(self, issuer: str) -> Optional[PyJWKClient]:
30+
"""Get the JWKS client for the specified issuer.
31+
32+
Args:
33+
issuer (str): The issuer for which to get the JWKS client.
34+
Returns:
35+
Optional[PyJWKClient]: The JWKS client for the specified issuer, or None if not found.
36+
"""
37+
if not self.issuer_config or issuer not in self.issuer_config:
38+
return None
39+
if issuer not in self.jwks_clients:
40+
try:
41+
cfg = self.issuer_config[issuer]
42+
self.jwks_clients[issuer] = PyJWKClient(
43+
cfg["jwks_uri"], cache_jwk_set=True, lifespan=3600, timeout=10
44+
)
45+
except Exception as e:
46+
logger.warning(
47+
"jwks_client_initialization_failed", error=str(e), issuer=issuer
48+
)
49+
return None
50+
return self.jwks_clients[issuer]
51+
52+
53+
jwks_manager = JWKSManager(getattr(settings.server, "ISSUER_CONFIG", None))
54+
55+
56+
def get_issuer_from_token(token: str) -> Optional[str]:
57+
"""
58+
Extract the issuer from the JWT token without verifying the signature.
59+
Args:
60+
token (str): The JWT token.
61+
Returns:
62+
str | None: The issuer (iss) claim from the token if present, otherwise None.
63+
"""
64+
try:
65+
unverified_payload = decode(token, options={"verify_signature": False})
66+
return unverified_payload.get("iss")
67+
except Exception:
68+
return None
69+
70+
71+
def extract_user_info_from_token(token: str) -> Tuple[Optional[str], Optional[str]]:
72+
"""
73+
Extract user ID and email from the JWT token without verifying the signature.
74+
Args:
75+
token (str): The JWT token.
76+
Returns:
77+
Tuple[str, str] | Tuple[None, None]: A tuple containing the user ID and email if present,
78+
otherwise (None, None).
79+
"""
80+
try:
81+
payload = decode(token, options={"verify_signature": False})
82+
user_id = None
83+
user_email = None
84+
85+
# For user JWTs, email may be a top-level claim
86+
if "email" in payload:
87+
user_email = payload["email"]
88+
89+
# sub is always present
90+
if "sub" in payload:
91+
user_id = payload["sub"].split("/")[-1]
92+
93+
return user_id, user_email
94+
except Exception as e:
95+
logger.warning(
96+
"user_info_extraction_failed",
97+
error=str(e),
98+
payload=payload,
99+
)
100+
return None, None
101+
102+
103+
async def validate_jwt_token(
104+
credentials: HTTPAuthorizationCredentials = Security(security),
105+
) -> Dict[str, Any]:
106+
"""
107+
Validate the JWT token and extract user information.
108+
Args:
109+
credentials (HTTPAuthorizationCredentials): The HTTP authorization credentials containing the JWT token.
110+
Returns:
111+
Dict[str, Any]: The decoded payload of the JWT token.
112+
Raises:
113+
HTTPException: If the token is invalid, untrusted, or if any other error occurs during validation.
114+
"""
115+
if (
116+
credentials is None
117+
or not credentials.scheme == "Bearer"
118+
or not credentials.credentials
119+
):
120+
raise HTTPException(status_code=401, detail="Missing or invalid token")
121+
token = credentials.credentials
122+
issuer = get_issuer_from_token(token)
123+
if not issuer:
124+
raise HTTPException(status_code=401, detail="Issuer not found in token")
125+
jwks_client = jwks_manager.get_jwks_client(issuer)
126+
if not jwks_client or not jwks_manager.issuer_config:
127+
raise HTTPException(status_code=401, detail="Untrusted or missing token issuer")
128+
cfg = jwks_manager.issuer_config[issuer]
129+
try:
130+
signing_key = jwks_client.get_signing_key_from_jwt(token)
131+
payload = decode(
132+
token,
133+
signing_key.key,
134+
algorithms=cfg["algorithms"],
135+
audience=cfg["audience"],
136+
options={"verify_exp": True},
137+
)
138+
return payload
139+
except (PyJWKClientError, PyJWTError) as e:
140+
logger.warning("jwt_validation_failed", error=str(e), issuer=issuer)
141+
raise HTTPException(status_code=401, detail=f"Invalid token: {str(e)}") from e

0 commit comments

Comments
 (0)