1414from .exceptions import JWTClaimsError
1515from .exceptions import JWTError
1616from .exceptions import ExpiredSignatureError
17- from .utils import timedelta_total_seconds
17+ from .constants import ALGORITHMS
18+ from .utils import timedelta_total_seconds , calculate_at_hash
1819
1920
20- def encode (claims , key , algorithm = None , headers = None ):
21+ def encode (claims , key , algorithm = ALGORITHMS . HS256 , headers = None , access_token = None ):
2122 """Encodes a claims set and returns a JWT string.
2223
2324 JWTs are JWS signed objects with a few reserved claims.
@@ -30,6 +31,9 @@ def encode(claims, key, algorithm=None, headers=None):
3031 headers (dict, optional): A set of headers that will be added to
3132 the default headers. Any headers that are added as additional
3233 headers will override the default headers.
34+ access_token (str, optional): If present, the 'at_hash' claim will
35+ be calculated and added to the claims present in the 'claims'
36+ parameter.
3337
3438 Returns:
3539 str: The string representation of the header, claims, and signature.
@@ -50,13 +54,15 @@ def encode(claims, key, algorithm=None, headers=None):
5054 if isinstance (claims .get (time_claim ), datetime ):
5155 claims [time_claim ] = timegm (claims [time_claim ].utctimetuple ())
5256
53- if algorithm :
54- return jws .sign (claims , key , headers = headers , algorithm = algorithm )
57+ if access_token :
58+ claims ['at_hash' ] = calculate_at_hash (access_token ,
59+ ALGORITHMS .HASHES [algorithm ])
5560
56- return jws .sign (claims , key , headers = headers )
61+ return jws .sign (claims , key , headers = headers , algorithm = algorithm )
5762
5863
59- def decode (token , key , algorithms = None , options = None , audience = None , issuer = None , subject = None ):
64+ def decode (token , key , algorithms = None , options = None , audience = None ,
65+ issuer = None , subject = None , access_token = None ):
6066 """Verifies a JWT string's signature and validates reserved claims.
6167
6268 Args:
@@ -72,6 +78,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
7278 subject (str): The subject of the token. If the "sub" claim is
7379 included in the claim set, then the subject must be included and must equal
7480 the provided claim.
81+ access_token (str): An access token returned alongside the id_token during
82+ the authorization grant flow. If the "at_hash" claim is included in the
83+ claim set, then the access_token must be included, and it must match
84+ the "at_hash" claim.
7585 options (dict): A dictionary of options for skipping validation steps.
7686
7787 defaults = {
@@ -109,6 +119,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
109119 'verify_iss' : True ,
110120 'verify_sub' : True ,
111121 'verify_jti' : True ,
122+ 'verify_at_hash' : True ,
112123 'leeway' : 0 ,
113124 }
114125
@@ -122,6 +133,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
122133 except JWSError as e :
123134 raise JWTError (e )
124135
136+ # Needed for at_hash verification
137+ algorithm = jws .get_unverified_header (token )['alg' ]
138+
125139 try :
126140 claims = json .loads (payload .decode ('utf-8' ))
127141 except ValueError as e :
@@ -130,7 +144,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
130144 if not isinstance (claims , Mapping ):
131145 raise JWTError ('Invalid payload string: must be a json object' )
132146
133- _validate_claims (claims , audience = audience , issuer = issuer , subject = subject , options = defaults )
147+ _validate_claims (claims , audience = audience , issuer = issuer ,
148+ subject = subject , algorithm = algorithm ,
149+ access_token = access_token ,
150+ options = defaults )
134151
135152 return claims
136153
@@ -384,7 +401,40 @@ def _validate_jti(claims):
384401 raise JWTClaimsError ('JWT ID must be a string.' )
385402
386403
387- def _validate_claims (claims , audience = None , issuer = None , subject = None , options = None ):
404+ def _validate_at_hash (claims , access_token , algorithm ):
405+ """
406+ Validates that the 'at_hash' parameter included in the claims matches
407+ with the access_token returned alongside the id token as part of
408+ the authorization_code flow.
409+
410+ Args:
411+ claims (dict): The claims dictionary to validate.
412+ access_token (str): The access token returned by the OpenID Provider.
413+ algorithm (str): The algorithm used to sign the JWT, as specified by
414+ the token headers.
415+ """
416+ if 'at_hash' not in claims and not access_token :
417+ return
418+ elif 'at_hash' in claims and not access_token :
419+ msg = 'No access_token provided to compare against at_hash claim.'
420+ raise JWTClaimsError (msg )
421+ elif access_token and 'at_hash' not in claims :
422+ msg = 'at_hash claim missing from token.'
423+ raise JWTClaimsError (msg )
424+
425+ try :
426+ expected_hash = calculate_at_hash (access_token ,
427+ ALGORITHMS .HASHES [algorithm ])
428+ except (TypeError , ValueError ):
429+ msg = 'Unable to calculate at_hash to verify against token claims.'
430+ raise JWTClaimsError (msg )
431+
432+ if claims ['at_hash' ] != expected_hash :
433+ raise JWTClaimsError ('at_hash claim does not match access_token.' )
434+
435+
436+ def _validate_claims (claims , audience = None , issuer = None , subject = None ,
437+ algorithm = None , access_token = None , options = None ):
388438
389439 leeway = options .get ('leeway' , 0 )
390440
@@ -414,3 +464,6 @@ def _validate_claims(claims, audience=None, issuer=None, subject=None, options=N
414464
415465 if options .get ('verify_jti' ):
416466 _validate_jti (claims )
467+
468+ if options .get ('verify_at_hash' ):
469+ _validate_at_hash (claims , access_token , algorithm )
0 commit comments