Skip to content
Open
153 changes: 142 additions & 11 deletions ansible_base/jwt_consumer/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from ansible_base.lib.logging.runtime import log_excess_runtime
from ansible_base.lib.utils.auth import get_user_by_ansible_id
from ansible_base.lib.utils.translations import translatableConditionally as _
from ansible_base.rbac.claims import get_user_claims, get_user_claims_hashable_form, get_claims_hash
from ansible_base.resource_registry.models import Resource, ResourceType
from ansible_base.resource_registry.rest_client import get_resource_server_client
from ansible_base.resource_registry.signals.handlers import no_reverse_sync

logger = logging.getLogger("ansible_base.jwt_consumer.common.auth")
Expand Down Expand Up @@ -52,6 +54,7 @@ def __init__(self, user_fields=default_mapped_user_fields) -> None:
self.cache = JWTCache()
self.user = None
self.token = None
self.gateway_claims = None # Store claims from gateway
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very strongly convinced this shouldn't be on self. Let the de-referenced be garbage collected.


@log_excess_runtime(logger, debug_cutoff=0.01)
def parse_jwt_token(self, request):
Expand Down Expand Up @@ -142,10 +145,118 @@ def parse_jwt_token(self, request):
resource.service_id = self.token['service_id']
resource.save(update_fields=['ansible_id', 'service_id'])

# Check if claims need to be refreshed from gateway based on claims_hash
user_ansible_id = self.token['sub']
current_claims_hash = self.token.get('claims_hash')

if self._should_fetch_claims_from_gateway(user_ansible_id, current_claims_hash):
logger.debug(f"Claims hash changed or not cached, fetching claims from gateway for user {user_ansible_id}")
jwt_claims = self._fetch_jwt_claims_from_gateway(user_ansible_id)

if jwt_claims:
self.gateway_claims = jwt_claims
self._cache_claims_hash(user_ansible_id, current_claims_hash)
logger.debug(f"Successfully loaded and cached gateway claims for user {user_ansible_id}")
else:
logger.error(f"Failed to fetch claims from gateway for user {user_ansible_id}. RBAC processing will not be available.")
# Note: We don't raise an exception here to allow basic authentication to succeed
# RBAC processing will fail gracefully with appropriate error messages
else:
logger.debug(f"Using cached claims for user {user_ansible_id} (claims_hash unchanged)")
setattr(self.user, "resource_api_actions", self.token.get("resource_api_actions", None))

logger.info(f"User {self.user.username} authenticated from JWT auth")

def _should_fetch_claims_from_gateway(self, user_ansible_id, current_claims_hash):
"""
Determine if claims should be fetched from gateway based on claims_hash comparison.
Returns True if claims need to be fetched (hash changed or not cached).
"""
if not current_claims_hash:
logger.debug(f"No claims_hash in token for user {user_ansible_id}, will fetch claims")
return True

cached_hash = self.cache.get_claims_hash(user_ansible_id)
if cached_hash != current_claims_hash:
# Recalculate hash from local database to verify the mismatch
# It is possible that the cached hash is stale, but the local data is synced to the resource server.
# This is an optimization to avoid fetching claims from the resource server if the local data is synced.
logger.debug(f"Claims hash mismatch for user {user_ansible_id}: cached={cached_hash}, from token={current_claims_hash}")
logger.debug(f"Recalculating hash from local database for user {user_ansible_id}")

try:
# Get user claims from local database
user_claims = get_user_claims(self.user)
hashable_claims = get_user_claims_hashable_form(user_claims)
recalculated_hash = get_claims_hash(hashable_claims)

# Compare recalculated hash with current hash from token
if recalculated_hash != current_claims_hash:
logger.debug(f"Claims hash still differs after recalculation for user {user_ansible_id}: local={recalculated_hash}, from token={current_claims_hash}")
return True
else:
logger.debug(f"Recalculated local hash matches token hash for user {user_ansible_id}")
logger.debug(f"Caching claims hash for user {user_ansible_id}: {recalculated_hash}")
self._cache_claims_hash(user_ansible_id, recalculated_hash)
return False

except Exception as e:
logger.error(f"Failed to recalculate claims hash for user {user_ansible_id}: {e}")
# If recalculation fails, fall back to treating it as a hash mismatch
return True

# Hash matches cached value, try to get cached claims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, don't do that.

