Skip to content

Commit 094a5c9

Browse files
committed
Store the key_id in the token and retrieve public keys using that key_id
1 parent 89d5b99 commit 094a5c9

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

rest_framework_sso/claims.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# coding: utf-8
22
from __future__ import absolute_import, unicode_literals
33

4+
TYPE = 'typ'
5+
ALGORITHM = 'alg'
6+
KEY_ID = 'kid'
7+
48
ISSUER = 'iss'
59
SUBJECT = 'sub'
610
AUDIENCE = 'aud'

rest_framework_sso/utils.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from django.core.serializers.json import DjangoJSONEncoder
99
from django.utils import six
1010
from django.utils.translation import gettext_lazy as _
11-
from jwt.exceptions import MissingRequiredClaimError, InvalidIssuerError, InvalidTokenError
11+
from jwt.exceptions import MissingRequiredClaimError, InvalidIssuerError, InvalidTokenError, InvalidKeyError
1212
from rest_framework import exceptions
1313

1414
from rest_framework_sso import claims
@@ -66,30 +66,39 @@ def encode_jwt_token(payload):
6666
if payload[claims.ISSUER] not in api_settings.PRIVATE_KEYS:
6767
raise RuntimeError('Private key for specified issuer was not found in settings')
6868

69-
private_key = open(api_settings.PRIVATE_KEYS.get(payload[claims.ISSUER]), 'rt').read()
69+
private_key, key_id = get_private_key_and_key_id(issuer=payload[claims.ISSUER])
70+
71+
headers = {
72+
claims.KEY_ID: key_id,
73+
}
7074

7175
return jwt.encode(
7276
payload=payload,
7377
key=private_key,
7478
algorithm=api_settings.ENCODE_ALGORITHM,
79+
headers=headers,
7580
json_encoder=DjangoJSONEncoder,
7681
).decode('utf-8')
7782

7883

7984
def decode_jwt_token(token):
85+
unverified_header = jwt.get_unverified_header(token)
8086
unverified_claims = jwt.decode(token, verify=False)
8187

88+
if unverified_header.get(claims.KEY_ID):
89+
unverified_key_id = six.text_type(unverified_header.get(claims.KEY_ID))
90+
else:
91+
unverified_key_id = None
92+
8293
if claims.ISSUER not in unverified_claims:
8394
raise MissingRequiredClaimError(claims.ISSUER)
8495

8596
unverified_issuer = six.text_type(unverified_claims[claims.ISSUER])
8697

8798
if api_settings.ACCEPTED_ISSUERS is not None and unverified_issuer not in api_settings.ACCEPTED_ISSUERS:
8899
raise InvalidIssuerError('Invalid issuer')
89-
if unverified_issuer not in api_settings.PUBLIC_KEYS:
90-
raise InvalidIssuerError('Invalid issuer')
91100

92-
public_key = open(api_settings.PUBLIC_KEYS.get(unverified_issuer), 'rt').read()
101+
public_key = get_public_key(issuer=unverified_issuer, key_id=unverified_key_id)
93102

94103
options = {
95104
'verify_exp': api_settings.VERIFY_EXPIRATION,
@@ -116,6 +125,40 @@ def decode_jwt_token(token):
116125
return payload
117126

118127

128+
def get_private_key_and_key_id(issuer, key_id=None):
129+
if not api_settings.PRIVATE_KEYS.get(issuer):
130+
raise InvalidKeyError('No private keys defined for the given issuer')
131+
private_keys_setting = api_settings.PRIVATE_KEYS.get(issuer)
132+
if isinstance(private_keys_setting, (str, six.text_type)):
133+
private_keys_setting = [private_keys_setting]
134+
for pks in private_keys_setting:
135+
if not key_id or key_id == pks:
136+
return open(pks, 'rt').read(), pks
137+
raise InvalidKeyError('No private key matches the given key_id')
138+
139+
140+
def get_private_key(issuer, key_id=None):
141+
private_key, key_id = get_private_key_and_key_id(issuer=issuer, key_id=key_id)
142+
return private_key
143+
144+
145+
def get_public_key_and_key_id(issuer, key_id=None):
146+
if not api_settings.PUBLIC_KEYS.get(issuer):
147+
raise InvalidKeyError('No public keys defined for the given issuer')
148+
public_keys_setting = api_settings.PUBLIC_KEYS.get(issuer)
149+
if isinstance(public_keys_setting, (str, six.text_type)):
150+
public_keys_setting = [public_keys_setting]
151+
for pks in public_keys_setting:
152+
if not key_id or key_id == pks:
153+
return open(pks, 'rt').read(), pks
154+
raise InvalidKeyError('No public key matches the given key_id')
155+
156+
157+
def get_public_key(issuer, key_id=None):
158+
public_key, key_id = get_public_key_and_key_id(issuer=issuer, key_id=key_id)
159+
return public_key
160+
161+
119162
def authenticate_payload(payload):
120163
from rest_framework_sso.models import SessionToken
121164

0 commit comments

Comments
 (0)