Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion descope/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,12 +439,37 @@ def _validate_token(
"Algorithm signature in JWT header does not match the algorithm signature in the public key",
)

# Check if we need to auto-detect audience from token
validation_audience = audience
if audience is None:
try:
unverified_claims = jwt.decode(
jwt=token,
key=copy_key[0].key,
algorithms=[alg_header],
options={"verify_aud": False}, # Skip audience verification for now
leeway=self.jwt_validation_leeway,
)
token_audience = unverified_claims.get("aud")

# If token has audience claim and it matches our project ID, use it
if token_audience and self.project_id:
if isinstance(token_audience, list):
if self.project_id in token_audience:
validation_audience = self.project_id
else:
if token_audience == self.project_id:
validation_audience = self.project_id
except Exception:
# If we can't decode the token to check audience, proceed with original audience (None)
pass

try:
claims = jwt.decode(
jwt=token,
key=copy_key[0].key,
algorithms=[alg_header],
audience=audience,
audience=validation_audience,
leeway=self.jwt_validation_leeway,
)
except ImmatureSignatureError:
Expand Down
129 changes: 129 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from unittest import mock
from unittest.mock import patch

import jwt

from descope import (
API_RATE_LIMIT_RETRY_AFTER_HEADER,
ERROR_TYPE_API_RATE_LIMIT,
Expand Down Expand Up @@ -948,6 +950,133 @@ def test_raise_from_response(self):
"""{"errorCode":"E062108","errorDescription":"User not found","errorMessage":"Cannot find user"}""",
)

def test_validate_session_audience_auto_detection(self):
"""Test that validate_session automatically detects audience when token audience matches project ID"""
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())

with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
mock_decode.side_effect = [
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999},
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}
]

with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
with patch.object(auth, '_fetch_public_keys'):
result = auth.validate_session("dummy_token")

self.assertEqual(mock_decode.call_count, 2)
first_call = mock_decode.call_args_list[0]
self.assertIn("options", first_call.kwargs)
self.assertIn("verify_aud", first_call.kwargs["options"])
self.assertFalse(first_call.kwargs["options"]["verify_aud"])
second_call = mock_decode.call_args_list[1]
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)

def test_validate_session_audience_auto_detection_list(self):
"""Test that validate_session automatically detects audience when token audience is a list containing project ID"""
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())

with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
mock_decode.side_effect = [
{"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999},
{"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999}
]

with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
with patch.object(auth, '_fetch_public_keys'):
result = auth.validate_session("dummy_token")

self.assertEqual(mock_decode.call_count, 2)
second_call = mock_decode.call_args_list[1]
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)

def test_validate_session_audience_auto_detection_no_match(self):
"""Test that validate_session does not auto-detect audience when token audience doesn't match project ID"""
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())

with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
mock_decode.side_effect = [
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999},
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999}
]

with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
with patch.object(auth, '_fetch_public_keys'):
result = auth.validate_session("dummy_token")

self.assertEqual(mock_decode.call_count, 2)
second_call = mock_decode.call_args_list[1]
self.assertIsNone(second_call.kwargs["audience"])

def test_validate_session_explicit_audience(self):
"""Test that validate_session uses explicit audience parameter instead of auto-detection"""
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())
explicit_audience = "explicit-audience"

with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
mock_decode.return_value = {"aud": explicit_audience, "sub": "user123", "exp": 9999999999}

with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
with patch.object(auth, '_fetch_public_keys'):
result = auth.validate_session("dummy_token", audience=explicit_audience)

self.assertEqual(mock_decode.call_count, 1)
call_args = mock_decode.call_args
self.assertEqual(call_args.kwargs["audience"], explicit_audience)

def test_validate_and_refresh_session_audience_auto_detection(self):
"""Test that validate_and_refresh_session automatically detects audience when token audience matches project ID"""
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())

with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
mock_decode.side_effect = [
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999},
{"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}
]

with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
with patch.object(auth, '_fetch_public_keys'):
with patch("requests.post") as mock_post:
mock_post.return_value.ok = True
mock_post.return_value.json.return_value = {"sessionJwt": "new_token"}
mock_post.return_value.cookies = {}

result = auth.validate_and_refresh_session("dummy_session_token", "dummy_refresh_token")

self.assertEqual(mock_decode.call_count, 2)
first_call = mock_decode.call_args_list[0]
self.assertIn("options", first_call.kwargs)
self.assertIn("verify_aud", first_call.kwargs["options"])
self.assertFalse(first_call.kwargs["options"]["verify_aud"])
second_call = mock_decode.call_args_list[1]
self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id)

def test_validate_session_audience_mismatch_fails(self):
"""Test that validate_session fails when token audience doesn't match project ID and no explicit audience is provided"""
auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client())

with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode:
mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]}
# First call succeeds (for audience detection), second call fails (for validation with None audience)
mock_decode.side_effect = [
{"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, # First call for audience detection
jwt.InvalidAudienceError("Invalid audience") # Second call fails because audience doesn't match
]

with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}):
with patch.object(auth, '_fetch_public_keys'):
with self.assertRaises(jwt.InvalidAudienceError) as cm:
auth.validate_session("dummy_token")

# Verify the error is about invalid audience
self.assertIn("Invalid audience", str(cm.exception))
self.assertEqual(mock_decode.call_count, 2)

def test_http_client_authorization_header_variants(self):
# Base client without management key
client = self.make_http_client()
Expand Down
Loading