Skip to content

Commit b65752d

Browse files
committed
added audience override if contains project ID and tests
1 parent 4ff6e92 commit b65752d

File tree

2 files changed

+155
-1
lines changed

2 files changed

+155
-1
lines changed

descope/auth.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,12 +656,37 @@ def _validate_token(
656656
"Algorithm signature in JWT header does not match the algorithm signature in the public key",
657657
)
658658

659+
# Check if we need to auto-detect audience from token
660+
validation_audience = audience
661+
if audience is None:
662+
try:
663+
unverified_claims = jwt.decode(
664+
jwt=token,
665+
key=copy_key[0].key,
666+
algorithms=[alg_header],
667+
options={"verify_aud": False}, # Skip audience verification for now
668+
leeway=self.jwt_validation_leeway,
669+
)
670+
token_audience = unverified_claims.get("aud")
671+
672+
# If token has audience claim and it matches our project ID, use it
673+
if token_audience and self.project_id:
674+
if isinstance(token_audience, list):
675+
if self.project_id in token_audience:
676+
validation_audience = self.project_id
677+
else:
678+
if token_audience == self.project_id:
679+
validation_audience = self.project_id
680+
except Exception:
681+
# If we can't decode the token to check audience, proceed with original audience (None)
682+
pass
683+
659684
try:
660685
claims = jwt.decode(
661686
jwt=token,
662687
key=copy_key[0].key,
663688
algorithms=[alg_header],
664-
audience=audience,
689+
audience=validation_audience,
665690
leeway=self.jwt_validation_leeway,
666691
)
667692
except ImmatureSignatureError:

tests/test_auth.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from unittest import mock
77
from unittest.mock import patch
88

9+
import jwt
10+
911
from descope import (
1012
API_RATE_LIMIT_RETRY_AFTER_HEADER,
1113
ERROR_TYPE_API_RATE_LIMIT,
@@ -778,6 +780,133 @@ def test_raise_from_response(self):
778780
"""{"errorCode":"E062108","errorDescription":"User not found","errorMessage":"Cannot find user"}""",
779781
)
780782

783+
def test_validate_session_audience_auto_detection(self):
784+
"""Test that validate_session automatically detects audience when token audience matches project ID"""
785+
auth = Auth(self.dummy_project_id, self.public_key_dict)
786+
787+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
788+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
789+
mock_decode.side_effect = [
790+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999},
791+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}
792+
]
793+
794+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
795+
with patch.object(auth, '_fetch_public_keys'):
796+
result = auth.validate_session("dummy_token")
797+
798+
self.assertEqual(mock_decode.call_count, 2)
799+
first_call = mock_decode.call_args_list[0]
800+
self.assertIn("options", first_call.kwargs)
801+
self.assertIn("verify_aud", first_call.kwargs["options"])
802+
self.assertFalse(first_call.kwargs["options"]["verify_aud"])
803+
second_call = mock_decode.call_args_list[1]
804+
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)
805+
806+
def test_validate_session_audience_auto_detection_list(self):
807+
"""Test that validate_session automatically detects audience when token audience is a list containing project ID"""
808+
auth = Auth(self.dummy_project_id, self.public_key_dict)
809+
810+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
811+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
812+
mock_decode.side_effect = [
813+
{"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999},
814+
{"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999}
815+
]
816+
817+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
818+
with patch.object(auth, '_fetch_public_keys'):
819+
result = auth.validate_session("dummy_token")
820+
821+
self.assertEqual(mock_decode.call_count, 2)
822+
second_call = mock_decode.call_args_list[1]
823+
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)
824+
825+
def test_validate_session_audience_auto_detection_no_match(self):
826+
"""Test that validate_session does not auto-detect audience when token audience doesn't match project ID"""
827+
auth = Auth(self.dummy_project_id, self.public_key_dict)
828+
829+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
830+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
831+
mock_decode.side_effect = [
832+
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999},
833+
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999}
834+
]
835+
836+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
837+
with patch.object(auth, '_fetch_public_keys'):
838+
result = auth.validate_session("dummy_token")
839+
840+
self.assertEqual(mock_decode.call_count, 2)
841+
second_call = mock_decode.call_args_list[1]
842+
self.assertIsNone(second_call.kwargs["audience"])
843+
844+
def test_validate_session_explicit_audience(self):
845+
"""Test that validate_session uses explicit audience parameter instead of auto-detection"""
846+
auth = Auth(self.dummy_project_id, self.public_key_dict)
847+
explicit_audience = "explicit-audience"
848+
849+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
850+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
851+
mock_decode.return_value = {"aud": explicit_audience, "sub": "user123", "exp": 9999999999}
852+
853+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
854+
with patch.object(auth, '_fetch_public_keys'):
855+
result = auth.validate_session("dummy_token", audience=explicit_audience)
856+
857+
self.assertEqual(mock_decode.call_count, 1)
858+
call_args = mock_decode.call_args
859+
self.assertEqual(call_args.kwargs["audience"], explicit_audience)
860+
861+
def test_validate_and_refresh_session_audience_auto_detection(self):
862+
"""Test that validate_and_refresh_session automatically detects audience when token audience matches project ID"""
863+
auth = Auth(self.dummy_project_id, self.public_key_dict)
864+
865+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
866+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
867+
mock_decode.side_effect = [
868+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999},
869+
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}
870+
]
871+
872+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
873+
with patch.object(auth, '_fetch_public_keys'):
874+
with patch("requests.post") as mock_post:
875+
mock_post.return_value.ok = True
876+
mock_post.return_value.json.return_value = {"sessionJwt": "new_token"}
877+
mock_post.return_value.cookies = {}
878+
879+
result = auth.validate_and_refresh_session("dummy_session_token", "dummy_refresh_token")
880+
881+
self.assertEqual(mock_decode.call_count, 2)
882+
first_call = mock_decode.call_args_list[0]
883+
self.assertIn("options", first_call.kwargs)
884+
self.assertIn("verify_aud", first_call.kwargs["options"])
885+
self.assertFalse(first_call.kwargs["options"]["verify_aud"])
886+
second_call = mock_decode.call_args_list[1]
887+
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)
888+
889+
def test_validate_session_audience_mismatch_fails(self):
890+
"""Test that validate_session fails when token audience doesn't match project ID and no explicit audience is provided"""
891+
auth = Auth(self.dummy_project_id, self.public_key_dict)
892+
893+
with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
894+
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
895+
# First call succeeds (for audience detection), second call fails (for validation with None audience)
896+
mock_decode.side_effect = [
897+
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, # First call for audience detection
898+
jwt.InvalidAudienceError("Invalid audience") # Second call fails because audience doesn't match
899+
]
900+
901+
with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
902+
with patch.object(auth, '_fetch_public_keys'):
903+
with self.assertRaises(jwt.InvalidAudienceError) as cm:
904+
auth.validate_session("dummy_token")
905+
906+
# Verify the error is about invalid audience
907+
self.assertIn("Invalid audience", str(cm.exception))
908+
self.assertEqual(mock_decode.call_count, 2)
909+
781910

782911
if __name__ == "__main__":
783912
unittest.main()

0 commit comments

Comments
 (0)