Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions okta_jwt_verifier/jwt_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,27 @@ async def verify_id_token(self, token, claims_to_verify=('iss', 'exp'), nonce=No
self.verify_signature(token, okta_jwk)

# verify client_id and nonce
self.verify_client_id(claims['aud'])
if self.client_id:
self.verify_client_id(claims['cid'])
if 'nonce' in claims and claims['nonce'] != nonce:
raise JWTValidationException('Claim "nonce" is invalid.')
except JWTValidationException:
raise
except Exception as err:
raise JWTValidationException(str(err))

def verify_client_id(self, aud):
"""Verify client_id match aud or one of its elements."""
if isinstance(aud, str):
if aud != self.client_id:
raise JWTValidationException('Claim "aud" does not match Client ID.')
elif isinstance(aud, list):
for elem in aud:
def verify_client_id(self, cid):
"""Verify client_id match cid or one of its elements."""
if isinstance(cid, str):
if cid != self.client_id:
raise JWTValidationException('Claim "cid" does not match Client ID.')
elif isinstance(cid, list):
for elem in cid:
if elem == self.client_id:
return
raise JWTValidationException('Claim "aud" does not contain Client ID.')
raise JWTValidationException('Claim "cid" does not contain Client ID.')
else:
raise JWTValidationException('Claim "aud" has unsupported format.')
raise JWTValidationException('Claim "cid" has unsupported format.')

def verify_signature(self, token, okta_jwk):
"""Verify token signature using received jwk."""
Expand Down Expand Up @@ -282,6 +283,7 @@ class AccessTokenVerifier():
def __init__(self,
issuer=None,
audience='api://default',
client_id='client_id_stub',
request_executor=RequestExecutor,
max_retries=MAX_RETRIES,
request_timeout=REQUEST_TIMEOUT,
Expand All @@ -301,7 +303,7 @@ def __init__(self,
cache_jwks: bool, optional
"""
self._jwt_verifier = BaseJWTVerifier(issuer=issuer,
client_id='client_id_stub',
client_id=client_id,
audience=audience,
request_executor=request_executor,
max_retries=max_retries,
Expand Down
30 changes: 15 additions & 15 deletions tests/unit/test_jwt_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,30 +117,30 @@ def test_verify_signature(mocker):

def test_verify_client_id():
"""Check if method verify_client_id works correctly."""
# verify when aud is a string
# verify when cid is a string
client_id = 'test_client_id'
aud = client_id
cid = client_id
jwt_verifier = BaseJWTVerifier('https://test_issuer.com', client_id)
jwt_verifier.verify_client_id(aud)
jwt_verifier.verify_client_id(cid)

# verify when aud is an array
aud = ['test_audience', client_id]
jwt_verifier.verify_client_id(aud)
# verify when cid is an array
cid = ['test_cid', client_id]
jwt_verifier.verify_client_id(cid)

# verify exception is raised when aud is a string
# verify exception is raised when cid is a string
with pytest.raises(JWTValidationException):
aud = 'bad_aud'
jwt_verifier.verify_client_id(aud)
cid = 'bad_cid'
jwt_verifier.verify_client_id(cid)

# verify exception is raised when aud is an array
# verify exception is raised when cid is an array
with pytest.raises(JWTValidationException):
aud = ['bad_aud']
jwt_verifier.verify_client_id(aud)
cid = ['bad_cid']
jwt_verifier.verify_client_id(cid)

# verify exception is raised when aud is not a string or array
# verify exception is raised when cid is not a string or array
with pytest.raises(JWTValidationException):
aud = {'aud': 'bad_aud'}
jwt_verifier.verify_client_id(aud)
cid = {'cid': 'bad_cid'}
jwt_verifier.verify_client_id(cid)


def test_verify_claims():
Expand Down