diff --git a/descope/auth.py b/descope/auth.py index 59d6c2cc1..831f48224 100644 --- a/descope/auth.py +++ b/descope/auth.py @@ -9,7 +9,6 @@ from typing import Iterable, Optional import jwt -import requests from email_validator import EmailNotValidError, validate_email from jwt import ExpiredSignatureError, ImmatureSignatureError @@ -53,8 +52,7 @@ def __init__( http_client: HTTPClient, ): self.lock_public_keys = Lock() - # validate project id - project_id = project_id or os.getenv("DESCOPE_PROJECT_ID", "") + if not project_id: raise AuthException( 400, @@ -451,7 +449,7 @@ def _validate_token( leeway=self.jwt_validation_leeway, ) token_audience = unverified_claims.get("aud") - + # If token has audience claim and it matches our project ID, use it if token_audience and self.project_id: if isinstance(token_audience, list): diff --git a/descope/descope_client.py b/descope/descope_client.py index c3a10c162..0fc2213ed 100644 --- a/descope/descope_client.py +++ b/descope/descope_client.py @@ -35,6 +35,18 @@ def __init__( auth_management_key: Optional[str] = None, fga_cache_url: Optional[str] = None, ): + # validate project id + project_id = project_id or os.getenv("DESCOPE_PROJECT_ID", "") + if not project_id: + raise AuthException( + 400, + ERROR_TYPE_INVALID_ARGUMENT, + ( + "Unable to init DescopeClient because project_id cannot be empty. " + "Set environment variable DESCOPE_PROJECT_ID or pass your Project ID to the init function." + ), + ) + # Auth Initialization auth_http_client = HTTPClient( project_id=project_id, diff --git a/descope/http_client.py b/descope/http_client.py index 4cd0c0e6e..540812c75 100644 --- a/descope/http_client.py +++ b/descope/http_client.py @@ -21,6 +21,7 @@ from descope.exceptions import ( API_RATE_LIMIT_RETRY_AFTER_HEADER, ERROR_TYPE_API_RATE_LIMIT, + ERROR_TYPE_INVALID_ARGUMENT, ERROR_TYPE_SERVER_ERROR, AuthException, RateLimitException, @@ -55,8 +56,11 @@ def __init__( if not project_id: raise AuthException( 400, - ERROR_TYPE_SERVER_ERROR, - "Project ID is required to initialize HTTP client", + ERROR_TYPE_INVALID_ARGUMENT, + ( + "Project ID is required to initialize HTTP client" + "Set environment variable DESCOPE_PROJECT_ID or pass your Project ID to the init function." + ), ) # Prefer explicitly provided base_url, then env var, then computed default diff --git a/descope/management/common.py b/descope/management/common.py index 31f3cb0f2..ebdc45dc9 100644 --- a/descope/management/common.py +++ b/descope/management/common.py @@ -1,17 +1,20 @@ from enum import Enum from typing import List, Optional + class SessionExpirationUnit(Enum): MINUTES = "minutes" HOURS = "hours" DAYS = "days" WEEKS = "weeks" + class TenantAuthType(Enum): NONE = "none" SAML = "saml" OIDC = "oidc" + class AccessType(Enum): OFFLINE = "offline" ONLINE = "online" @@ -303,6 +306,7 @@ def associated_tenants_to_dict(associated_tenants: List[AssociatedTenant]) -> li ) return associated_tenant_list + class SAMLIDPAttributeMappingInfo: """ Represents a SAML IDP attribute mapping object. use this class for mapping Descope attribute diff --git a/descope/management/tenant.py b/descope/management/tenant.py index e14630b06..b42acbc1d 100644 --- a/descope/management/tenant.py +++ b/descope/management/tenant.py @@ -1,7 +1,7 @@ from typing import Any, List, Optional from descope._http_base import HTTPBase -from descope.management.common import MgmtV1, TenantAuthType, SessionExpirationUnit +from descope.management.common import MgmtV1, SessionExpirationUnit, TenantAuthType class Tenant(HTTPBase): @@ -108,7 +108,7 @@ def update_settings( enable_inactivity: Optional[bool] = None, inactivity_time: Optional[int] = None, inactivity_time_unit: Optional[SessionExpirationUnit] = None, - JITDisabled: Optional[bool] = None + JITDisabled: Optional[bool] = None, ): """ Update an existing tenant's session settings. @@ -150,14 +150,10 @@ def update_settings( "inactivityTimeUnit": inactivity_time_unit, "JITDisabled": JITDisabled, } - + body = {k: v for k, v in body.items() if v is not None} - self._http.post( - MgmtV1.tenant_settings_path, - body=body, - params=None - ) + self._http.post(MgmtV1.tenant_settings_path, body=body, params=None) def delete( self, @@ -201,7 +197,7 @@ def load( params={"id": id}, ) return response.json() - + def load_settings( self, id: str, diff --git a/tests/management/test_tenant.py b/tests/management/test_tenant.py index 89858dfc1..956e90948 100644 --- a/tests/management/test_tenant.py +++ b/tests/management/test_tenant.py @@ -389,7 +389,13 @@ def test_update_settings(self): with patch("requests.post") as mock_post: mock_post.return_value.ok = True self.assertIsNone( - client.mgmt.tenant.update_settings("t1", self_provisioning_domains=["domain1.com"], domains=["domain1.com", "domain2.com"], auth_type="oidc", session_settings_enabled=True) + client.mgmt.tenant.update_settings( + "t1", + self_provisioning_domains=["domain1.com"], + domains=["domain1.com", "domain2.com"], + auth_type="oidc", + session_settings_enabled=True, + ) ) mock_post.assert_called_with( f"{common.DEFAULT_BASE_URL}{MgmtV1.tenant_settings_path}", @@ -403,7 +409,7 @@ def test_update_settings(self): "selfProvisioningDomains": ["domain1.com"], "domains": ["domain1.com", "domain2.com"], "authType": "oidc", - "enabled": True + "enabled": True, }, allow_redirects=False, params=None, diff --git a/tests/test_auth.py b/tests/test_auth.py index 5c9fa7ed6..a3d64ddd7 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,5 +1,4 @@ import json -import os import unittest from enum import Enum from http import HTTPStatus @@ -134,12 +133,7 @@ def test_fetch_public_key(self): mock_get.return_value.text = valid_keys_response self.assertIsNone(auth._fetch_public_keys()) - def test_project_id_from_env(self): - os.environ["DESCOPE_PROJECT_ID"] = self.dummy_project_id - Auth(http_client=self.make_http_client()) - - def test_project_id_from_env_without_env(self): - os.environ["DESCOPE_PROJECT_ID"] = "" + def test_empty_project_id(self): self.assertRaises(AuthException, Auth, http_client=self.make_http_client()) def test_base_url_for_project_id(self): @@ -952,127 +946,233 @@ def test_raise_from_response(self): def test_validate_session_audience_auto_detection(self): """Test that validate_session automatically detects audience when token audience matches project ID""" - auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client()) - - with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: - mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} + auth = Auth( + self.dummy_project_id, + self.public_key_dict, + http_client=self.make_http_client(), + ) + + with patch("jwt.get_unverified_header") as mock_get_header, patch( + "jwt.decode" + ) as mock_decode: + mock_get_header.return_value = { + "alg": "ES384", + "kid": self.public_key_dict["kid"], + } mock_decode.side_effect = [ {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}, - {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999} + {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}, ] - - with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): - with patch.object(auth, '_fetch_public_keys'): - result = auth.validate_session("dummy_token") - + + with patch.object( + auth, + "public_keys", + {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}, + ): + with patch.object(auth, "_fetch_public_keys"): + auth.validate_session("dummy_token") + self.assertEqual(mock_decode.call_count, 2) first_call = mock_decode.call_args_list[0] self.assertIn("options", first_call.kwargs) self.assertIn("verify_aud", first_call.kwargs["options"]) self.assertFalse(first_call.kwargs["options"]["verify_aud"]) second_call = mock_decode.call_args_list[1] - self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id) + self.assertEqual( + second_call.kwargs["audience"], self.dummy_project_id + ) def test_validate_session_audience_auto_detection_list(self): """Test that validate_session automatically detects audience when token audience is a list containing project ID""" - auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client()) - - with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: - mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} + auth = Auth( + self.dummy_project_id, + self.public_key_dict, + http_client=self.make_http_client(), + ) + + with patch("jwt.get_unverified_header") as mock_get_header, patch( + "jwt.decode" + ) as mock_decode: + mock_get_header.return_value = { + "alg": "ES384", + "kid": self.public_key_dict["kid"], + } mock_decode.side_effect = [ - {"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999}, - {"aud": [self.dummy_project_id, "other-audience"], "sub": "user123", "exp": 9999999999} + { + "aud": [self.dummy_project_id, "other-audience"], + "sub": "user123", + "exp": 9999999999, + }, + { + "aud": [self.dummy_project_id, "other-audience"], + "sub": "user123", + "exp": 9999999999, + }, ] - - with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): - with patch.object(auth, '_fetch_public_keys'): - result = auth.validate_session("dummy_token") - + + with patch.object( + auth, + "public_keys", + {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}, + ): + with patch.object(auth, "_fetch_public_keys"): + auth.validate_session("dummy_token") + self.assertEqual(mock_decode.call_count, 2) second_call = mock_decode.call_args_list[1] - self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id) + self.assertEqual( + second_call.kwargs["audience"], self.dummy_project_id + ) def test_validate_session_audience_auto_detection_no_match(self): """Test that validate_session does not auto-detect audience when token audience doesn't match project ID""" - auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client()) - - with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: - mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} + auth = Auth( + self.dummy_project_id, + self.public_key_dict, + http_client=self.make_http_client(), + ) + + with patch("jwt.get_unverified_header") as mock_get_header, patch( + "jwt.decode" + ) as mock_decode: + mock_get_header.return_value = { + "alg": "ES384", + "kid": self.public_key_dict["kid"], + } mock_decode.side_effect = [ {"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, - {"aud": "different-project-id", "sub": "user123", "exp": 9999999999} + {"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, ] - - with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): - with patch.object(auth, '_fetch_public_keys'): - result = auth.validate_session("dummy_token") - + + with patch.object( + auth, + "public_keys", + {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}, + ): + with patch.object(auth, "_fetch_public_keys"): + auth.validate_session("dummy_token") + self.assertEqual(mock_decode.call_count, 2) second_call = mock_decode.call_args_list[1] self.assertIsNone(second_call.kwargs["audience"]) def test_validate_session_explicit_audience(self): """Test that validate_session uses explicit audience parameter instead of auto-detection""" - auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client()) + auth = Auth( + self.dummy_project_id, + self.public_key_dict, + http_client=self.make_http_client(), + ) explicit_audience = "explicit-audience" - - with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: - mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} - mock_decode.return_value = {"aud": explicit_audience, "sub": "user123", "exp": 9999999999} - - with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): - with patch.object(auth, '_fetch_public_keys'): - result = auth.validate_session("dummy_token", audience=explicit_audience) - + + with patch("jwt.get_unverified_header") as mock_get_header, patch( + "jwt.decode" + ) as mock_decode: + mock_get_header.return_value = { + "alg": "ES384", + "kid": self.public_key_dict["kid"], + } + mock_decode.return_value = { + "aud": explicit_audience, + "sub": "user123", + "exp": 9999999999, + } + + with patch.object( + auth, + "public_keys", + {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}, + ): + with patch.object(auth, "_fetch_public_keys"): + auth.validate_session("dummy_token", audience=explicit_audience) + self.assertEqual(mock_decode.call_count, 1) call_args = mock_decode.call_args self.assertEqual(call_args.kwargs["audience"], explicit_audience) def test_validate_and_refresh_session_audience_auto_detection(self): """Test that validate_and_refresh_session automatically detects audience when token audience matches project ID""" - auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client()) - - with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: - mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} + auth = Auth( + self.dummy_project_id, + self.public_key_dict, + http_client=self.make_http_client(), + ) + + with patch("jwt.get_unverified_header") as mock_get_header, patch( + "jwt.decode" + ) as mock_decode: + mock_get_header.return_value = { + "alg": "ES384", + "kid": self.public_key_dict["kid"], + } mock_decode.side_effect = [ {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}, - {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999} + {"aud": self.dummy_project_id, "sub": "user123", "exp": 9999999999}, ] - - with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): - with patch.object(auth, '_fetch_public_keys'): + + with patch.object( + auth, + "public_keys", + {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}, + ): + with patch.object(auth, "_fetch_public_keys"): with patch("requests.post") as mock_post: mock_post.return_value.ok = True - mock_post.return_value.json.return_value = {"sessionJwt": "new_token"} + mock_post.return_value.json.return_value = { + "sessionJwt": "new_token" + } mock_post.return_value.cookies = {} - - result = auth.validate_and_refresh_session("dummy_session_token", "dummy_refresh_token") - + + auth.validate_and_refresh_session( + "dummy_session_token", "dummy_refresh_token" + ) + self.assertEqual(mock_decode.call_count, 2) first_call = mock_decode.call_args_list[0] self.assertIn("options", first_call.kwargs) self.assertIn("verify_aud", first_call.kwargs["options"]) self.assertFalse(first_call.kwargs["options"]["verify_aud"]) second_call = mock_decode.call_args_list[1] - self.assertEqual(second_call.kwargs["audience"], self.dummy_project_id) + self.assertEqual( + second_call.kwargs["audience"], self.dummy_project_id + ) def test_validate_session_audience_mismatch_fails(self): """Test that validate_session fails when token audience doesn't match project ID and no explicit audience is provided""" - auth = Auth(self.dummy_project_id, self.public_key_dict, http_client=self.make_http_client()) - - with patch("jwt.get_unverified_header") as mock_get_header, patch("jwt.decode") as mock_decode: - mock_get_header.return_value = {"alg": "ES384", "kid": self.public_key_dict["kid"]} + auth = Auth( + self.dummy_project_id, + self.public_key_dict, + http_client=self.make_http_client(), + ) + + with patch("jwt.get_unverified_header") as mock_get_header, patch( + "jwt.decode" + ) as mock_decode: + mock_get_header.return_value = { + "alg": "ES384", + "kid": self.public_key_dict["kid"], + } # First call succeeds (for audience detection), second call fails (for validation with None audience) mock_decode.side_effect = [ - {"aud": "different-project-id", "sub": "user123", "exp": 9999999999}, # First call for audience detection - jwt.InvalidAudienceError("Invalid audience") # Second call fails because audience doesn't match + { + "aud": "different-project-id", + "sub": "user123", + "exp": 9999999999, + }, # First call for audience detection + jwt.InvalidAudienceError( + "Invalid audience" + ), # Second call fails because audience doesn't match ] - - with patch.object(auth, 'public_keys', {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}): - with patch.object(auth, '_fetch_public_keys'): + + with patch.object( + auth, + "public_keys", + {self.public_key_dict["kid"]: (mock.Mock(), "ES384")}, + ): + with patch.object(auth, "_fetch_public_keys"): with self.assertRaises(jwt.InvalidAudienceError) as cm: auth.validate_session("dummy_token") - + # Verify the error is about invalid audience self.assertIn("Invalid audience", str(cm.exception)) self.assertEqual(mock_decode.call_count, 2) diff --git a/tests/test_descope_client.py b/tests/test_descope_client.py index 16f06695b..5c7b669ee 100644 --- a/tests/test_descope_client.py +++ b/tests/test_descope_client.py @@ -1,4 +1,5 @@ import json +import os import sys import unittest from copy import deepcopy @@ -69,6 +70,10 @@ def test_descope_client(self): DescopeClient(project_id="dummy", public_key=self.public_key_str) ) + def test_project_id_from_env_without_env(self): + os.environ["DESCOPE_PROJECT_ID"] = "" + self.assertRaises(AuthException, DescopeClient, "") + def test_mgmt(self): client = DescopeClient(self.dummy_project_id, self.public_key_dict) diff --git a/tests/test_jwt_common.py b/tests/test_jwt_common.py index dcfd1217b..a6b9c7add 100644 --- a/tests/test_jwt_common.py +++ b/tests/test_jwt_common.py @@ -5,7 +5,6 @@ REFRESH_SESSION_TOKEN_NAME, SESSION_TOKEN_NAME, decode_token_unverified, - generate_auth_info, generate_jwt_response, )