If hashes match, do nothing. Don't do approximately nothing, do exactly nothing. Don't save the claims for later. If the hashes match that means the claims have already been saved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update to do approximately nothing 🤣

cached_claims = self.cache.get_cached_claims(user_ansible_id)
if cached_claims:
self.gateway_claims = cached_claims
return False
else:
logger.debug(f"Claims hash matches but no cached claims found for user {user_ansible_id}")
return True

def _cache_claims_hash(self, user_ansible_id, claims_hash):
"""Cache the claims hash and gateway claims for future comparisons."""
if claims_hash and self.gateway_claims:
self.cache.set_claims_hash(user_ansible_id, claims_hash)
self.cache.set_cached_claims(user_ansible_id, self.gateway_claims)

def _fetch_jwt_claims_from_gateway(self, user_ansible_id):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _fetch_jwt_claims_from_gateway(self, user_ansible_id):
def fetch_jwt_claims_from_gateway(self, user_ansible_id) -> dict[str,dict[dict[dict[str,list],list],dict]:

"""
Fetch JWT claims for a user from the gateway service-index API.
Returns None if claims cannot be retrieved.
"""
try:
# Use the full service path from settings, not just "service-index"
# This should be something like "/api/gateway/v1/service-index/"
service_path = getattr(settings, "RESOURCE_SERVICE_PATH", "/api/gateway/v1/service-index/")
client = get_resource_server_client(service_path)
response = client.get_jwt_claims(user_ansible_id)

if response.status_code == 200:
# Try to parse JSON, but handle empty or invalid responses
try:
claims = response.json()
logger.debug(f"Retrieved JWT claims from gateway for user {user_ansible_id}")
return claims
except ValueError:
# Log the actual response content for debugging
logger.error(
f"Invalid JSON response from gateway for user {user_ansible_id}. "
f"Status: {response.status_code}, "
f"Content-Type: {response.headers.get('Content-Type', 'unknown')}, "
f"Response text: {response.text[:500]}" # First 500 chars to avoid huge logs
)
return None
else:
logger.warning(
f"Failed to retrieve JWT claims from gateway for user {user_ansible_id}: "
f"Status: {response.status_code}, Response: {response.text[:500]}"
)
return None
except Exception as e:
logger.error(f"Error fetching JWT claims from gateway for user {user_ansible_id}: {e}")
return None

def log_and_raise(self, conditional_translate_object, expand_values={}, error_code=None):
logger.error(conditional_translate_object.not_translated() % expand_values)
translated_error_message = conditional_translate_object.translated() % expand_values
Expand Down Expand Up @@ -226,7 +337,8 @@ def validate_token(self, unencrypted_token, decryption_key, request_id=None):
return validated_body

def decode_jwt_token(self, unencrypted_token, decryption_key, additional_options={}):
local_required_field = ["sub", "user_data", "exp", "objects", "object_roles", "global_roles", "version"]
# Core required fields - claims_hash is now required to track permission changes
local_required_field = ["sub", "user_data", "exp", "version", "claims_hash"]
options = {"require": local_required_field}
options.update(additional_options)
return jwt.decode(
Expand Down Expand Up @@ -259,16 +371,23 @@ def get_role_definition(self, name: str) -> Optional[Model]:
def process_rbac_permissions(self):
"""
This is a default process_permissions which should be usable if you are using RBAC from DAB
Uses gateway claims data exclusively - no fallback to JWT token fields
"""
if self.token is None or self.user is None:
logger.error("Unable to process rbac permissions because user or token is not defined, please call authenticate first")
if self.user is None:
logger.error("Unable to process rbac permissions because user is not defined, please call authenticate first")
return

if self.gateway_claims is None:
logger.error("Unable to process rbac permissions because gateway claims are not available. Ensure gateway jwt_claims endpoint is accessible.")
return

from ansible_base.rbac.models import RoleUserAssignment

role_diff = RoleUserAssignment.objects.filter(user=self.user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES)

for system_role_name in self.token.get("global_roles", []):
# Process global roles from gateway claims
global_roles = self.gateway_claims.get("global_roles", [])
for system_role_name in global_roles:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're trying too hard to not change the existing code. And in this case it's not good code.

Make a new method, like save_claims(user, claims). See other stuff in ansible_base/rbac/claims.py, it should fit in with that crowd.

It should take the claims, and make them true for that user. You don't have to refactor the logic itself here (although you should, you don't have to), but you do really need to refactor the interface.

This method should be callable from unit tests, and it should be called from unit tests. We shouldn't need to see the JWT auth class within a mile of that logic.

logger.debug(f"Processing system role {system_role_name} for {self.user.username}")
rd = self.get_role_definition(system_role_name)
if rd:
Expand All @@ -282,7 +401,11 @@ def process_rbac_permissions(self):
logger.error(f"Unable to grant {self.user.username} system level role {system_role_name} because it does not exist")
continue

for object_role_name in self.token.get('object_roles', {}).keys():
# Process object roles from gateway claims
object_roles = self.gateway_claims.get('object_roles', {})
objects = self.gateway_claims.get('objects', {})

for object_role_name in object_roles.keys():
rd = self.get_role_definition(object_role_name)
if rd is None:
logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it does not exist")
Expand All @@ -291,11 +414,11 @@ def process_rbac_permissions(self):
logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it is not a JWT managed role")
continue

object_type = self.token['object_roles'][object_role_name]['content_type']
object_indexes = self.token['object_roles'][object_role_name]['objects']
object_type = object_roles[object_role_name]['content_type']
object_indexes = object_roles[object_role_name]['objects']

for index in object_indexes:
object_data = self.token['objects'][object_type][index]
object_data = objects[object_type][index]
try:
resource, obj = self.get_or_create_resource(object_type, object_data)
except IntegrityError as e:
Expand All @@ -310,7 +433,7 @@ def process_rbac_permissions(self):
role_diff = role_diff.exclude(pk=assignment.pk)
logger.info(f"Granted user {self.user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}")

# Remove all permissions not authorized by the JWT
# Remove all permissions not authorized by the gateway claims
for role_assignment in role_diff:
rd = role_assignment.role_definition
content_object = role_assignment.content_object
Expand All @@ -322,9 +445,17 @@ def process_rbac_permissions(self):
def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optional[Resource], Optional[Model]]:
"""
Gets or creates a resource from a content type and its default data
Uses gateway claims exclusively - no fallback to JWT token fields

This can only build or get organizations or teams
Args:
content_type: Type of content ('team', 'organization')
data: Resource data dictionary
"""
if self.gateway_claims is None:
logger.error("Unable to create resource because gateway claims are not available")
return None, None

