88from django .core .serializers .json import DjangoJSONEncoder
99from django .utils import six
1010from django .utils .translation import gettext_lazy as _
11- from jwt .exceptions import MissingRequiredClaimError , InvalidIssuerError , InvalidTokenError
11+ from jwt .exceptions import MissingRequiredClaimError , InvalidIssuerError , InvalidTokenError , InvalidKeyError
1212from rest_framework import exceptions
1313
1414from rest_framework_sso import claims
@@ -66,30 +66,39 @@ def encode_jwt_token(payload):
6666 if payload [claims .ISSUER ] not in api_settings .PRIVATE_KEYS :
6767 raise RuntimeError ('Private key for specified issuer was not found in settings' )
6868
69- private_key = open (api_settings .PRIVATE_KEYS .get (payload [claims .ISSUER ]), 'rt' ).read ()
69+ private_key , key_id = get_private_key_and_key_id (issuer = payload [claims .ISSUER ])
70+
71+ headers = {
72+ claims .KEY_ID : key_id ,
73+ }
7074
7175 return jwt .encode (
7276 payload = payload ,
7377 key = private_key ,
7478 algorithm = api_settings .ENCODE_ALGORITHM ,
79+ headers = headers ,
7580 json_encoder = DjangoJSONEncoder ,
7681 ).decode ('utf-8' )
7782
7883
7984def decode_jwt_token (token ):
85+ unverified_header = jwt .get_unverified_header (token )
8086 unverified_claims = jwt .decode (token , verify = False )
8187
88+ if unverified_header .get (claims .KEY_ID ):
89+ unverified_key_id = six .text_type (unverified_header .get (claims .KEY_ID ))
90+ else :
91+ unverified_key_id = None
92+
8293 if claims .ISSUER not in unverified_claims :
8394 raise MissingRequiredClaimError (claims .ISSUER )
8495
8596 unverified_issuer = six .text_type (unverified_claims [claims .ISSUER ])
8697
8798 if api_settings .ACCEPTED_ISSUERS is not None and unverified_issuer not in api_settings .ACCEPTED_ISSUERS :
8899 raise InvalidIssuerError ('Invalid issuer' )
89- if unverified_issuer not in api_settings .PUBLIC_KEYS :
90- raise InvalidIssuerError ('Invalid issuer' )
91100
92- public_key = open ( api_settings . PUBLIC_KEYS . get ( unverified_issuer ), 'rt' ). read ( )
101+ public_key = get_public_key ( issuer = unverified_issuer , key_id = unverified_key_id )
93102
94103 options = {
95104 'verify_exp' : api_settings .VERIFY_EXPIRATION ,
@@ -116,6 +125,40 @@ def decode_jwt_token(token):
116125 return payload
117126
118127
128+ def get_private_key_and_key_id (issuer , key_id = None ):
129+ if not api_settings .PRIVATE_KEYS .get (issuer ):
130+ raise InvalidKeyError ('No private keys defined for the given issuer' )
131+ private_keys_setting = api_settings .PRIVATE_KEYS .get (issuer )
132+ if isinstance (private_keys_setting , (str , six .text_type )):
133+ private_keys_setting = [private_keys_setting ]
134+ for pks in private_keys_setting :
135+ if not key_id or key_id == pks :
136+ return open (pks , 'rt' ).read (), pks
137+ raise InvalidKeyError ('No private key matches the given key_id' )
138+
139+
140+ def get_private_key (issuer , key_id = None ):
141+ private_key , key_id = get_private_key_and_key_id (issuer = issuer , key_id = key_id )
142+ return private_key
143+
144+
145+ def get_public_key_and_key_id (issuer , key_id = None ):
146+ if not api_settings .PUBLIC_KEYS .get (issuer ):
147+ raise InvalidKeyError ('No public keys defined for the given issuer' )
148+ public_keys_setting = api_settings .PUBLIC_KEYS .get (issuer )
149+ if isinstance (public_keys_setting , (str , six .text_type )):
150+ public_keys_setting = [public_keys_setting ]
151+ for pks in public_keys_setting :
152+ if not key_id or key_id == pks :
153+ return open (pks , 'rt' ).read (), pks
154+ raise InvalidKeyError ('No public key matches the given key_id' )
155+
156+
157+ def get_public_key (issuer , key_id = None ):
158+ public_key , key_id = get_public_key_and_key_id (issuer = issuer , key_id = key_id )
159+ return public_key
160+
161+
119162def authenticate_payload (payload ):
120163 from rest_framework_sso .models import SessionToken
121164
0 commit comments