Skip to content

Commit 5abe950

Browse files
[AAP-47811] Update jwt_consumer to load user claims from gateway, if needed (#796)
Adapts JWT RBAC processing to use claims hash instead of full claims data: - JWT tokens will no longer contain full claims data, only a claims hash - Implement claims hash caching to track user permission state - On hash mismatch, first compute hash from user's local permissions before gateway fallback - Add gateway fallback to fetch fresh claims only when local permissions don't match - Fail authentication when claims validation cannot be completed This reduces JWT token size while maintaining secure RBAC permission handling through efficient caching and validation mechanisms.
1 parent 4438f72 commit 5abe950

File tree

7 files changed

+372
-88
lines changed

7 files changed

+372
-88
lines changed

ansible_base/jwt_consumer/common/auth.py

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import uuid
23
from datetime import datetime
34
from typing import Optional, Tuple
45

@@ -18,7 +19,9 @@
1819
from ansible_base.lib.logging.runtime import log_excess_runtime
1920
from ansible_base.lib.utils.auth import get_user_by_ansible_id
2021
from ansible_base.lib.utils.translations import translatableConditionally as _
22+
from ansible_base.rbac.claims import get_claims_hash, get_user_claims, get_user_claims_hashable_form
2123
from ansible_base.resource_registry.models import Resource, ResourceType
24+
from ansible_base.resource_registry.rest_client import get_resource_server_client
2225
from ansible_base.resource_registry.signals.handlers import no_reverse_sync
2326

2427
logger = logging.getLogger("ansible_base.jwt_consumer.common.auth")
@@ -226,7 +229,7 @@ def validate_token(self, unencrypted_token, decryption_key, request_id=None):
226229
return validated_body
227230

228231
def decode_jwt_token(self, unencrypted_token, decryption_key, additional_options={}):
229-
local_required_field = ["sub", "user_data", "exp", "objects", "object_roles", "global_roles", "version"]
232+
local_required_field = ["sub", "user_data", "exp", "claims_hash", "version"]
230233
options = {"require": local_required_field}
231234
options.update(additional_options)
232235
return jwt.decode(
@@ -258,17 +261,98 @@ def get_role_definition(self, name: str) -> Optional[Model]:
258261

259262
def process_rbac_permissions(self):
260263
"""
261-
This is a default process_permissions which should be usable if you are using RBAC from DAB
264+
Process RBAC permissions using claims hash logic
262265
"""
263266
if self.token is None or self.user is None:
264-
logger.error("Unable to process rbac permissions because user or token is not defined, please call authenticate first")
267+
logger.error("Unable to process rbac permissions because user or token is not defined")
265268
return
266269

270+
jwt_claims_hash = self.token.get("claims_hash")
271+
if not jwt_claims_hash:
272+
logger.error("No claims_hash found in JWT token")
273+
return
274+
275+
user_ansible_id = self.token.get("sub")
276+
if not user_ansible_id:
277+
logger.error("No subject (sub) found in JWT token")
278+
return
279+
280+
# Validate UUID format (consistent with rest of codebase)
281+
try:
282+
uuid.UUID(user_ansible_id)
283+
except (ValueError, TypeError):
284+
logger.error(f"Invalid UUID format for user_ansible_id: {user_ansible_id}")
285+
return
286+
287+
# Check cached claims hash
288+
cached_claims_hash = self.cache.get_cached_claims_hash(user_ansible_id)
289+
290+
if cached_claims_hash == jwt_claims_hash:
291+
logger.debug(f"Claims hash matches cached value for user {user_ansible_id}")
292+
return
293+
294+
# Calculate local claims hash
295+
local_claims = get_user_claims(self.user)
296+
local_hashable_claims = get_user_claims_hashable_form(local_claims)
297+
local_claims_hash = get_claims_hash(local_hashable_claims)
298+
299+
if local_claims_hash == jwt_claims_hash:
300+
logger.debug(f"Claims hash matches local calculation for user {user_ansible_id}")
301+
# Update cache with the correct hash
302+
self.cache.cache_claims_hash(user_ansible_id, jwt_claims_hash)
303+
return
304+
305+
# Claims hash mismatch - fetch from gateway
306+
logger.info(f"Claims hash mismatch for user {user_ansible_id}. JWT: {jwt_claims_hash}, Local: {local_claims_hash}. Fetching from gateway.")
307+
gateway_claims = self._fetch_jwt_claims_from_gateway(user_ansible_id)
308+
309+
if gateway_claims:
310+
# Extract claims structure from gateway response
311+
objects = gateway_claims.get('objects', {})
312+
object_roles = gateway_claims.get('object_roles', {})
313+
global_roles = gateway_claims.get('global_roles', [])
314+
315+
# Process the RBAC permissions with the gateway claims
316+
self._apply_rbac_permissions(objects, object_roles, global_roles)
317+
318+
# Update cache with the new hash
319+
self.cache.cache_claims_hash(user_ansible_id, jwt_claims_hash)
320+
else:
321+
self.log_and_raise(
322+
_("Unable to validate user permissions - gateway claims fetch failed for user %(user_ansible_id)s"), {"user_ansible_id": user_ansible_id}
323+
)
324+
325+
def _fetch_jwt_claims_from_gateway(self, user_ansible_id: str) -> Optional[dict]:
326+
"""
327+
Fetch JWT claims from the gateway endpoint using resource server client
328+
"""
329+
try:
330+
# Use the resource server client to make the request
331+
client = get_resource_server_client(service_path="api/gateway/v1")
332+
333+
logger.debug(f"Fetching claims from gateway for user {user_ansible_id}")
334+
response = client._make_request("GET", f"jwt_claims/{user_ansible_id}/")
335+
336+
if response.status_code == 200:
337+
claims_data = response.json()
338+
return claims_data
339+
else:
340+
logger.error(f"Gateway request failed with status {response.status_code}")
341+
return None
342+
343+
except Exception as e:
344+
logger.error(f"Error fetching claims from gateway: {e}")
345+
return None
346+
347+
def _apply_rbac_permissions(self, objects, object_roles, global_roles):
348+
"""
349+
Apply RBAC permissions from claims data
350+
"""
267351
from ansible_base.rbac.models import RoleUserAssignment
268352

269353
role_diff = RoleUserAssignment.objects.filter(user=self.user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES)
270354

271-
for system_role_name in self.token.get("global_roles", []):
355+
for system_role_name in global_roles:
272356
logger.debug(f"Processing system role {system_role_name} for {self.user.username}")
273357
rd = self.get_role_definition(system_role_name)
274358
if rd:
@@ -282,7 +366,7 @@ def process_rbac_permissions(self):
282366
logger.error(f"Unable to grant {self.user.username} system level role {system_role_name} because it does not exist")
283367
continue
284368

285-
for object_role_name in self.token.get('object_roles', {}).keys():
369+
for object_role_name in object_roles.keys():
286370
rd = self.get_role_definition(object_role_name)
287371
if rd is None:
288372
logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it does not exist")
@@ -291,11 +375,11 @@ def process_rbac_permissions(self):
291375
logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it is not a JWT managed role")
292376
continue
293377

294-
object_type = self.token['object_roles'][object_role_name]['content_type']
295-
object_indexes = self.token['object_roles'][object_role_name]['objects']
378+
object_type = object_roles[object_role_name]['content_type']
379+
object_indexes = object_roles[object_role_name]['objects']
296380

297381
for index in object_indexes:
298-
object_data = self.token['objects'][object_type][index]
382+
object_data = objects[object_type][index]
299383
try:
300384
resource, obj = self.get_or_create_resource(object_type, object_data)
301385
except IntegrityError as e:

ansible_base/jwt_consumer/common/cache.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ def check_user_in_cache(self, validated_body: dict) -> Tuple[bool, dict]:
3838
cache.set(validated_body["sub"], expected_cache_value, timeout=self.get_cache_timeout())
3939
return False, expected_cache_value
4040

41+
def cache_claims_hash(self, user_ansible_id: str, claims_hash: str) -> None:
42+
"""Cache the claims hash for a user"""
43+
cache_key = f"claims_hash_{user_ansible_id}"
44+
cache.set(cache_key, claims_hash, timeout=self.get_cache_timeout())
45+
logger.debug(f"Cached claims hash for user {user_ansible_id}: {claims_hash}")
46+
47+
def get_cached_claims_hash(self, user_ansible_id: str) -> Optional[str]:
48+
"""Get cached claims hash for a user"""
49+
cache_key = f"claims_hash_{user_ansible_id}"
50+
cached_hash = cache.get(cache_key, None)
51+
logger.debug(f"Retrieved cached claims hash for user {user_ansible_id}: {cached_hash}")
52+
return cached_hash
53+
4154
def get_key_from_cache(self) -> Optional[str]:
4255
# If we are not ignoring the cache (forcing a reload of the key), check it
4356
key = cache.get(cache_key, None)

ansible_base/jwt_consumer/hub/auth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def get_galaxy_models(self):
2222

2323
return Organization, Team
2424

25-
def process_permissions(self):
25+
def _apply_rbac_permissions(self, objects, object_roles, global_roles):
2626
# Map teams in the JWT to Automation Hub groups.
2727
Organization, Team = self.get_galaxy_models()
2828
self.team_content_type = ContentType.objects.get_for_model(Team)
@@ -40,10 +40,10 @@ def process_permissions(self):
4040
# the teams this user should have a "shared" [!local] assignment to
4141
member_teams = []
4242

43-
for role_name in self.common_auth.token.get('object_roles', {}).keys():
43+
for role_name in object_roles.keys():
4444
if role_name.startswith('Team'):
45-
for object_index in self.common_auth.token['object_roles'][role_name]['objects']:
46-
team_data = self.common_auth.token['objects']['team'][object_index]
45+
for object_index in object_roles[role_name]['objects']:
46+
team_data = objects['team'][object_index]
4747
ansible_id = team_data['ansible_id']
4848
try:
4949
team = Resource.objects.get(ansible_id=ansible_id).content_object
@@ -83,7 +83,7 @@ def process_permissions(self):
8383
roledef.give_permission(self.common_auth.user, team)
8484

8585
auditor_roledef = RoleDefinition.objects.get(name='Platform Auditor')
86-
if "Platform Auditor" in self.common_auth.token.get('global_roles', []):
86+
if "Platform Auditor" in global_roles:
8787
auditor_roledef.give_global_permission(self.common_auth.user)
8888
else:
8989
auditor_roledef.remove_global_permission(self.common_auth.user)

test_app/tests/conftest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,7 @@ def __init__(self):
542542
"email": "[email protected]",
543543
"is_superuser": False,
544544
},
545-
"objects": {},
546-
"object_roles": {},
547-
"global_roles": [],
545+
"claims_hash": "abc123def456",
548546
}
549547

550548
def encrypt_token(self):

0 commit comments

Comments
 (0)