object_ansible_id = data['ansible_id']
try:
resource = Resource.objects.get(ansible_id=object_ansible_id)
Expand All @@ -337,7 +468,7 @@ def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optiona
if content_type == 'team':
# For a team we first have to make sure the org is there
org_id = data['org']
organization_data = self.token['objects']["organization"][org_id]
organization_data = self.gateway_claims['objects']["organization"][org_id]

# Now that we have the org we can build a team
org_resource, _ = self.get_or_create_resource("organization", organization_data)
Expand All @@ -359,7 +490,7 @@ def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optiona

return resource, resource.content_object
else:
logger.error(f"build_resource_stub does not know how to build an object of type {type}")
logger.error(f"build_resource_stub does not know how to build an object of type {content_type}")
Copy link
Preview

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message references 'build_resource_stub' but this method is actually called 'get_or_create_resource'. The error message should be updated to reflect the correct method name for clarity.

Suggested change
logger.error(f"build_resource_stub does not know how to build an object of type {content_type}")
logger.error(f"get_or_create_resource does not know how to build an object of type {content_type}")

Copilot uses AI. Check for mistakes.

return None, None


Expand Down
20 changes: 20 additions & 0 deletions ansible_base/jwt_consumer/common/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,23 @@ def get_key_from_cache(self) -> Optional[str]:

def set_key_in_cache(self, key: str) -> None:
cache.set(cache_key, key, timeout=self.get_cache_timeout())

def get_claims_hash(self, user_ansible_id: str) -> Optional[str]:
"""Get cached claims hash for a user."""
claims_hash_key = f"jwt_claims_hash_{user_ansible_id}"
return cache.get(claims_hash_key, None)

def set_claims_hash(self, user_ansible_id: str, claims_hash: str) -> None:
"""Set cached claims hash for a user."""
claims_hash_key = f"jwt_claims_hash_{user_ansible_id}"
cache.set(claims_hash_key, claims_hash, timeout=self.get_cache_timeout())

