diff --git a/ansible_base/jwt_consumer/common/auth.py b/ansible_base/jwt_consumer/common/auth.py index 72641cdb6..381958499 100644 --- a/ansible_base/jwt_consumer/common/auth.py +++ b/ansible_base/jwt_consumer/common/auth.py @@ -18,7 +18,9 @@ from ansible_base.lib.logging.runtime import log_excess_runtime from ansible_base.lib.utils.auth import get_user_by_ansible_id from ansible_base.lib.utils.translations import translatableConditionally as _ +from ansible_base.rbac.claims import get_user_claims, get_user_claims_hashable_form, get_claims_hash from ansible_base.resource_registry.models import Resource, ResourceType +from ansible_base.resource_registry.rest_client import get_resource_server_client from ansible_base.resource_registry.signals.handlers import no_reverse_sync logger = logging.getLogger("ansible_base.jwt_consumer.common.auth") @@ -52,6 +54,7 @@ def __init__(self, user_fields=default_mapped_user_fields) -> None: self.cache = JWTCache() self.user = None self.token = None + self.gateway_claims = None # Store claims from gateway @log_excess_runtime(logger, debug_cutoff=0.01) def parse_jwt_token(self, request): @@ -142,10 +145,118 @@ def parse_jwt_token(self, request): resource.service_id = self.token['service_id'] resource.save(update_fields=['ansible_id', 'service_id']) + # Check if claims need to be refreshed from gateway based on claims_hash + user_ansible_id = self.token['sub'] + current_claims_hash = self.token.get('claims_hash') + + if self._should_fetch_claims_from_gateway(user_ansible_id, current_claims_hash): + logger.debug(f"Claims hash changed or not cached, fetching claims from gateway for user {user_ansible_id}") + jwt_claims = self._fetch_jwt_claims_from_gateway(user_ansible_id) + + if jwt_claims: + self.gateway_claims = jwt_claims + self._cache_claims_hash(user_ansible_id, current_claims_hash) + logger.debug(f"Successfully loaded and cached gateway claims for user {user_ansible_id}") + else: + logger.error(f"Failed to fetch claims from gateway for user {user_ansible_id}. RBAC processing will not be available.") + # Note: We don't raise an exception here to allow basic authentication to succeed + # RBAC processing will fail gracefully with appropriate error messages + else: + logger.debug(f"Using cached claims for user {user_ansible_id} (claims_hash unchanged)") setattr(self.user, "resource_api_actions", self.token.get("resource_api_actions", None)) logger.info(f"User {self.user.username} authenticated from JWT auth") + def _should_fetch_claims_from_gateway(self, user_ansible_id, current_claims_hash): + """ + Determine if claims should be fetched from gateway based on claims_hash comparison. + Returns True if claims need to be fetched (hash changed or not cached). + """ + if not current_claims_hash: + logger.debug(f"No claims_hash in token for user {user_ansible_id}, will fetch claims") + return True + + cached_hash = self.cache.get_claims_hash(user_ansible_id) + if cached_hash != current_claims_hash: + # Recalculate hash from local database to verify the mismatch + # It is possible that the cached hash is stale, but the local data is synced to the resource server. + # This is an optimization to avoid fetching claims from the resource server if the local data is synced. + logger.debug(f"Claims hash mismatch for user {user_ansible_id}: cached={cached_hash}, from token={current_claims_hash}") + logger.debug(f"Recalculating hash from local database for user {user_ansible_id}") + + try: + # Get user claims from local database + user_claims = get_user_claims(self.user) + hashable_claims = get_user_claims_hashable_form(user_claims) + recalculated_hash = get_claims_hash(hashable_claims) + + # Compare recalculated hash with current hash from token + if recalculated_hash != current_claims_hash: + logger.debug(f"Claims hash still differs after recalculation for user {user_ansible_id}: local={recalculated_hash}, from token={current_claims_hash}") + return True + else: + logger.debug(f"Recalculated local hash matches token hash for user {user_ansible_id}") + logger.debug(f"Caching claims hash for user {user_ansible_id}: {recalculated_hash}") + self._cache_claims_hash(user_ansible_id, recalculated_hash) + return False + + except Exception as e: + logger.error(f"Failed to recalculate claims hash for user {user_ansible_id}: {e}") + # If recalculation fails, fall back to treating it as a hash mismatch + return True + + # Hash matches cached value, try to get cached claims + cached_claims = self.cache.get_cached_claims(user_ansible_id) + if cached_claims: + self.gateway_claims = cached_claims + return False + else: + logger.debug(f"Claims hash matches but no cached claims found for user {user_ansible_id}") + return True + + def _cache_claims_hash(self, user_ansible_id, claims_hash): + """Cache the claims hash and gateway claims for future comparisons.""" + if claims_hash and self.gateway_claims: + self.cache.set_claims_hash(user_ansible_id, claims_hash) + self.cache.set_cached_claims(user_ansible_id, self.gateway_claims) + + def _fetch_jwt_claims_from_gateway(self, user_ansible_id): + """ + Fetch JWT claims for a user from the gateway service-index API. + Returns None if claims cannot be retrieved. + """ + try: + # Use the full service path from settings, not just "service-index" + # This should be something like "/api/gateway/v1/service-index/" + service_path = getattr(settings, "RESOURCE_SERVICE_PATH", "/api/gateway/v1/service-index/") + client = get_resource_server_client(service_path) + response = client.get_jwt_claims(user_ansible_id) + + if response.status_code == 200: + # Try to parse JSON, but handle empty or invalid responses + try: + claims = response.json() + logger.debug(f"Retrieved JWT claims from gateway for user {user_ansible_id}") + return claims + except ValueError: + # Log the actual response content for debugging + logger.error( + f"Invalid JSON response from gateway for user {user_ansible_id}. " + f"Status: {response.status_code}, " + f"Content-Type: {response.headers.get('Content-Type', 'unknown')}, " + f"Response text: {response.text[:500]}" # First 500 chars to avoid huge logs + ) + return None + else: + logger.warning( + f"Failed to retrieve JWT claims from gateway for user {user_ansible_id}: " + f"Status: {response.status_code}, Response: {response.text[:500]}" + ) + return None + except Exception as e: + logger.error(f"Error fetching JWT claims from gateway for user {user_ansible_id}: {e}") + return None + def log_and_raise(self, conditional_translate_object, expand_values={}, error_code=None): logger.error(conditional_translate_object.not_translated() % expand_values) translated_error_message = conditional_translate_object.translated() % expand_values @@ -226,7 +337,8 @@ def validate_token(self, unencrypted_token, decryption_key, request_id=None): return validated_body def decode_jwt_token(self, unencrypted_token, decryption_key, additional_options={}): - local_required_field = ["sub", "user_data", "exp", "objects", "object_roles", "global_roles", "version"] + # Core required fields - claims_hash is now required to track permission changes + local_required_field = ["sub", "user_data", "exp", "version", "claims_hash"] options = {"require": local_required_field} options.update(additional_options) return jwt.decode( @@ -259,16 +371,23 @@ def get_role_definition(self, name: str) -> Optional[Model]: def process_rbac_permissions(self): """ This is a default process_permissions which should be usable if you are using RBAC from DAB + Uses gateway claims data exclusively - no fallback to JWT token fields """ - if self.token is None or self.user is None: - logger.error("Unable to process rbac permissions because user or token is not defined, please call authenticate first") + if self.user is None: + logger.error("Unable to process rbac permissions because user is not defined, please call authenticate first") + return + + if self.gateway_claims is None: + logger.error("Unable to process rbac permissions because gateway claims are not available. Ensure gateway jwt_claims endpoint is accessible.") return from ansible_base.rbac.models import RoleUserAssignment role_diff = RoleUserAssignment.objects.filter(user=self.user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES) - for system_role_name in self.token.get("global_roles", []): + # Process global roles from gateway claims + global_roles = self.gateway_claims.get("global_roles", []) + for system_role_name in global_roles: logger.debug(f"Processing system role {system_role_name} for {self.user.username}") rd = self.get_role_definition(system_role_name) if rd: @@ -282,7 +401,11 @@ def process_rbac_permissions(self): logger.error(f"Unable to grant {self.user.username} system level role {system_role_name} because it does not exist") continue - for object_role_name in self.token.get('object_roles', {}).keys(): + # Process object roles from gateway claims + object_roles = self.gateway_claims.get('object_roles', {}) + objects = self.gateway_claims.get('objects', {}) + + for object_role_name in object_roles.keys(): rd = self.get_role_definition(object_role_name) if rd is None: logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it does not exist") @@ -291,11 +414,11 @@ def process_rbac_permissions(self): logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it is not a JWT managed role") continue - object_type = self.token['object_roles'][object_role_name]['content_type'] - object_indexes = self.token['object_roles'][object_role_name]['objects'] + object_type = object_roles[object_role_name]['content_type'] + object_indexes = object_roles[object_role_name]['objects'] for index in object_indexes: - object_data = self.token['objects'][object_type][index] + object_data = objects[object_type][index] try: resource, obj = self.get_or_create_resource(object_type, object_data) except IntegrityError as e: @@ -310,7 +433,7 @@ def process_rbac_permissions(self): role_diff = role_diff.exclude(pk=assignment.pk) logger.info(f"Granted user {self.user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}") - # Remove all permissions not authorized by the JWT + # Remove all permissions not authorized by the gateway claims for role_assignment in role_diff: rd = role_assignment.role_definition content_object = role_assignment.content_object @@ -322,9 +445,17 @@ def process_rbac_permissions(self): def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optional[Resource], Optional[Model]]: """ Gets or creates a resource from a content type and its default data + Uses gateway claims exclusively - no fallback to JWT token fields This can only build or get organizations or teams + Args: + content_type: Type of content ('team', 'organization') + data: Resource data dictionary """ + if self.gateway_claims is None: + logger.error("Unable to create resource because gateway claims are not available") + return None, None + object_ansible_id = data['ansible_id'] try: resource = Resource.objects.get(ansible_id=object_ansible_id) @@ -337,7 +468,7 @@ def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optiona if content_type == 'team': # For a team we first have to make sure the org is there org_id = data['org'] - organization_data = self.token['objects']["organization"][org_id] + organization_data = self.gateway_claims['objects']["organization"][org_id] # Now that we have the org we can build a team org_resource, _ = self.get_or_create_resource("organization", organization_data) @@ -359,7 +490,7 @@ def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optiona return resource, resource.content_object else: - logger.error(f"build_resource_stub does not know how to build an object of type {type}") + logger.error(f"build_resource_stub does not know how to build an object of type {content_type}") return None, None diff --git a/ansible_base/jwt_consumer/common/cache.py b/ansible_base/jwt_consumer/common/cache.py index b42a5862d..9bf899c7b 100644 --- a/ansible_base/jwt_consumer/common/cache.py +++ b/ansible_base/jwt_consumer/common/cache.py @@ -46,3 +46,23 @@ def get_key_from_cache(self) -> Optional[str]: def set_key_in_cache(self, key: str) -> None: cache.set(cache_key, key, timeout=self.get_cache_timeout()) + + def get_claims_hash(self, user_ansible_id: str) -> Optional[str]: + """Get cached claims hash for a user.""" + claims_hash_key = f"jwt_claims_hash_{user_ansible_id}" + return cache.get(claims_hash_key, None) + + def set_claims_hash(self, user_ansible_id: str, claims_hash: str) -> None: + """Set cached claims hash for a user.""" + claims_hash_key = f"jwt_claims_hash_{user_ansible_id}" + cache.set(claims_hash_key, claims_hash, timeout=self.get_cache_timeout()) + + def get_cached_claims(self, user_ansible_id: str) -> Optional[dict]: + """Get cached gateway claims for a user.""" + claims_key = f"jwt_gateway_claims_{user_ansible_id}" + return cache.get(claims_key, None) + + def set_cached_claims(self, user_ansible_id: str, claims: dict) -> None: + """Set cached gateway claims for a user.""" + claims_key = f"jwt_gateway_claims_{user_ansible_id}" + cache.set(claims_key, claims, timeout=self.get_cache_timeout()) diff --git a/ansible_base/jwt_consumer/hub/auth.py b/ansible_base/jwt_consumer/hub/auth.py index a855a0a65..aac6f2cc2 100644 --- a/ansible_base/jwt_consumer/hub/auth.py +++ b/ansible_base/jwt_consumer/hub/auth.py @@ -40,10 +40,15 @@ def process_permissions(self): # the teams this user should have a "shared" [!local] assignment to member_teams = [] - for role_name in self.common_auth.token.get('object_roles', {}).keys(): + # Process object roles from gateway claims instead of JWT token + if self.common_auth.gateway_claims is None: + logger.error("Unable to process permissions because gateway claims are not available") + return + + for role_name in self.common_auth.gateway_claims.get('object_roles', {}).keys(): if role_name.startswith('Team'): - for object_index in self.common_auth.token['object_roles'][role_name]['objects']: - team_data = self.common_auth.token['objects']['team'][object_index] + for object_index in self.common_auth.gateway_claims['object_roles'][role_name]['objects']: + team_data = self.common_auth.gateway_claims['objects']['team'][object_index] ansible_id = team_data['ansible_id'] try: team = Resource.objects.get(ansible_id=ansible_id).content_object @@ -83,7 +88,7 @@ def process_permissions(self): roledef.give_permission(self.common_auth.user, team) auditor_roledef = RoleDefinition.objects.get(name='Platform Auditor') - if "Platform Auditor" in self.common_auth.token.get('global_roles', []): + if "Platform Auditor" in self.common_auth.gateway_claims.get('global_roles', []): auditor_roledef.give_global_permission(self.common_auth.user) else: auditor_roledef.remove_global_permission(self.common_auth.user) diff --git a/ansible_base/resource_registry/rest_client.py b/ansible_base/resource_registry/rest_client.py index 5110f99c4..e664efd34 100644 --- a/ansible_base/resource_registry/rest_client.py +++ b/ansible_base/resource_registry/rest_client.py @@ -189,7 +189,10 @@ def list_team_assignments(self, team_ansible_id: Optional[str] = None, filters: return self._make_request("get", "role-team-assignments/", params=params) def sync_assignment(self, assignment): - from ansible_base.rbac.service_api.serializers import ServiceRoleTeamAssignmentSerializer, ServiceRoleUserAssignmentSerializer + from ansible_base.rbac.service_api.serializers import ( + ServiceRoleTeamAssignmentSerializer, + ServiceRoleUserAssignmentSerializer, + ) if assignment._meta.model_name == 'roleuserassignment': serializer = ServiceRoleUserAssignmentSerializer(assignment) @@ -227,3 +230,7 @@ def _sync_assignment(self, data, giving=True): url = f'role-{actor_type}-assignments/{sub_url}/' return self._make_request(method="post", path=url, data=data) + + def get_jwt_claims(self, user_ansible_id): + """Get JWT claims for a user from the gateway service-index.""" + return self._make_request("get", f"jwt_claims/{user_ansible_id}/") diff --git a/test_app/tests/conftest.py b/test_app/tests/conftest.py index 230134046..7d92602c6 100644 --- a/test_app/tests/conftest.py +++ b/test_app/tests/conftest.py @@ -542,9 +542,8 @@ def __init__(self): "email": "noone@redhat.com", "is_superuser": False, }, - "objects": {}, - "object_roles": {}, - "global_roles": [], + # claims_hash is required to track permission changes + "claims_hash": "test_hash_123", } def encrypt_token(self): @@ -595,6 +594,31 @@ def mocked_gateway_view_get_request(self, *args, **kwargs): return MockedHttp() +@pytest.fixture +def mock_gateway_jwt_claims(): + """Mock for gateway JWT claims endpoint.""" + mock_claims = { + "global_roles": ["Platform Auditor"], + "object_roles": {"Organization Admin": {"content_type": "organization", "objects": [0]}}, + "objects": {"organization": [{"ansible_id": "test-org-id", "name": "Test Organization"}], "team": []}, + } + + class MockResponse: + def __init__(self, json_data, status_code=200): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + class MockResourceAPIClient: + def get_jwt_claims(self, user_ansible_id): + return MockResponse(mock_claims) + + with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client', return_value=MockResourceAPIClient()): + yield mock_claims + + @pytest.fixture def system_user(db, no_log_messages): with no_log_messages(): diff --git a/test_app/tests/jwt_consumer/common/test_auth.py b/test_app/tests/jwt_consumer/common/test_auth.py index a6955aa64..bd062d138 100644 --- a/test_app/tests/jwt_consumer/common/test_auth.py +++ b/test_app/tests/jwt_consumer/common/test_auth.py @@ -14,6 +14,7 @@ from ansible_base.jwt_consumer.common.cert import JWTCert, JWTCertException from ansible_base.jwt_consumer.common.exceptions import InvalidTokenException from ansible_base.lib.utils.translations import translatableConditionally as _ +from ansible_base.rbac.claims import get_user_claims, get_user_claims_hashable_form, get_claims_hash from ansible_base.rbac.models import RoleDefinition, RoleUserAssignment from ansible_base.rbac.permission_registry import permission_registry from ansible_base.resource_registry.models import Resource @@ -152,9 +153,7 @@ def test_map_user_fields(self, user_fields, token, should_save, caplog, shut_up_ ("iss", False), ("exp", False), ("aud", False), - ("objects", False), - ("object_roles", False), - ("global_roles", False), + ("claims_hash", False), ], ) def test_validate_token_missing_default_items(self, remove, is_user_data_entry, jwt_token, test_encryption_public_key): @@ -301,7 +300,7 @@ def test_process_rbac_permissions_system_roles( ): authentication = JWTCommonAuth() authentication.user = admin_user - authentication.token = token + authentication.gateway_claims = token # Use gateway_claims instead of token if logs_error: with expected_log(default_logger, 'error', 'Unable to grant'): authentication.process_rbac_permissions() @@ -322,7 +321,7 @@ def test_process_rbac_permissions_system_roles( def test_process_rbac_permissions_object_roles_role_dne(self, expected_log, admin_user): authentication = JWTCommonAuth() authentication.user = admin_user - authentication.token = {'object_roles': {'Junk': ['a']}} + authentication.gateway_claims = {'object_roles': {'Junk': ['a']}} # Use gateway_claims instead of token with expected_log(default_logger, 'error', 'Unable to grant'): authentication.process_rbac_permissions() @@ -338,7 +337,7 @@ def test_process_rbac_permissions_object_role_exists_object_exists( ): authentication = JWTCommonAuth() authentication.user = admin_user - authentication.token = { + authentication.gateway_claims = { # Use gateway_claims instead of token 'objects': {'organization': [{'ansible_id': organization.resource.ansible_id, 'name': organization.name}]}, 'object_roles': object_roles, } @@ -349,7 +348,7 @@ def test_process_rbac_permissions_object_role_exists_object_exists( def test_process_rbac_permissions_org_duplicate_name_error(self, expected_log, admin_user, organization, organization_admin_role): authentication = JWTCommonAuth() authentication.user = admin_user - authentication.token = { + authentication.gateway_claims = { # Use gateway_claims instead of token 'objects': {'organization': [{'ansible_id': str(uuid4()), 'name': organization.name}]}, 'object_roles': {"Organization Admin": {'content_type': 'organization', 'objects': [0]}}, } @@ -368,7 +367,7 @@ def test_process_rbac_permissions_removed_when_removed_from_jwt(self, admin_user authentication = JWTCommonAuth() authentication.user = admin_user - authentication.token = { + authentication.gateway_claims = { # Use gateway_claims instead of token 'objects': {'organization': [{'ansible_id': organization.resource.ansible_id, 'name': organization.name}]}, 'object_roles': {organization_admin_role.name: {'content_type': 'organization', 'objects': [0]}}, 'global_roles': ["Platform Auditor"], @@ -378,7 +377,7 @@ def test_process_rbac_permissions_removed_when_removed_from_jwt(self, admin_user assert RoleUserAssignment.objects.filter(user=admin_user).count() == 2 - authentication.token = {} + authentication.gateway_claims = {} # Use gateway_claims instead of token authentication.process_rbac_permissions() @@ -387,11 +386,13 @@ def test_process_rbac_permissions_removed_when_removed_from_jwt(self, admin_user @pytest.mark.django_db def test_get_or_create_resource_invalid_content_type(self): authentication = JWTCommonAuth() + authentication.gateway_claims = {} # Set gateway_claims assert authentication.get_or_create_resource('junk', {'ansible_id': uuid4()}) == (None, None) @pytest.mark.django_db def test_get_or_create_resource_organization(self): authentication = JWTCommonAuth() + authentication.gateway_claims = {} # Set gateway_claims data = {'ansible_id': uuid4(), 'name': 'Test Organization'} assert not Organization.objects.filter(name=data['name']).exists() assert not Resource.objects.filter(ansible_id=data['ansible_id']).exists() @@ -404,7 +405,7 @@ def test_get_or_create_resource_organization(self): def test_get_or_create_resource_team(self): authentication = JWTCommonAuth() org_name = 'Testing Org Name' - authentication.token = { + authentication.gateway_claims = { # Use gateway_claims instead of token 'objects': { 'organization': [ { @@ -428,6 +429,278 @@ def test_get_or_create_resource_team(self): assert Resource.objects.filter(ansible_id=data['ansible_id']).exists() assert Team.objects.filter(name=data['name']).exists() + @pytest.mark.django_db + def test_fetch_jwt_claims_from_gateway_success(self): + """Test successful fetching of JWT claims from gateway.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + mock_claims = {"objects": {"organization": [], "team": []}, "object_roles": {}, "global_roles": [], "claims_hash": "test_hash_123"} + + # Mock the response from gateway + with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client') as mock_client: + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_claims + + mock_client_instance = mock.Mock() + mock_client_instance.get_jwt_claims.return_value = mock_response + mock_client.return_value = mock_client_instance + + result = authentication._fetch_jwt_claims_from_gateway(user_ansible_id) + assert result == mock_claims + mock_client_instance.get_jwt_claims.assert_called_once_with(user_ansible_id) + + @pytest.mark.django_db + def test_fetch_jwt_claims_from_gateway_invalid_json(self, caplog): + """Test handling of invalid JSON response from gateway.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + # Mock the response from gateway with invalid JSON + with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client') as mock_client: + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.side_effect = ValueError("Invalid JSON") + mock_response.headers = {'Content-Type': 'text/html'} + mock_response.text = "Error page" + + mock_client_instance = mock.Mock() + mock_client_instance.get_jwt_claims.return_value = mock_response + mock_client.return_value = mock_client_instance + + with caplog.at_level(logging.ERROR): + result = authentication._fetch_jwt_claims_from_gateway(user_ansible_id) + assert result is None + assert "Invalid JSON response from gateway" in caplog.text + assert "text/html" in caplog.text + + @pytest.mark.django_db + def test_fetch_jwt_claims_from_gateway_non_200_response(self, caplog): + """Test handling of non-200 status code from gateway.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + # Mock the response from gateway with 404 + with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client') as mock_client: + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_response.text = "Not found" + + mock_client_instance = mock.Mock() + mock_client_instance.get_jwt_claims.return_value = mock_response + mock_client.return_value = mock_client_instance + + with caplog.at_level(logging.WARNING): + result = authentication._fetch_jwt_claims_from_gateway(user_ansible_id) + assert result is None + assert "Failed to retrieve JWT claims from gateway" in caplog.text + assert "404" in caplog.text + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_no_hash(self): + """Test that claims are fetched when no claims_hash in token.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + # No claims_hash means we should fetch + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, None) is True + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_hash_changed(self): + """Test that claims are fetched when claims_hash has changed.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + # Mock cache to return different hash + with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_hash"): + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, "new_hash") is True + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_cached(self): + """Test that cached claims are used when hash hasn't changed.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + cached_claims = {"global_roles": ["test"]} + + # Mock cache to return same hash and cached claims + with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="same_hash"): + with mock.patch.object(authentication.cache, 'get_cached_claims', return_value=cached_claims): + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, "same_hash") is False + assert authentication.gateway_claims == cached_claims + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_cache_miss(self): + """Test that claims are fetched when hash matches but no cached claims exist.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + # Mock cache to return same hash but no cached claims + with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="same_hash"): + with mock.patch.object(authentication.cache, 'get_cached_claims', return_value=None): + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, "same_hash") is True + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_hash_recalculation_matches(self, admin_user): + """Test that claims are not fetched when recalculated hash matches current hash from token.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + user_ansible_id = str(admin_user.resource.ansible_id) + + # Mock claims functions to return predictable data + mock_user_claims = { + 'objects': {'organization': []}, + 'object_roles': {}, + 'global_roles': ['System Auditor'] + } + mock_hashable_claims = { + 'global_roles': ['System Auditor'], + 'object_roles': {} + } + recalculated_hash = "recalculated_hash_value" + + # Mock cache to return different hash (triggering recalculation) + with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_cached_hash"): + # Mock the claims functions to return our test data + with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims', return_value=mock_user_claims): + with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims_hashable_form', return_value=mock_hashable_claims): + with mock.patch('ansible_base.jwt_consumer.common.auth.get_claims_hash', return_value=recalculated_hash): + # When recalculated hash matches current hash, should return False + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, recalculated_hash) is False + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_hash_recalculation_still_differs(self, admin_user): + """Test that claims are fetched when recalculated hash still differs from current hash.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + user_ansible_id = str(admin_user.resource.ansible_id) + + # Mock claims functions to return predictable data + mock_user_claims = { + 'objects': {'organization': []}, + 'object_roles': {}, + 'global_roles': ['System Auditor'] + } + mock_hashable_claims = { + 'global_roles': ['System Auditor'], + 'object_roles': {} + } + recalculated_hash = "recalculated_hash_value" + current_hash = "different_current_hash" + + # Mock cache to return different hash (triggering recalculation) + with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_cached_hash"): + # Mock the claims functions to return our test data + with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims', return_value=mock_user_claims): + with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims_hashable_form', return_value=mock_hashable_claims): + with mock.patch('ansible_base.jwt_consumer.common.auth.get_claims_hash', return_value=recalculated_hash): + # When recalculated hash still differs from current hash, should return True + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, current_hash) is True + + @pytest.mark.django_db + def test_should_fetch_claims_from_gateway_hash_recalculation_exception(self, admin_user): + """Test that claims are fetched when hash recalculation fails with exception.""" + authentication = JWTCommonAuth() + authentication.user = admin_user + user_ansible_id = str(admin_user.resource.ansible_id) + + # Mock cache to return different hash (triggering recalculation) + with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_cached_hash"): + # Mock get_user_claims to raise an exception + with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims', side_effect=Exception("Test exception")): + # When recalculation fails, should return True (fallback behavior) + assert authentication._should_fetch_claims_from_gateway(user_ansible_id, "current_hash") is True + + @pytest.mark.django_db + def test_fetch_jwt_claims_from_gateway_exception(self, caplog): + """Test handling of exceptions when fetching JWT claims from gateway.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + # Mock client to raise an exception + with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client') as mock_client: + mock_client.side_effect = Exception("Connection error") + + with caplog.at_level(logging.ERROR): + result = authentication._fetch_jwt_claims_from_gateway(user_ansible_id) + assert result is None + assert "Error fetching JWT claims from gateway" in caplog.text + assert "Connection error" in caplog.text + + @pytest.mark.django_db + def test_cache_claims_hash_with_valid_data(self): + """Test caching of claims hash and gateway claims.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + claims_hash = "test_hash_456" + gateway_claims = {"global_roles": ["test"], "object_roles": {}} + + authentication.gateway_claims = gateway_claims + + with mock.patch.object(authentication.cache, 'set_claims_hash') as mock_set_hash: + with mock.patch.object(authentication.cache, 'set_cached_claims') as mock_set_claims: + authentication._cache_claims_hash(user_ansible_id, claims_hash) + + mock_set_hash.assert_called_once_with(user_ansible_id, claims_hash) + mock_set_claims.assert_called_once_with(user_ansible_id, gateway_claims) + + @pytest.mark.django_db + def test_cache_claims_hash_with_no_hash(self): + """Test that caching is skipped when no claims_hash is provided.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + authentication.gateway_claims = {"global_roles": ["test"]} + + with mock.patch.object(authentication.cache, 'set_claims_hash') as mock_set_hash: + with mock.patch.object(authentication.cache, 'set_cached_claims') as mock_set_claims: + authentication._cache_claims_hash(user_ansible_id, None) + + mock_set_hash.assert_not_called() + mock_set_claims.assert_not_called() + + @pytest.mark.django_db + def test_cache_claims_hash_with_no_gateway_claims(self): + """Test that caching is skipped when no gateway claims are available.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + authentication.gateway_claims = None + + with mock.patch.object(authentication.cache, 'set_claims_hash') as mock_set_hash: + with mock.patch.object(authentication.cache, 'set_cached_claims') as mock_set_claims: + authentication._cache_claims_hash(user_ansible_id, "test_hash") + + mock_set_hash.assert_not_called() + mock_set_claims.assert_not_called() + + @pytest.mark.django_db + def test_fetch_jwt_claims_uses_resource_service_path(self): + """Test that JWT claims fetching uses RESOURCE_SERVICE_PATH setting.""" + authentication = JWTCommonAuth() + user_ansible_id = str(uuid4()) + + mock_claims = {"global_roles": [], "object_roles": {}} + + with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client') as mock_client: + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_claims + + mock_client_instance = mock.Mock() + mock_client_instance.get_jwt_claims.return_value = mock_response + mock_client.return_value = mock_client_instance + + # Mock the settings to verify correct path is used + with mock.patch('ansible_base.jwt_consumer.common.auth.getattr') as mock_getattr: + mock_getattr.return_value = "/custom/api/path/service-index/" + + result = authentication._fetch_jwt_claims_from_gateway(user_ansible_id) + + # Verify that getattr was called to get RESOURCE_SERVICE_PATH + mock_getattr.assert_called_once() + # Verify the client was created with the setting value + mock_client.assert_called_once_with("/custom/api/path/service-index/") + assert result == mock_claims + class TestJWTAuthentication: def test_authenticate(self, jwt_token, django_user_model, mocked_http, test_encryption_public_key): diff --git a/test_app/tests/jwt_consumer/hub/test_auth.py b/test_app/tests/jwt_consumer/hub/test_auth.py index 261f5f232..7125b0061 100644 --- a/test_app/tests/jwt_consumer/hub/test_auth.py +++ b/test_app/tests/jwt_consumer/hub/test_auth.py @@ -102,10 +102,10 @@ def get_resource_content(*args, **kwargs): # Add the user to the org and the team. Galaxy doesn't have # a concept of org&team admin yet so we don't care about those. - auth.common_auth.token = { - "global_roles": { - 'Platform Auditor': {}, - }, + auth.common_auth.gateway_claims = { + "global_roles": [ + 'Platform Auditor', + ], "object_roles": { 'Organization Admin': {'content_type': 'organization', 'objects': [0]}, 'Organization Member': {'content_type': 'organization', 'objects': [0]}, @@ -137,7 +137,7 @@ def get_resource_content(*args, **kwargs): assert RoleUserAssignment.objects.filter(user=testuser, role_definition=platform_auditor_role).count() == 1 # REVOKE EVERYTHING AND RECHECK ... - auth.common_auth.token = {} + auth.common_auth.gateway_claims = {} auth.process_permissions() assert RoleUserAssignment.objects.filter(user=testuser, role_definition=team_member_role, object_id=testteam.pk).count() == 0 assert RoleUserAssignment.objects.filter(user=testuser, role_definition=team_admin_role, object_id=testteam.pk).count() == 0