Skip to content

Commit cc402f8

Browse files
author
Michael Davis
authored
Merge pull request #30 from bjmc/at_hash
Adds support for at_hash verification
2 parents b7d3871 + 95fb84a commit cc402f8

File tree

4 files changed

+123
-9
lines changed

4 files changed

+123
-9
lines changed

jose/constants.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
import hashlib
22

33
class ALGORITHMS(object):
44
NONE = 'none'
@@ -19,3 +19,15 @@ class ALGORITHMS(object):
1919
SUPPORTED = HMAC + RSA + EC
2020

2121
ALL = SUPPORTED + (NONE, )
22+
23+
HASHES = {
24+
HS256: hashlib.sha256,
25+
HS384: hashlib.sha384,
26+
HS512: hashlib.sha512,
27+
RS256: hashlib.sha256,
28+
RS384: hashlib.sha384,
29+
RS512: hashlib.sha512,
30+
ES256: hashlib.sha256,
31+
ES384: hashlib.sha384,
32+
ES512: hashlib.sha512,
33+
}

jose/jwt.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from .exceptions import JWTClaimsError
1515
from .exceptions import JWTError
1616
from .exceptions import ExpiredSignatureError
17-
from .utils import timedelta_total_seconds
17+
from .constants import ALGORITHMS
18+
from .utils import timedelta_total_seconds, calculate_at_hash
1819

1920

20-
def encode(claims, key, algorithm=None, headers=None):
21+
def encode(claims, key, algorithm=ALGORITHMS.HS256, headers=None, access_token=None):
2122
"""Encodes a claims set and returns a JWT string.
2223
2324
JWTs are JWS signed objects with a few reserved claims.
@@ -30,6 +31,9 @@ def encode(claims, key, algorithm=None, headers=None):
3031
headers (dict, optional): A set of headers that will be added to
3132
the default headers. Any headers that are added as additional
3233
headers will override the default headers.
34+
access_token (str, optional): If present, the 'at_hash' claim will
35+
be calculated and added to the claims present in the 'claims'
36+
parameter.
3337
3438
Returns:
3539
str: The string representation of the header, claims, and signature.
@@ -50,13 +54,15 @@ def encode(claims, key, algorithm=None, headers=None):
5054
if isinstance(claims.get(time_claim), datetime):
5155
claims[time_claim] = timegm(claims[time_claim].utctimetuple())
5256

53-
if algorithm:
54-
return jws.sign(claims, key, headers=headers, algorithm=algorithm)
57+
if access_token:
58+
claims['at_hash'] = calculate_at_hash(access_token,
59+
ALGORITHMS.HASHES[algorithm])
5560

56-
return jws.sign(claims, key, headers=headers)
61+
return jws.sign(claims, key, headers=headers, algorithm=algorithm)
5762

5863

59-
def decode(token, key, algorithms=None, options=None, audience=None, issuer=None, subject=None):
64+
def decode(token, key, algorithms=None, options=None, audience=None,
65+
issuer=None, subject=None, access_token=None):
6066
"""Verifies a JWT string's signature and validates reserved claims.
6167
6268
Args:
@@ -72,6 +78,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
7278
subject (str): The subject of the token. If the "sub" claim is
7379
included in the claim set, then the subject must be included and must equal
7480
the provided claim.
81+
access_token (str): An access token returned alongside the id_token during
82+
the authorization grant flow. If the "at_hash" claim is included in the
83+
claim set, then the access_token must be included, and it must match
84+
the "at_hash" claim.
7585
options (dict): A dictionary of options for skipping validation steps.
7686
7787
defaults = {
@@ -109,6 +119,7 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
109119
'verify_iss': True,
110120
'verify_sub': True,
111121
'verify_jti': True,
122+
'verify_at_hash': True,
112123
'leeway': 0,
113124
}
114125

@@ -122,6 +133,9 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
122133
except JWSError as e:
123134
raise JWTError(e)
124135

