|
2 | 2 | import uuid
|
3 | 3 | from datetime import timedelta
|
4 | 4 |
|
5 |
| -from Cryptodome.PublicKey.RSA import importKey |
| 5 | +import jwt |
| 6 | +from cryptography.hazmat.primitives import serialization |
6 | 7 | from django.utils import dateformat
|
7 | 8 | from django.utils import timezone
|
8 |
| -from jwkest.jwk import RSAKey as jwk_RSAKey |
9 |
| -from jwkest.jwk import SYMKey |
10 |
| -from jwkest.jws import JWS |
11 |
| -from jwkest.jwt import JWT |
12 | 9 |
|
13 | 10 | from oidc_provider import settings
|
14 | 11 | from oidc_provider.lib.claims import StandardScopeClaims
|
|
18 | 15 | from oidc_provider.models import RSAKey
|
19 | 16 | from oidc_provider.models import Token
|
20 | 17 |
|
| 18 | +# Cache for loaded RSA keys to avoid repeated PEM parsing |
| 19 | +# Cache is automatically cleaned of stale entries (keys no longer in DB) |
| 20 | +_rsa_key_cache = {} |
| 21 | + |
21 | 22 |
|
22 | 23 | def create_id_token(token, user, aud, nonce="", at_hash="", request=None, scope=None):
|
23 | 24 | """
|
@@ -72,28 +73,56 @@ def create_id_token(token, user, aud, nonce="", at_hash="", request=None, scope=
|
72 | 73 | def encode_id_token(payload, client):
|
73 | 74 | """
|
74 | 75 | Represent the ID Token as a JSON Web Token (JWT).
|
75 |
| - Return a hash. |
| 76 | + Returns a dict. |
76 | 77 | """
|
77 | 78 | keys = get_client_alg_keys(client)
|
78 |
| - _jws = JWS(payload, alg=client.jwt_alg) |
79 |
| - return _jws.sign_compact(keys) |
| 79 | + # Use the first key for encoding |
| 80 | + # TODO: make key selection more explicit |
| 81 | + key_info = keys[0] |
| 82 | + |
| 83 | + headers = {} |
| 84 | + if "kid" in key_info: |
| 85 | + headers["kid"] = key_info["kid"] |
| 86 | + |
| 87 | + return jwt.encode(payload, key_info["key"], algorithm=key_info["algorithm"], headers=headers) |
80 | 88 |
|
81 | 89 |
|
82 | 90 | def decode_id_token(token, client):
|
83 | 91 | """
|
84 | 92 | Represent the ID Token as a JSON Web Token (JWT).
|
85 |
| - Return a hash. |
| 93 | + Returns a dict. |
86 | 94 | """
|
87 |
| - keys = get_client_alg_keys(client) |
88 |
| - return JWS().verify_compact(token, keys=keys) |
| 95 | + # Try decoding with each available key |
| 96 | + for key in get_client_alg_keys(client): |
| 97 | + try: |
| 98 | + return jwt.decode( |
| 99 | + jwt=token, |
| 100 | + # HS256 uses the same key for signing and verifying |
| 101 | + key=key["key"] if key["algorithm"] == "HS256" else key["public_key"], |
| 102 | + algorithms=[key["algorithm"]], |
| 103 | + options={ |
| 104 | + "verify_signature": True, |
| 105 | + "verify_aud": False, # Disable audience validation for compatibility |
| 106 | + "verify_exp": False, # Disable expiration validation for compatibility |
| 107 | + "verify_iat": False, # Disable issued at validation for compatibility |
| 108 | + "verify_nbf": False, # Disable not before validation for compatibility |
| 109 | + }, |
| 110 | + ) |
| 111 | + except jwt.InvalidTokenError: |
| 112 | + continue |
| 113 | + |
| 114 | + # If we get here, none of the keys worked |
| 115 | + raise jwt.InvalidTokenError("Token could not be decoded with any available key") |
89 | 116 |
|
90 | 117 |
|
91 | 118 | def client_id_from_id_token(id_token):
|
92 | 119 | """
|
93 | 120 | Extracts the client id from a JSON Web Token (JWT).
|
| 121 | + Does NOT verify the token signature or expiration. |
94 | 122 | Returns a string or None.
|
95 | 123 | """
|
96 |
| - payload = JWT().unpack(id_token).payload() |
| 124 | + # Decode without verification to get the payload |
| 125 | + payload = jwt.decode(id_token, options={"verify_signature": False}) |
97 | 126 | aud = payload.get("aud", None)
|
98 | 127 | if aud is None:
|
99 | 128 | return None
|
@@ -150,16 +179,47 @@ def create_code(
|
150 | 179 | def get_client_alg_keys(client):
|
151 | 180 | """
|
152 | 181 | Takes a client and returns the set of keys associated with it.
|
153 |
| - Returns a list of keys. |
| 182 | + Returns a list of keys compatible with PyJWT. |
154 | 183 | """
|
155 | 184 | if client.jwt_alg == "RS256":
|
156 | 185 | keys = []
|
| 186 | + current_kids = set() |
| 187 | + |
157 | 188 | for rsakey in RSAKey.objects.all():
|
158 |
| - keys.append(jwk_RSAKey(key=importKey(rsakey.key), kid=rsakey.kid)) |
| 189 | + cache_key = f"rsa_key_{rsakey.kid}" |
| 190 | + current_kids.add(cache_key) |
| 191 | + |
| 192 | + if cache_key not in _rsa_key_cache: |
| 193 | + # Load the RSA private key using cryptography (expensive operation) |
| 194 | + private_key = serialization.load_pem_private_key( |
| 195 | + rsakey.key.encode("utf-8"), |
| 196 | + password=None, |
| 197 | + ) |
| 198 | + # Also cache the public key to avoid repeated .public_key() calls |
| 199 | + public_key = private_key.public_key() |
| 200 | + _rsa_key_cache[cache_key] = {"private_key": private_key, "public_key": public_key} |
| 201 | + |
| 202 | + key_pair = _rsa_key_cache[cache_key] |
| 203 | + keys.append( |
| 204 | + { |
| 205 | + "key": key_pair["private_key"], |
| 206 | + "public_key": key_pair["public_key"], |
| 207 | + "kid": rsakey.kid, |
| 208 | + "algorithm": "RS256", |
| 209 | + } |
| 210 | + ) |
| 211 | + |
| 212 | + # Clean up stale cache entries (keys that no longer exist in DB) |
| 213 | + stale_keys = set(_rsa_key_cache.keys()) - current_kids |
| 214 | + for stale_key in stale_keys: |
| 215 | + del _rsa_key_cache[stale_key] |
| 216 | + |
159 | 217 | if not keys:
|
160 | 218 | raise Exception("You must add at least one RSA Key.")
|
161 | 219 | elif client.jwt_alg == "HS256":
|
162 |
| - keys = [SYMKey(key=client.client_secret, alg=client.jwt_alg)] |
| 220 | + # NOTE: HS256 does not have any expensive key parsing, so we don't need the |
| 221 | + # same key caching as RS256. |
| 222 | + keys = [{"key": client.client_secret, "algorithm": "HS256"}] |
163 | 223 | else:
|
164 | 224 | raise Exception("Unsupported key algorithm.")
|
165 | 225 |
|
|
0 commit comments