Skip to content

Commit 5d4b479

Browse files
authored
Use PyJWT+cryptography instead of jwkest+Cryptodrome (juanifioren#446)
- Switches id_token signing and verification to use PyJWT with cryptography - Removes jwkest and Cryptodome dependencies - `future` dependency previously required by jwkest is no longer needed (it had unfixed security vulnerabilities and seems unmaintained) - adds caching of RSA keys to avoid repeated expensive key loading operations
1 parent 971a7cd commit 5d4b479

File tree

6 files changed

+456
-60
lines changed

6 files changed

+456
-60
lines changed

oidc_provider/lib/utils/token.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,10 @@
22
import uuid
33
from datetime import timedelta
44

5-
from Cryptodome.PublicKey.RSA import importKey
5+
import jwt
6+
from cryptography.hazmat.primitives import serialization
67
from django.utils import dateformat
78
from django.utils import timezone
8-
from jwkest.jwk import RSAKey as jwk_RSAKey
9-
from jwkest.jwk import SYMKey
10-
from jwkest.jws import JWS
11-
from jwkest.jwt import JWT
129

1310
from oidc_provider import settings
1411
from oidc_provider.lib.claims import StandardScopeClaims
@@ -18,6 +15,10 @@
1815
from oidc_provider.models import RSAKey
1916
from oidc_provider.models import Token
2017

18+
# Cache for loaded RSA keys to avoid repeated PEM parsing
19+
# Cache is automatically cleaned of stale entries (keys no longer in DB)
20+
_rsa_key_cache = {}
21+
2122

2223
def create_id_token(token, user, aud, nonce="", at_hash="", request=None, scope=None):
2324
"""
@@ -72,28 +73,56 @@ def create_id_token(token, user, aud, nonce="", at_hash="", request=None, scope=
7273
def encode_id_token(payload, client):
7374
"""
7475
Represent the ID Token as a JSON Web Token (JWT).
75-
Return a hash.
76+
Returns a dict.
7677
"""
7778
keys = get_client_alg_keys(client)
78-
_jws = JWS(payload, alg=client.jwt_alg)
79-
return _jws.sign_compact(keys)
79+
# Use the first key for encoding
80+
# TODO: make key selection more explicit
81+
key_info = keys[0]
82+
83+
headers = {}
84+
if "kid" in key_info:
85+
headers["kid"] = key_info["kid"]
86+
87+
return jwt.encode(payload, key_info["key"], algorithm=key_info["algorithm"], headers=headers)
8088

8189

8290
def decode_id_token(token, client):
8391
"""
8492
Represent the ID Token as a JSON Web Token (JWT).
85-
Return a hash.
93+
Returns a dict.
8694
"""
87-
keys = get_client_alg_keys(client)
88-
return JWS().verify_compact(token, keys=keys)
95+
# Try decoding with each available key
96+
for key in get_client_alg_keys(client):
97+
try:
98+
return jwt.decode(
99+
jwt=token,
100+
# HS256 uses the same key for signing and verifying
101+
key=key["key"] if key["algorithm"] == "HS256" else key["public_key"],
102+
algorithms=[key["algorithm"]],
103+
options={
104+
"verify_signature": True,
105+
"verify_aud": False, # Disable audience validation for compatibility
106+
"verify_exp": False, # Disable expiration validation for compatibility
107+
"verify_iat": False, # Disable issued at validation for compatibility
108+
"verify_nbf": False, # Disable not before validation for compatibility
109+
},
110+
)
111+
except jwt.InvalidTokenError:
112+
continue
113+
114+
# If we get here, none of the keys worked
115+
raise jwt.InvalidTokenError("Token could not be decoded with any available key")
89116

90117

91118
def client_id_from_id_token(id_token):
92119
"""
93120
Extracts the client id from a JSON Web Token (JWT).
121+
Does NOT verify the token signature or expiration.
94122
Returns a string or None.
95123
"""
96-
payload = JWT().unpack(id_token).payload()
124+
# Decode without verification to get the payload
125+
payload = jwt.decode(id_token, options={"verify_signature": False})
97126
aud = payload.get("aud", None)
98127
if aud is None:
99128
return None
@@ -150,16 +179,47 @@ def create_code(
150179
def get_client_alg_keys(client):
151180
"""
152181
Takes a client and returns the set of keys associated with it.
153-
Returns a list of keys.
182+
Returns a list of keys compatible with PyJWT.
154183
"""
155184
if client.jwt_alg == "RS256":
156185
keys = []
186+
current_kids = set()
187+
157188
for rsakey in RSAKey.objects.all():
158-
keys.append(jwk_RSAKey(key=importKey(rsakey.key), kid=rsakey.kid))
189+
cache_key = f"rsa_key_{rsakey.kid}"
190+
current_kids.add(cache_key)
191+
192+
if cache_key not in _rsa_key_cache:
193+
# Load the RSA private key using cryptography (expensive operation)
194+
private_key = serialization.load_pem_private_key(
195+
rsakey.key.encode("utf-8"),
196+
password=None,
197+
)
198+
# Also cache the public key to avoid repeated .public_key() calls
199+
public_key = private_key.public_key()
200+
_rsa_key_cache[cache_key] = {"private_key": private_key, "public_key": public_key}
201+
202+
key_pair = _rsa_key_cache[cache_key]
203+
keys.append(
204+
{
205+
"key": key_pair["private_key"],
206+
"public_key": key_pair["public_key"],
207+
"kid": rsakey.kid,
208+
"algorithm": "RS256",
209+
}
210+
)
211+
212+
# Clean up stale cache entries (keys that no longer exist in DB)
213+
stale_keys = set(_rsa_key_cache.keys()) - current_kids
214+
for stale_key in stale_keys:
215+
del _rsa_key_cache[stale_key]
216+
159217
if not keys:
160218
raise Exception("You must add at least one RSA Key.")
161219
elif client.jwt_alg == "HS256":
162-
keys = [SYMKey(key=client.client_secret, alg=client.jwt_alg)]
220+
# NOTE: HS256 does not have any expensive key parsing, so we don't need the
221+
# same key caching as RS256.
222+
keys = [{"key": client.client_secret, "algorithm": "HS256"}]
163223
else:
164224
raise Exception("Unsupported key algorithm.")
165225

oidc_provider/management/commands/creatersakey.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from Cryptodome.PublicKey import RSA
1+
from cryptography.hazmat.primitives import serialization
2+
from cryptography.hazmat.primitives.asymmetric import rsa
23
from django.core.management.base import BaseCommand
34

45
from oidc_provider.models import RSAKey
@@ -9,8 +10,20 @@ class Command(BaseCommand):
910

1011
def handle(self, *args, **options):
1112
try:
12-
key = RSA.generate(2048)
13-
rsakey = RSAKey(key=key.exportKey("PEM").decode("utf8"))
13+
# Generate a new RSA private key with 2048 bits
14+
private_key = rsa.generate_private_key(
15+
public_exponent=65537,
16+
key_size=2048,
17+
)
18+
19+
# Serialize the private key to PEM format
20+
key_pem = private_key.private_bytes(
21+
encoding=serialization.Encoding.PEM,
22+
format=serialization.PrivateFormat.PKCS8,
23+
encryption_algorithm=serialization.NoEncryption(),
24+
).decode("utf-8")
25+
26+
rsakey = RSAKey(key=key_pem)
1427
rsakey.save()
1528
self.stdout.write("RSA key successfully created with kid: {0}".format(rsakey.kid))
1629
except Exception as e:

oidc_provider/tests/cases/test_authorize_endpoint.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
from django.urls import reverse
2323
except ImportError:
2424
from django.core.urlresolvers import reverse
25+
import jwt
2526
from django.contrib.auth.models import AnonymousUser
2627
from django.core.management import call_command
2728
from django.test import RequestFactory
2829
from django.test import TestCase
2930
from django.test import override_settings
30-
from jwkest.jwt import JWT
3131

3232
from oidc_provider import settings
3333
from oidc_provider.lib.endpoints.authorize import AuthorizeEndpoint
@@ -724,7 +724,7 @@ def test_idtoken_token_at_hash(self):
724724
# obtain `id_token` portion of Location
725725
components = urlsplit(response["Location"])
726726
fragment = parse_qs(components[4])
727-
id_token = JWT().unpack(fragment["id_token"][0].encode("utf-8")).payload()
727+
id_token = jwt.decode(fragment["id_token"][0], options={"verify_signature": False})
728728

729729
self.assertIn("at_hash", id_token)
730730

@@ -750,7 +750,7 @@ def test_idtoken_at_hash(self):
750750
# obtain `id_token` portion of Location
751751
components = urlsplit(response["Location"])
752752
fragment = parse_qs(components[4])
753-
id_token = JWT().unpack(fragment["id_token"][0].encode("utf-8")).payload()
753+
id_token = jwt.decode(fragment["id_token"][0], options={"verify_signature": False})
754754

755755
self.assertNotIn("at_hash", id_token)
756756

0 commit comments

Comments
 (0)