136+
# Needed for at_hash verification
137+
algorithm = jws.get_unverified_header(token)['alg']
138+
125139
try:
126140
claims = json.loads(payload.decode('utf-8'))
127141
except ValueError as e:
@@ -130,7 +144,10 @@ def decode(token, key, algorithms=None, options=None, audience=None, issuer=None
130144
if not isinstance(claims, Mapping):
131145
raise JWTError('Invalid payload string: must be a json object')
132146

133-
_validate_claims(claims, audience=audience, issuer=issuer, subject=subject, options=defaults)
147+
_validate_claims(claims, audience=audience, issuer=issuer,
148+
subject=subject, algorithm=algorithm,
149+
access_token=access_token,
150+
options=defaults)
134151

135152
return claims
136153

@@ -384,7 +401,40 @@ def _validate_jti(claims):
384401
raise JWTClaimsError('JWT ID must be a string.')
385402

386403

387-
def _validate_claims(claims, audience=None, issuer=None, subject=None, options=None):
404+
def _validate_at_hash(claims, access_token, algorithm):
405+
"""
406+
Validates that the 'at_hash' parameter included in the claims matches
407+
with the access_token returned alongside the id token as part of
408+
the authorization_code flow.
409+
410+
Args:
411+
claims (dict): The claims dictionary to validate.
412+
access_token (str): The access token returned by the OpenID Provider.
413+
algorithm (str): The algorithm used to sign the JWT, as specified by
414+
the token headers.
415+
"""
416+
if 'at_hash' not in claims and not access_token:
417+
return
418+
elif 'at_hash' in claims and not access_token:
419+
msg = 'No access_token provided to compare against at_hash claim.'
420+
raise JWTClaimsError(msg)
421+
elif access_token and 'at_hash' not in claims:
422+
msg = 'at_hash claim missing from token.'
423+
raise JWTClaimsError(msg)
424+
425+
try:
426+
expected_hash = calculate_at_hash(access_token,
427+
ALGORITHMS.HASHES[algorithm])
428+
except (TypeError, ValueError):
429+
msg = 'Unable to calculate at_hash to verify against token claims.'
430+
raise JWTClaimsError(msg)
431+
432+
if claims['at_hash'] != expected_hash:
433+
raise JWTClaimsError('at_hash claim does not match access_token.')
434+
435+
436+
def _validate_claims(claims, audience=None, issuer=None, subject=None,
437+
algorithm=None, access_token=None, options=None):
388438

389439
leeway = options.get('leeway', 0)
390440

@@ -414,3 +464,6 @@ def _validate_claims(claims, audience=None, issuer=None, subject=None, options=N
414464

415465
if options.get('verify_jti'):
416466
_validate_jti(claims)
467+
468+
if options.get('verify_at_hash'):
469+
_validate_at_hash(claims, access_token, algorithm)

jose/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,29 @@
22
import base64
33

44

5+
def calculate_at_hash(access_token, hash_alg):
6+
"""Helper method for calculating an access token
7+
hash, as described in http://openid.net/specs/openid-connect-core-1_0.html#CodeIDToken
8+
9+
Its value is the base64url encoding of the left-most half of the hash of the octets
10+
of the ASCII representation of the access_token value, where the hash algorithm
11+
used is the hash algorithm used in the alg Header Parameter of the ID Token's JOSE
12+
Header. For instance, if the alg is RS256, hash the access_token value with SHA-256,
13+
then take the left-most 128 bits and base64url encode them. The at_hash value is a
14+
case sensitive string.
15+
16+
Args:
17+
access_token (str): An access token string.
18+
hash_alg (callable): A callable returning a hash object, e.g. hashlib.sha256
19+
20+
"""
21+
hash_digest = hash_alg(access_token.encode('utf-8')).digest()
22+
cut_at = int(len(hash_digest) / 2)
23+
truncated = hash_digest[:cut_at]
24+
at_hash = base64url_encode(truncated)
25+
return at_hash.decode('utf-8')
26+
27+
528
def base64url_decode(input):
629
"""Helper method to base64url_decode a string.
730

tests/test_jwt.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,32 @@ def test_jti_invalid(self, key):
428428
with pytest.raises(JWTError):
429429
jwt.decode(token, key)
430430

431+
def test_at_hash(self, claims, key):
432+
access_token = '<ACCESS_TOKEN>'
433+
token = jwt.encode(claims, key, access_token=access_token)
434+
payload = jwt.decode(token, key, access_token=access_token)
435+
assert 'at_hash' in payload
436+
437+
def test_at_hash_invalid(self, claims, key):
438+
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
439+
with pytest.raises(JWTError):
440+
jwt.decode(token, key, access_token='<OTHER_TOKEN>')
441+
442+
def test_at_hash_missing_access_token(self, claims, key):
443+
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
444+
with pytest.raises(JWTError):
445+
jwt.decode(token, key)
446+
447+
def test_at_hash_missing_claim(self, claims, key):
448+
token = jwt.encode(claims, key)
449+
with pytest.raises(JWTError):
450+
jwt.decode(token, key, access_token='<ACCESS_TOKEN>')
451+
452+
def test_at_hash_unable_to_calculate(self, claims, key):
453+
token = jwt.encode(claims, key, access_token='<ACCESS_TOKEN>')
454+
with pytest.raises(JWTError):
455+
jwt.decode(token, key, access_token='\xe2')
456+
431457
def test_unverified_claims_string(self):
432458
token = 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.aW52YWxpZCBjbGFpbQ.iOJ5SiNfaNO_pa2J4Umtb3b3zmk5C18-mhTCVNsjnck'
433459
with pytest.raises(JWTError):

0 commit comments

Comments
 (0)