def get_cached_claims(self, user_ansible_id: str) -> Optional[dict]:
"""Get cached gateway claims for a user."""
claims_key = f"jwt_gateway_claims_{user_ansible_id}"
return cache.get(claims_key, None)

def set_cached_claims(self, user_ansible_id: str, claims: dict) -> None:
"""Set cached gateway claims for a user."""
claims_key = f"jwt_gateway_claims_{user_ansible_id}"
cache.set(claims_key, claims, timeout=self.get_cache_timeout())
13 changes: 9 additions & 4 deletions ansible_base/jwt_consumer/hub/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,15 @@ def process_permissions(self):
# the teams this user should have a "shared" [!local] assignment to
member_teams = []

for role_name in self.common_auth.token.get('object_roles', {}).keys():
# Process object roles from gateway claims instead of JWT token
if self.common_auth.gateway_claims is None:
logger.error("Unable to process permissions because gateway claims are not available")
return

for role_name in self.common_auth.gateway_claims.get('object_roles', {}).keys():
if role_name.startswith('Team'):
for object_index in self.common_auth.token['object_roles'][role_name]['objects']:
team_data = self.common_auth.token['objects']['team'][object_index]
for object_index in self.common_auth.gateway_claims['object_roles'][role_name]['objects']:
team_data = self.common_auth.gateway_claims['objects']['team'][object_index]
ansible_id = team_data['ansible_id']
try:
team = Resource.objects.get(ansible_id=ansible_id).content_object
Expand Down Expand Up @@ -83,7 +88,7 @@ def process_permissions(self):
roledef.give_permission(self.common_auth.user, team)

auditor_roledef = RoleDefinition.objects.get(name='Platform Auditor')
if "Platform Auditor" in self.common_auth.token.get('global_roles', []):
if "Platform Auditor" in self.common_auth.gateway_claims.get('global_roles', []):
auditor_roledef.give_global_permission(self.common_auth.user)
else:
auditor_roledef.remove_global_permission(self.common_auth.user)
9 changes: 8 additions & 1 deletion ansible_base/resource_registry/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def list_team_assignments(self, team_ansible_id: Optional[str] = None, filters:
return self._make_request("get", "role-team-assignments/", params=params)

def sync_assignment(self, assignment):
from ansible_base.rbac.service_api.serializers import ServiceRoleTeamAssignmentSerializer, ServiceRoleUserAssignmentSerializer
from ansible_base.rbac.service_api.serializers import (
ServiceRoleTeamAssignmentSerializer,
ServiceRoleUserAssignmentSerializer,
)

if assignment._meta.model_name == 'roleuserassignment':
serializer = ServiceRoleUserAssignmentSerializer(assignment)
Expand Down Expand Up @@ -227,3 +230,7 @@ def _sync_assignment(self, data, giving=True):
url = f'role-{actor_type}-assignments/{sub_url}/'

return self._make_request(method="post", path=url, data=data)

def get_jwt_claims(self, user_ansible_id):
"""Get JWT claims for a user from the gateway service-index."""
return self._make_request("get", f"jwt_claims/{user_ansible_id}/")
30 changes: 27 additions & 3 deletions test_app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,8 @@ def __init__(self):
"email": "[email protected]",
"is_superuser": False,
},
"objects": {},
"object_roles": {},
"global_roles": [],
# claims_hash is required to track permission changes
"claims_hash": "test_hash_123",
}

def encrypt_token(self):
Expand Down Expand Up @@ -595,6 +594,31 @@ def mocked_gateway_view_get_request(self, *args, **kwargs):
return MockedHttp()


@pytest.fixture
def mock_gateway_jwt_claims():
"""Mock for gateway JWT claims endpoint."""
mock_claims = {
"global_roles": ["Platform Auditor"],
"object_roles": {"Organization Admin": {"content_type": "organization", "objects": [0]}},
"objects": {"organization": [{"ansible_id": "test-org-id", "name": "Test Organization"}], "team": []},
}

class MockResponse:
def __init__(self, json_data, status_code=200):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

class MockResourceAPIClient:
def get_jwt_claims(self, user_ansible_id):
return MockResponse(mock_claims)

with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client', return_value=MockResourceAPIClient()):
yield mock_claims


@pytest.fixture
def system_user(db, no_log_messages):
with no_log_messages():
Expand Down
Loading
Loading