|
6 | 6 | from unittest import mock |
7 | 7 | from unittest.mock import patch |
8 | 8 |
|
| 9 | +import jwt |
| 10 | + |
9 | 11 | from descope import ( |
10 | 12 | API_RATE_LIMIT_RETRY_AFTER_HEADER, |
11 | 13 | ERROR_TYPE_API_RATE_LIMIT, |
@@ -778,6 +780,133 @@ def test_raise_from_response(self): |
778 | 780 | """{"errorCode":"E062108","errorDescription":"User not found","errorMessage":"Cannot find user"}""", |
779 | 781 | ) |
780 | 782 |
|
| 783 | + def test_validate_session_audience_auto_detection(self): |
| 784 | + """Test that validate_session automatically detects audience when token audience matches project ID""" |
| 785 | + auth = Auth(self.dummy_project_id, self.public_key_dict) |
| 786 | + |
| 787 | + with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: |
| 788 | + mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} |
| 789 | + mock_decode.side_effect = [ |
| 790 | + {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}, |
| 791 | + {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999} |
| 792 | + ] |
| 793 | + |
| 794 | + with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): |
| 795 | + with patch.object(auth, '_fetch_public_keys'): |
| 796 | + result = auth.validate_session("dummy_token") |
| 797 | + |
| 798 | + self.assertEqual(mock_decode.call_count, 2) |
| 799 | + first_call = mock_decode.call_args_list[0] |
| 800 | + self.assertIn("options", first_call.kwargs) |
| 801 | + self.assertIn("verify_aud", first_call.kwargs["options"]) |
| 802 | + self.assertFalse(first_call.kwargs["options"]["verify_aud"]) |
| 803 | + second_call = mock_decode.call_args_list[1] |
| 804 | + self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id) |
| 805 | + |
| 806 | + def test_validate_session_audience_auto_detection_list(self): |
| 807 | + """Test that validate_session automatically detects audience when token audience is a list containing project ID""" |
| 808 | + auth = Auth(self.dummy_project_id, self.public_key_dict) |
| 809 | + |
| 810 | + with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: |
| 811 | + mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} |
| 812 | + mock_decode.side_effect = [ |
| 813 | + {"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999}, |
| 814 | + {"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999} |
| 815 | + ] |
| 816 | + |
| 817 | + with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): |
| 818 | + with patch.object(auth, '_fetch_public_keys'): |
| 819 | + result = auth.validate_session("dummy_token") |
| 820 | + |
| 821 | + self.assertEqual(mock_decode.call_count, 2) |
| 822 | + second_call = mock_decode.call_args_list[1] |
| 823 | + self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id) |
| 824 | + |
| 825 | + def test_validate_session_audience_auto_detection_no_match(self): |
| 826 | + """Test that validate_session does not auto-detect audience when token audience doesn't match project ID""" |
| 827 | + auth = Auth(self.dummy_project_id, self.public_key_dict) |
| 828 | + |
| 829 | + with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: |
| 830 | + mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} |
| 831 | + mock_decode.side_effect = [ |
| 832 | + {"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, |
| 833 | + {"aud": "different-project-id", "sub": "user123", "exp": 9999999999} |
| 834 | + ] |
| 835 | + |
| 836 | + with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): |
| 837 | + with patch.object(auth, '_fetch_public_keys'): |
| 838 | + result = auth.validate_session("dummy_token") |
| 839 | + |
| 840 | + self.assertEqual(mock_decode.call_count, 2) |
| 841 | + second_call = mock_decode.call_args_list[1] |
| 842 | + self.assertIsNone(second_call.kwargs["audience"]) |
| 843 | + |
| 844 | + def test_validate_session_explicit_audience(self): |
| 845 | + """Test that validate_session uses explicit audience parameter instead of auto-detection""" |
| 846 | + auth = Auth(self.dummy_project_id, self.public_key_dict) |
| 847 | + explicit_audience = "explicit-audience" |
| 848 | + |
| 849 | + with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: |
| 850 | + mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} |
| 851 | + mock_decode.return_value = {"aud": explicit_audience, "sub": "user123", "exp": 9999999999} |
| 852 | + |
| 853 | + with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): |
| 854 | + with patch.object(auth, '_fetch_public_keys'): |
| 855 | + result = auth.validate_session("dummy_token", audience=explicit_audience) |
| 856 | + |
| 857 | + self.assertEqual(mock_decode.call_count, 1) |
| 858 | + call_args = mock_decode.call_args |
| 859 | + self.assertEqual(call_args.kwargs["audience"], explicit_audience) |
| 860 | + |
| 861 | + def test_validate_and_refresh_session_audience_auto_detection(self): |
| 862 | + """Test that validate_and_refresh_session automatically detects audience when token audience matches project ID""" |
| 863 | + auth = Auth(self.dummy_project_id, self.public_key_dict) |
| 864 | + |
| 865 | + with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: |
| 866 | + mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} |
| 867 | + mock_decode.side_effect = [ |
| 868 | + {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}, |
| 869 | + {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999} |
| 870 | + ] |
| 871 | + |
| 872 | + with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): |
| 873 | + with patch.object(auth, '_fetch_public_keys'): |
| 874 | + with patch("requests.post") as mock_post: |
| 875 | + mock_post.return_value.ok = True |
| 876 | + mock_post.return_value.json.return_value = {"sessionJwt": "new_token"} |
| 877 | + mock_post.return_value.cookies = {} |
| 878 | + |
| 879 | + result = auth.validate_and_refresh_session("dummy_session_token", "dummy_refresh_token") |
| 880 | + |
| 881 | + self.assertEqual(mock_decode.call_count, 2) |
| 882 | + first_call = mock_decode.call_args_list[0] |
| 883 | + self.assertIn("options", first_call.kwargs) |
| 884 | + self.assertIn("verify_aud", first_call.kwargs["options"]) |
| 885 | + self.assertFalse(first_call.kwargs["options"]["verify_aud"]) |
| 886 | + second_call = mock_decode.call_args_list[1] |
| 887 | + self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id) |
| 888 | + |
| 889 | + def test_validate_session_audience_mismatch_fails(self): |
| 890 | + """Test that validate_session fails when token audience doesn't match project ID and no explicit audience is provided""" |
| 891 | + auth = Auth(self.dummy_project_id, self.public_key_dict) |
| 892 | + |
| 893 | + with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: |
| 894 | + mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} |
| 895 | + # First call succeeds (for audience detection), second call fails (for validation with None audience) |
| 896 | + mock_decode.side_effect = [ |
| 897 | + {"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, # First call for audience detection |
| 898 | + jwt.InvalidAudienceError("Invalid audience") # Second call fails because audience doesn't match |
| 899 | + ] |
| 900 | + |
| 901 | + with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): |
| 902 | + with patch.object(auth, '_fetch_public_keys'): |
| 903 | + with self.assertRaises(jwt.InvalidAudienceError) as cm: |
| 904 | + auth.validate_session("dummy_token") |
| 905 | + |
| 906 | + # Verify the error is about invalid audience |
| 907 | + self.assertIn("Invalid audience", str(cm.exception)) |
| 908 | + self.assertEqual(mock_decode.call_count, 2) |
| 909 | + |
781 | 910 |
|
782 | 911 | if __name__ == "__main__": |
783 | 912 | unittest.main() |
0 commit comments