|
| 1 | +# type: ignore |
| 2 | +# ruff: noqa |
| 3 | +"""Mock OIDC server for demo/experimentation.""" |
| 4 | + |
| 5 | +import base64 |
| 6 | +import hashlib |
| 7 | +import json |
| 8 | +import os |
| 9 | +from dataclasses import dataclass, field |
| 10 | +from datetime import UTC, datetime, timedelta |
| 11 | +from pathlib import Path |
| 12 | +from typing import Optional |
| 13 | +from urllib.parse import urlencode |
| 14 | + |
| 15 | +from cryptography.hazmat.primitives import serialization |
| 16 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 17 | +from fastapi import FastAPI, Form, HTTPException, Request |
| 18 | +from fastapi.middleware.cors import CORSMiddleware |
| 19 | +from fastapi.responses import JSONResponse, RedirectResponse |
| 20 | +from fastapi.templating import Jinja2Templates |
| 21 | +from jose import jwt |
| 22 | + |
| 23 | +app = FastAPI() |
| 24 | + |
| 25 | +# Configure templates |
| 26 | +templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates")) |
| 27 | + |
| 28 | +# Configure CORS |
| 29 | +app.add_middleware( |
| 30 | + CORSMiddleware, |
| 31 | + allow_origins=["*"], # In production, replace with specific origins |
| 32 | + allow_credentials=True, |
| 33 | + allow_methods=["*"], |
| 34 | + allow_headers=["*"], |
| 35 | + expose_headers=["Content-Type"], |
| 36 | + max_age=86400, # 24 hours |
| 37 | +) |
| 38 | + |
| 39 | +# Configuration |
| 40 | +ISSUER = os.environ.get("ISSUER", "http://localhost:3000") |
| 41 | +AVAILABLE_SCOPES = os.environ.get("SCOPES", "") |
| 42 | +KEY_ID = "1" |
| 43 | + |
| 44 | + |
| 45 | +@dataclass |
| 46 | +class KeyPair: |
| 47 | + cache_dir: Path |
| 48 | + |
| 49 | + jwks: dict = field(init=False) |
| 50 | + private_key: str = field(init=False) |
| 51 | + |
| 52 | + def __post_init__(self): |
| 53 | + private_key_path = self.cache_dir / "private_key.pem" |
| 54 | + jwks_path = self.cache_dir / "jwks.json" |
| 55 | + |
| 56 | + if private_key_path.exists() and jwks_path.exists(): |
| 57 | + self.jwks = json.loads(jwks_path.read_text()) |
| 58 | + self.private_key = private_key_path.read_text() |
| 59 | + return |
| 60 | + |
| 61 | + # Generate keys |
| 62 | + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 63 | + private_pem = private_key.private_bytes( |
| 64 | + encoding=serialization.Encoding.PEM, |
| 65 | + format=serialization.PrivateFormat.PKCS8, |
| 66 | + encryption_algorithm=serialization.NoEncryption(), |
| 67 | + ) |
| 68 | + public_key = private_key.public_key() |
| 69 | + public_numbers = public_key.public_numbers() |
| 70 | + |
| 71 | + self.jwks = { |
| 72 | + "keys": [ |
| 73 | + { |
| 74 | + "kty": "RSA", |
| 75 | + "use": "sig", |
| 76 | + "kid": KEY_ID, |
| 77 | + "alg": "RS256", |
| 78 | + "n": int_to_base64url(public_numbers.n), |
| 79 | + "e": int_to_base64url(public_numbers.e), |
| 80 | + } |
| 81 | + ] |
| 82 | + } |
| 83 | + self.private_key = private_pem.decode("utf-8") |
| 84 | + |
| 85 | + private_key_path.write_text(self.private_key) |
| 86 | + jwks_path.write_text(json.dumps(self.jwks, indent=2)) |
| 87 | + |
| 88 | + @staticmethod |
| 89 | + def int_to_base64url(value): |
| 90 | + """Convert an integer to base64url format.""" |
| 91 | + value_hex = format(value, "x") |
| 92 | + # Ensure even length |
| 93 | + if len(value_hex) % 2 == 1: |
| 94 | + value_hex = "0" + value_hex |
| 95 | + value_bytes = bytes.fromhex(value_hex) |
| 96 | + return base64.urlsafe_b64encode(value_bytes).rstrip(b"=").decode("ascii") |
| 97 | + |
| 98 | + |
| 99 | +# Load or generate key pair on startup |
| 100 | +KEY_PAIR = KeyPair(Path(__file__).parent) |
| 101 | + |
| 102 | +# In-memory storage |
| 103 | +authorization_codes = {} |
| 104 | +pkce_challenges = {} |
| 105 | +access_tokens = {} |
| 106 | +auth_requests = {} |
| 107 | + |
| 108 | + |
| 109 | +@app.get("/") |
| 110 | +async def root(): |
| 111 | + return { |
| 112 | + "message": "If you're using this in production, you are going to have a bad time." |
| 113 | + } |
| 114 | + |
| 115 | + |
| 116 | +@app.get("/.well-known/openid-configuration") |
| 117 | +async def openid_configuration(): |
| 118 | + """Return OpenID Connect configuration.""" |
| 119 | + scopes_set = set(["openid", "profile", *AVAILABLE_SCOPES.split(",")]) |
| 120 | + return { |
| 121 | + "issuer": ISSUER, |
| 122 | + "authorization_endpoint": f"{ISSUER}/authorize", |
| 123 | + "token_endpoint": f"{ISSUER}/token", |
| 124 | + "jwks_uri": f"{ISSUER}/.well-known/jwks.json", |
| 125 | + "response_types_supported": ["code"], |
| 126 | + "subject_types_supported": ["public"], |
| 127 | + "id_token_signing_alg_values_supported": ["RS256"], |
| 128 | + "scopes_supported": sorted(scopes_set), |
| 129 | + "token_endpoint_auth_methods_supported": ["client_secret_post", "none"], |
| 130 | + "claims_supported": ["sub", "iss", "iat", "exp"], |
| 131 | + "code_challenge_methods_supported": ["S256"], |
| 132 | + } |
| 133 | + |
| 134 | + |
| 135 | +@app.get("/.well-known/jwks.json") |
| 136 | +async def jwks(): |
| 137 | + """Return JWKS (JSON Web Key Set).""" |
| 138 | + return KEY_PAIR.jwks |
| 139 | + |
| 140 | + |
| 141 | +@app.get("/authorize") |
| 142 | +async def authorize( |
| 143 | + request: Request, |
| 144 | + response_type: str, |
| 145 | + client_id: str, |
| 146 | + redirect_uri: str, |
| 147 | + state: str, |
| 148 | + scope: str = "", |
| 149 | + code_challenge: Optional[str] = None, |
| 150 | + code_challenge_method: Optional[str] = None, |
| 151 | +): |
| 152 | + """Handle authorization request.""" |
| 153 | + if response_type != "code": |
| 154 | + raise HTTPException(status_code=400, detail="Invalid response type") |
| 155 | + |
| 156 | + # Validate PKCE if provided |
| 157 | + if code_challenge is not None: |
| 158 | + if code_challenge_method != "S256": |
| 159 | + raise HTTPException(status_code=400, detail="Only S256 PKCE is supported") |
| 160 | + |
| 161 | + # Store the auth request details |
| 162 | + request_id = os.urandom(16).hex() |
| 163 | + auth_requests[request_id] = { |
| 164 | + "client_id": client_id, |
| 165 | + "redirect_uri": redirect_uri, |
| 166 | + "state": state, |
| 167 | + "scope": scope, |
| 168 | + "code_challenge": code_challenge, |
| 169 | + "code_challenge_method": code_challenge_method, |
| 170 | + } |
| 171 | + |
| 172 | + # Show login page |
| 173 | + scopes = sorted(set(("openid profile " + scope).split())) |
| 174 | + return templates.TemplateResponse( |
| 175 | + "login.html", |
| 176 | + { |
| 177 | + "request": request, |
| 178 | + "request_id": request_id, |
| 179 | + "client_id": client_id, |
| 180 | + "scopes": scopes, |
| 181 | + }, |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +@app.post("/login") |
| 186 | +async def login(request_id: str = Form(...)): |
| 187 | + """Handle login form submission.""" |
| 188 | + # Retrieve the stored auth request |
| 189 | + if request_id not in auth_requests: |
| 190 | + raise HTTPException(status_code=400, detail="Invalid request") |
| 191 | + |
| 192 | + auth_request = auth_requests.pop(request_id) |
| 193 | + |
| 194 | + # Generate authorization code |
| 195 | + code = os.urandom(32).hex() |
| 196 | + |
| 197 | + # Store authorization details |
| 198 | + authorization_codes[code] = { |
| 199 | + "client_id": auth_request["client_id"], |
| 200 | + "redirect_uri": auth_request["redirect_uri"], |
| 201 | + "scope": " ".join( |
| 202 | + sorted(set(("openid profile " + auth_request["scope"]).split(" "))) |
| 203 | + ), |
| 204 | + } |
| 205 | + |
| 206 | + # Store PKCE challenge if provided |
| 207 | + if auth_request["code_challenge"]: |
| 208 | + pkce_challenges[code] = auth_request["code_challenge"] |
| 209 | + |
| 210 | + # Redirect back to client with the code |
| 211 | + params = {"code": code, "state": auth_request["state"]} |
| 212 | + return RedirectResponse( |
| 213 | + url=f"{auth_request['redirect_uri']}?{urlencode(params)}", status_code=303 |
| 214 | + ) |
| 215 | + |
| 216 | + |
| 217 | +@app.post("/token") |
| 218 | +async def token( |
| 219 | + grant_type: str = Form(...), |
| 220 | + code: str = Form(...), |
| 221 | + redirect_uri: str = Form(...), |
| 222 | + client_id: str = Form(...), |
| 223 | + client_secret: Optional[str] = Form(None), |
| 224 | + code_verifier: Optional[str] = Form(None), |
| 225 | +): |
| 226 | + """Handle token request.""" |
| 227 | + if grant_type != "authorization_code": |
| 228 | + raise HTTPException(status_code=400, detail="Invalid grant type") |
| 229 | + |
| 230 | + # Verify the authorization code exists |
| 231 | + if code not in authorization_codes: |
| 232 | + raise HTTPException(status_code=400, detail="Invalid authorization code") |
| 233 | + |
| 234 | + auth_details = authorization_codes[code] |
| 235 | + |
| 236 | + # Verify client_id matches the stored one |
| 237 | + if client_id != auth_details["client_id"]: |
| 238 | + raise HTTPException(status_code=400, detail="Client ID mismatch") |
| 239 | + |
| 240 | + # Verify redirect_uri matches the stored one |
| 241 | + if redirect_uri != auth_details["redirect_uri"]: |
| 242 | + raise HTTPException(status_code=400, detail="Redirect URI mismatch") |
| 243 | + |
| 244 | + # Check if PKCE was used in the authorization request |
| 245 | + if code in pkce_challenges: |
| 246 | + if not code_verifier: |
| 247 | + raise HTTPException(status_code=400, detail="Code verifier required") |
| 248 | + |
| 249 | + # Verify the code verifier |
| 250 | + code_challenge = pkce_challenges[code] |
| 251 | + computed_challenge = hashlib.sha256(code_verifier.encode()).digest() |
| 252 | + computed_challenge = ( |
| 253 | + base64.urlsafe_b64encode(computed_challenge).decode().rstrip("=") |
| 254 | + ) |
| 255 | + |
| 256 | + if computed_challenge != code_challenge: |
| 257 | + raise HTTPException(status_code=400, detail="Invalid code verifier") |
| 258 | + |
| 259 | + # Clean up the used code and PKCE challenge |
| 260 | + del authorization_codes[code] |
| 261 | + if code in pkce_challenges: |
| 262 | + del pkce_challenges[code] |
| 263 | + |
| 264 | + # Generate access token |
| 265 | + now = datetime.now(UTC) |
| 266 | + expires_delta = timedelta(minutes=15) |
| 267 | + |
| 268 | + return JSONResponse( |
| 269 | + content={ |
| 270 | + "access_token": jwt.encode( |
| 271 | + { |
| 272 | + "iss": ISSUER, |
| 273 | + "sub": "user123", |
| 274 | + "iat": now, |
| 275 | + "exp": now + expires_delta, |
| 276 | + "scope": auth_details["scope"], |
| 277 | + "kid": KEY_ID, |
| 278 | + }, |
| 279 | + KEY_PAIR.private_key, |
| 280 | + algorithm="RS256", |
| 281 | + headers={"kid": KEY_ID}, |
| 282 | + ), |
| 283 | + "token_type": "Bearer", |
| 284 | + "expires_in": expires_delta.seconds, |
| 285 | + "scope": auth_details["scope"], |
| 286 | + } |
| 287 | + ) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + import uvicorn |
| 292 | + |
| 293 | + uvicorn.run( |
| 294 | + "app:app", |
| 295 | + host="0.0.0.0", |
| 296 | + port=int(os.environ.get("PORT", 8888)), |
| 297 | + reload=True, |
| 298 | + ) |
0 commit comments