11# type: ignore
2+ # ruff: noqa
23"""Mock OIDC server for demo/experimentation."""
34
4-
55import base64
66import hashlib
77import json
88import os
9+ from dataclasses import dataclass , field
910from datetime import UTC , datetime , timedelta
1011from pathlib import Path
1112from typing import Optional
3839 "REDIRECT_URI" , "http://localhost:8000/docs/oauth2-redirect"
3940)
4041ISSUER = os .environ .get ("ISSUER" , "http://localhost:3000" )
42+ SCOPES = os .environ .get ("SCOPES" , "" )
43+ KEY_ID = "1"
4144
42- # Key paths - determine from current file location
43- APP_DIR = Path (__file__ ).parent
44- PRIVATE_KEY_PATH = APP_DIR / "private_key.pem"
45- JWKS_PATH = APP_DIR / "jwks.json"
4645
46+ @dataclass
47+ class KeyPair :
48+ cache_dir : Path
4749
48- def load_or_generate_keys ():
49- """Load keys from files if they exist, otherwise generate and save them."""
50- # If both files exist, load them
51- if PRIVATE_KEY_PATH .exists () and JWKS_PATH .exists ():
52- private_key = PRIVATE_KEY_PATH .read_text ()
53- jwks = json .loads (JWKS_PATH .read_text ())
54- return private_key , jwks
50+ jwks : dict = field (init = False )
51+ private_key : str = field (init = False )
5552
56- # Otherwise, generate new keys
57- private_key , jwks = generate_key_pair ()
53+ def __post_init__ (self ):
54+ private_key_path = self .cache_dir / "private_key.pem"
55+ jwks_path = self .cache_dir / "jwks.json"
5856
59- # Save the keys
60- PRIVATE_KEY_PATH .write_text (private_key )
61- JWKS_PATH .write_text (json .dumps (jwks , indent = 2 ))
57+ if private_key_path .exists () and jwks_path .exists ():
58+ self .jwks = json .loads (jwks_path .read_text ())
59+ self .private_key = private_key_path .read_text ()
60+ return
6261
63- return private_key , jwks
62+ # Generate keys
63+ private_key = rsa .generate_private_key (public_exponent = 65537 , key_size = 2048 )
64+ private_pem = private_key .private_bytes (
65+ encoding = serialization .Encoding .PEM ,
66+ format = serialization .PrivateFormat .PKCS8 ,
67+ encryption_algorithm = serialization .NoEncryption (),
68+ )
69+ public_key = private_key .public_key ()
70+ public_numbers = public_key .public_numbers ()
6471
72+ self .jwks = {
73+ "keys" : [
74+ {
75+ "kty" : "RSA" ,
76+ "use" : "sig" ,
77+ "kid" : KEY_ID ,
78+ "alg" : "RS256" ,
79+ "n" : int_to_base64url (public_numbers .n ),
80+ "e" : int_to_base64url (public_numbers .e ),
81+ }
82+ ]
83+ }
84+ self .private_key = private_pem .decode ("utf-8" )
6585
66- # Generate RSA key pair
67- def generate_key_pair ():
68- """Generate RSA key pair and return private and public keys."""
69- private_key = rsa .generate_private_key (public_exponent = 65537 , key_size = 2048 )
70- private_pem = private_key .private_bytes (
71- encoding = serialization .Encoding .PEM ,
72- format = serialization .PrivateFormat .PKCS8 ,
73- encryption_algorithm = serialization .NoEncryption (),
74- )
75- public_key = private_key .public_key ()
76- public_numbers = public_key .public_numbers ()
86+ private_key_path .write_text (self .private_key )
87+ jwks_path .write_text (json .dumps (self .jwks , indent = 2 ))
7788
78- # Convert public key components to base64url format
89+ @ staticmethod
7990 def int_to_base64url (value ):
8091 """Convert an integer to base64url format."""
8192 value_hex = format (value , "x" )
@@ -85,33 +96,17 @@ def int_to_base64url(value):
8596 value_bytes = bytes .fromhex (value_hex )
8697 return base64 .urlsafe_b64encode (value_bytes ).rstrip (b"=" ).decode ("ascii" )
8798
88- return (
89- private_pem .decode ("utf-8" ),
90- {
91- "keys" : [
92- {
93- "kty" : "RSA" ,
94- "use" : "sig" ,
95- "kid" : "1" , # Key ID
96- "alg" : "RS256" ,
97- "n" : int_to_base64url (public_numbers .n ),
98- "e" : int_to_base64url (public_numbers .e ),
99- }
100- ]
101- },
102- )
103-
10499
105100# Load or generate key pair on startup
106- PRIVATE_KEY , JWKS = load_or_generate_keys ( )
101+ KEY_PAIR = KeyPair ( Path ( __file__ ). parent )
107102
108103# In-memory storage
109104authorization_codes = {}
110105pkce_challenges = {}
111106access_tokens = {}
112107
113108# Mock client registry
114- clients = {
109+ CLIENT_REGISTRY = {
115110 CLIENT_ID : {
116111 "client_secret" : CLIENT_SECRET ,
117112 "redirect_uris" : [REDIRECT_URI ],
@@ -120,25 +115,17 @@ def int_to_base64url(value):
120115}
121116
122117
123- def generate_token (
124- subject : str , expires_delta : timedelta = timedelta (minutes = 15 )
125- ) -> str :
126- """Generate a JWT token."""
127- now = datetime .now (UTC )
128- claims = {
129- "iss" : ISSUER ,
130- "sub" : subject ,
131- "iat" : now ,
132- "exp" : now + expires_delta ,
133- "scope" : "openid profile" ,
134- "kid" : "1" , # Match the key ID from JWKS
118+ @app .get ("/" )
119+ async def root ():
120+ return {
121+ "message" : "If you're using this in production, you are going to have a bad time."
135122 }
136- return jwt .encode (claims , PRIVATE_KEY , algorithm = "RS256" , headers = {"kid" : "1" })
137123
138124
139125@app .get ("/.well-known/openid-configuration" )
140126async def openid_configuration ():
141127 """Return OpenID Connect configuration."""
128+ scopes_set = set (["openid" , "profile" , * SCOPES .split ("," )])
142129 return {
143130 "issuer" : ISSUER ,
144131 "authorization_endpoint" : f"{ ISSUER } /authorize" ,
@@ -147,7 +134,7 @@ async def openid_configuration():
147134 "response_types_supported" : ["code" ],
148135 "subject_types_supported" : ["public" ],
149136 "id_token_signing_alg_values_supported" : ["RS256" ],
150- "scopes_supported" : [ "openid" , "profile" ] ,
137+ "scopes_supported" : sorted ( scopes_set ) ,
151138 "token_endpoint_auth_methods_supported" : ["client_secret_post" , "none" ],
152139 "claims_supported" : ["sub" , "iss" , "iat" , "exp" ],
153140 "code_challenge_methods_supported" : ["S256" ],
@@ -157,7 +144,7 @@ async def openid_configuration():
157144@app .get ("/.well-known/jwks.json" )
158145async def jwks ():
159146 """Return JWKS (JSON Web Key Set)."""
160- return JWKS
147+ return KEY_PAIR . jwks
161148
162149
163150@app .get ("/authorize" )
@@ -175,11 +162,11 @@ async def authorize(
175162 raise HTTPException (status_code = 400 , detail = "Invalid response type" )
176163
177164 # Validate client
178- if client_id not in clients :
165+ if client_id not in CLIENT_REGISTRY :
179166 raise HTTPException (status_code = 400 , detail = "Invalid client_id" )
180167
181168 # Validate redirect URI
182- if redirect_uri not in clients [client_id ]["redirect_uris" ]:
169+ if redirect_uri not in CLIENT_REGISTRY [client_id ]["redirect_uris" ]:
183170 raise HTTPException (status_code = 400 , detail = "Invalid redirect_uri" )
184171
185172 # Validate PKCE if provided
@@ -194,7 +181,7 @@ async def authorize(
194181 authorization_codes [code ] = {
195182 "client_id" : client_id ,
196183 "redirect_uri" : redirect_uri ,
197- "scope" : scope ,
184+ "scope" : " " . join ( sorted ( set (( "openid profile " + scope ). split ( " " )))) ,
198185 }
199186
200187 # Store PKCE challenge if provided
@@ -252,7 +239,7 @@ async def token(
252239 if not client_secret :
253240 raise HTTPException (status_code = 400 , detail = "Client secret required" )
254241
255- if client_secret != clients [client_id ]["client_secret" ]:
242+ if client_secret != CLIENT_REGISTRY [client_id ]["client_secret" ]:
256243 raise HTTPException (status_code = 400 , detail = "Invalid client secret" )
257244
258245 # Clean up the used code and PKCE challenge
@@ -261,21 +248,37 @@ async def token(
261248 del pkce_challenges [code ]
262249
263250 # Generate access token
264- access_token = generate_token ("user123" )
251+ now = datetime .now (UTC )
252+ expires_delta = timedelta (minutes = 15 )
265253
266- response = JSONResponse (
254+ return JSONResponse (
267255 content = {
268- "access_token" : access_token ,
256+ "access_token" : jwt .encode (
257+ {
258+ "iss" : ISSUER ,
259+ "sub" : "user123" ,
260+ "iat" : now ,
261+ "exp" : now + expires_delta ,
262+ "scope" : auth_details ["scope" ],
263+ "kid" : KEY_ID ,
264+ },
265+ KEY_PAIR .private_key ,
266+ algorithm = "RS256" ,
267+ headers = {"kid" : KEY_ID },
268+ ),
269269 "token_type" : "Bearer" ,
270- "expires_in" : 900 , # 15 minutes
270+ "expires_in" : expires_delta . seconds ,
271271 "scope" : auth_details ["scope" ],
272272 }
273273 )
274274
275- return response
276-
277275
278276if __name__ == "__main__" :
279277 import uvicorn
280278
281- uvicorn .run (app , host = "0.0.0.0" , port = 3000 )
279+ uvicorn .run (
280+ "app:app" ,
281+ host = "0.0.0.0" ,
282+ port = int (os .environ .get ("PORT" , 8888 )),
283+ reload = True ,
284+ )
0 commit comments