Skip to content

Commit 77db991

Browse files
committed
Refactor JWT claims handling to use gateway endpoint
JWT claims are now exclusively fetched from the gateway service-index API instead of being included in the JWT token. Deprecated fields (objects, object_roles, global_roles) are removed from token processing and all RBAC logic now relies on gateway claims. Added helper to ResourceAPIClient for fetching claims, and updated tests to reflect the new claims source.
1 parent 28ef3a6 commit 77db991

File tree

5 files changed

+164
-19
lines changed

5 files changed

+164
-19
lines changed

ansible_base/jwt_consumer/common/auth.py

Lines changed: 100 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ansible_base.lib.utils.auth import get_user_by_ansible_id
2020
from ansible_base.lib.utils.translations import translatableConditionally as _
2121
from ansible_base.resource_registry.models import Resource, ResourceType
22+
from ansible_base.resource_registry.rest_client import get_resource_server_client
2223
from ansible_base.resource_registry.signals.handlers import no_reverse_sync
2324

2425
logger = logging.getLogger("ansible_base.jwt_consumer.common.auth")
@@ -52,6 +53,7 @@ def __init__(self, user_fields=default_mapped_user_fields) -> None:
5253
self.cache = JWTCache()
5354
self.user = None
5455
self.token = None
56+
self.gateway_claims = None # Store claims from gateway
5557

5658
@log_excess_runtime(logger, debug_cutoff=0.01)
5759
def parse_jwt_token(self, request):
@@ -142,10 +144,77 @@ def parse_jwt_token(self, request):
142144
resource.service_id = self.token['service_id']
143145
resource.save(update_fields=['ansible_id', 'service_id'])
144146

147+
# Check if claims need to be refreshed from gateway based on claims_hash
148+
user_ansible_id = self.token['sub']
149+
current_claims_hash = self.token.get('claims_hash')
150+
151+
if self._should_fetch_claims_from_gateway(user_ansible_id, current_claims_hash):
152+
logger.debug(f"Claims hash changed or not cached, fetching claims from gateway for user {user_ansible_id}")
153+
jwt_claims = self._fetch_jwt_claims_from_gateway(user_ansible_id)
154+
155+
if jwt_claims:
156+
self.gateway_claims = jwt_claims
157+
self._cache_claims_hash(user_ansible_id, current_claims_hash)
158+
logger.debug(f"Successfully loaded and cached gateway claims for user {user_ansible_id}")
159+
else:
160+
logger.error(f"Failed to fetch claims from gateway for user {user_ansible_id}. RBAC processing will not be available.")
161+
# Note: We don't raise an exception here to allow basic authentication to succeed
162+
# RBAC processing will fail gracefully with appropriate error messages
163+
else:
164+
logger.debug(f"Using cached claims for user {user_ansible_id} (claims_hash unchanged)")
145165
setattr(self.user, "resource_api_actions", self.token.get("resource_api_actions", None))
146166

147167
logger.info(f"User {self.user.username} authenticated from JWT auth")
148168

169+
def _should_fetch_claims_from_gateway(self, user_ansible_id, current_claims_hash):
170+
"""
171+
Determine if claims should be fetched from gateway based on claims_hash comparison.
172+
Returns True if claims need to be fetched (hash changed or not cached).
173+
"""
174+
if not current_claims_hash:
175+
logger.debug(f"No claims_hash in token for user {user_ansible_id}, will fetch claims")
176+
return True
177+
178+
cached_hash = self.cache.get_claims_hash(user_ansible_id)
179+
if cached_hash != current_claims_hash:
180+
logger.debug(f"Claims hash changed for user {user_ansible_id}: cached={cached_hash}, current={current_claims_hash}")
181+
return True
182+
183+
# Hash matches cached value, try to get cached claims
184+
cached_claims = self.cache.get_cached_claims(user_ansible_id)
185+
if cached_claims:
186+
self.gateway_claims = cached_claims
187+
return False
188+
else:
189+
logger.debug(f"Claims hash matches but no cached claims found for user {user_ansible_id}")
190+
return True
191+
192+
def _cache_claims_hash(self, user_ansible_id, claims_hash):
193+
"""Cache the claims hash and gateway claims for future comparisons."""
194+
if claims_hash and self.gateway_claims:
195+
self.cache.set_claims_hash(user_ansible_id, claims_hash)
196+
self.cache.set_cached_claims(user_ansible_id, self.gateway_claims)
197+
198+
def _fetch_jwt_claims_from_gateway(self, user_ansible_id):
199+
"""
200+
Fetch JWT claims for a user from the gateway service-index API.
201+
Returns None if claims cannot be retrieved.
202+
"""
203+
try:
204+
client = get_resource_server_client("service-index")
205+
response = client.get_jwt_claims(user_ansible_id)
206+
207+
if response.status_code == 200:
208+
claims = response.json()
209+
logger.debug(f"Retrieved JWT claims from gateway for user {user_ansible_id}")
210+
return claims
211+
else:
212+
logger.warning(f"Failed to retrieve JWT claims from gateway for user {user_ansible_id}: " f"{response.status_code}")
213+
return None
214+
except Exception as e:
215+
logger.error(f"Error fetching JWT claims from gateway for user {user_ansible_id}: {e}")
216+
return None
217+
149218
def log_and_raise(self, conditional_translate_object, expand_values={}, error_code=None):
150219
logger.error(conditional_translate_object.not_translated() % expand_values)
151220
translated_error_message = conditional_translate_object.translated() % expand_values
@@ -226,7 +295,8 @@ def validate_token(self, unencrypted_token, decryption_key, request_id=None):
226295
return validated_body
227296

228297
def decode_jwt_token(self, unencrypted_token, decryption_key, additional_options={}):
229-
local_required_field = ["sub", "user_data", "exp", "objects", "object_roles", "global_roles", "version"]
298+
# Core required fields - claims_hash is now required to track permission changes
299+
local_required_field = ["sub", "user_data", "exp", "version", "claims_hash"]
230300
options = {"require": local_required_field}
231301
options.update(additional_options)
232302
return jwt.decode(
@@ -259,16 +329,23 @@ def get_role_definition(self, name: str) -> Optional[Model]:
259329
def process_rbac_permissions(self):
260330
"""
261331
This is a default process_permissions which should be usable if you are using RBAC from DAB
332+
Uses gateway claims data exclusively - no fallback to JWT token fields
262333
"""
263-
if self.token is None or self.user is None:
264-
logger.error("Unable to process rbac permissions because user or token is not defined, please call authenticate first")
334+
if self.user is None:
335+
logger.error("Unable to process rbac permissions because user is not defined, please call authenticate first")
336+
return
337+
338+
if self.gateway_claims is None:
339+
logger.error("Unable to process rbac permissions because gateway claims are not available. Ensure gateway jwt_claims endpoint is accessible.")
265340
return
266341

267342
from ansible_base.rbac.models import RoleUserAssignment
268343

269344
role_diff = RoleUserAssignment.objects.filter(user=self.user, role_definition__name__in=settings.ANSIBLE_BASE_JWT_MANAGED_ROLES)
270345

271-
for system_role_name in self.token.get("global_roles", []):
346+
# Process global roles from gateway claims
347+
global_roles = self.gateway_claims.get("global_roles", [])
348+
for system_role_name in global_roles:
272349
logger.debug(f"Processing system role {system_role_name} for {self.user.username}")
273350
rd = self.get_role_definition(system_role_name)
274351
if rd:
@@ -282,7 +359,11 @@ def process_rbac_permissions(self):
282359
logger.error(f"Unable to grant {self.user.username} system level role {system_role_name} because it does not exist")
283360
continue
284361

285-
for object_role_name in self.token.get('object_roles', {}).keys():
362+
# Process object roles from gateway claims
363+
object_roles = self.gateway_claims.get('object_roles', {})
364+
objects = self.gateway_claims.get('objects', {})
365+
366+
for object_role_name in object_roles.keys():
286367
rd = self.get_role_definition(object_role_name)
287368
if rd is None:
288369
logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it does not exist")
@@ -291,11 +372,11 @@ def process_rbac_permissions(self):
291372
logger.error(f"Unable to grant {self.user.username} object role {object_role_name} because it is not a JWT managed role")
292373
continue
293374

294-
object_type = self.token['object_roles'][object_role_name]['content_type']
295-
object_indexes = self.token['object_roles'][object_role_name]['objects']
375+
object_type = object_roles[object_role_name]['content_type']
376+
object_indexes = object_roles[object_role_name]['objects']
296377

297378
for index in object_indexes:
298-
object_data = self.token['objects'][object_type][index]
379+
object_data = objects[object_type][index]
299380
try:
300381
resource, obj = self.get_or_create_resource(object_type, object_data)
301382
except IntegrityError as e:
@@ -310,7 +391,7 @@ def process_rbac_permissions(self):
310391
role_diff = role_diff.exclude(pk=assignment.pk)
311392
logger.info(f"Granted user {self.user.username} role {object_role_name} to object {obj.name} with ansible_id {object_data['ansible_id']}")
312393

313-
# Remove all permissions not authorized by the JWT
394+
# Remove all permissions not authorized by the gateway claims
314395
for role_assignment in role_diff:
315396
rd = role_assignment.role_definition
316397
content_object = role_assignment.content_object
@@ -322,9 +403,17 @@ def process_rbac_permissions(self):
322403
def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optional[Resource], Optional[Model]]:
323404
"""
324405
Gets or creates a resource from a content type and its default data
406+
Uses gateway claims exclusively - no fallback to JWT token fields
325407
326408
This can only build or get organizations or teams
409+
Args:
410+
content_type: Type of content ('team', 'organization')
411+
data: Resource data dictionary
327412
"""
413+
if self.gateway_claims is None:
414+
logger.error("Unable to create resource because gateway claims are not available")
415+
return None, None
416+
328417
object_ansible_id = data['ansible_id']
329418
try:
330419
resource = Resource.objects.get(ansible_id=object_ansible_id)
@@ -337,7 +426,7 @@ def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optiona
337426
if content_type == 'team':
338427
# For a team we first have to make sure the org is there
339428
org_id = data['org']
340-
organization_data = self.token['objects']["organization"][org_id]
429+
organization_data = self.gateway_claims['objects']["organization"][org_id]
341430

342431
# Now that we have the org we can build a team
343432
org_resource, _ = self.get_or_create_resource("organization", organization_data)
@@ -359,7 +448,7 @@ def get_or_create_resource(self, content_type: str, data: dict) -> Tuple[Optiona
359448

360449
return resource, resource.content_object
361450
else:
362-
logger.error(f"build_resource_stub does not know how to build an object of type {type}")
451+
logger.error(f"build_resource_stub does not know how to build an object of type {content_type}")
363452
return None, None
364453

365454

ansible_base/jwt_consumer/common/cache.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,23 @@ def get_key_from_cache(self) -> Optional[str]:
4646

4747
def set_key_in_cache(self, key: str) -> None:
4848
cache.set(cache_key, key, timeout=self.get_cache_timeout())
49+
50+
def get_claims_hash(self, user_ansible_id: str) -> Optional[str]:
51+
"""Get cached claims hash for a user."""
52+
claims_hash_key = f"jwt_claims_hash_{user_ansible_id}"
53+
return cache.get(claims_hash_key, None)
54+
55+
def set_claims_hash(self, user_ansible_id: str, claims_hash: str) -> None:
56+
"""Set cached claims hash for a user."""
57+
claims_hash_key = f"jwt_claims_hash_{user_ansible_id}"
58+
cache.set(claims_hash_key, claims_hash, timeout=self.get_cache_timeout())
59+
60+
def get_cached_claims(self, user_ansible_id: str) -> Optional[dict]:
61+
"""Get cached gateway claims for a user."""
62+
claims_key = f"jwt_gateway_claims_{user_ansible_id}"
63+
return cache.get(claims_key, None)
64+
65+
def set_cached_claims(self, user_ansible_id: str, claims: dict) -> None:
66+
"""Set cached gateway claims for a user."""
67+
claims_key = f"jwt_gateway_claims_{user_ansible_id}"
68+
cache.set(claims_key, claims, timeout=self.get_cache_timeout())

ansible_base/jwt_consumer/hub/auth.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,15 @@ def process_permissions(self):
4040
# the teams this user should have a "shared" [!local] assignment to
4141
member_teams = []
4242

43-
for role_name in self.common_auth.token.get('object_roles', {}).keys():
43+
# Process object roles from gateway claims instead of JWT token
44+
if not self.common_auth.gateway_claims:
45+
logger.error("Unable to process permissions because gateway claims are not available")
46+
return
47+
48+
for role_name in self.common_auth.gateway_claims.get('object_roles', {}).keys():
4449
if role_name.startswith('Team'):
45-
for object_index in self.common_auth.token['object_roles'][role_name]['objects']:
46-
team_data = self.common_auth.token['objects']['team'][object_index]
50+
for object_index in self.common_auth.gateway_claims['object_roles'][role_name]['objects']:
51+
team_data = self.common_auth.gateway_claims['objects']['team'][object_index]
4752
ansible_id = team_data['ansible_id']
4853
try:
4954
team = Resource.objects.get(ansible_id=ansible_id).content_object
@@ -83,7 +88,7 @@ def process_permissions(self):
8388
roledef.give_permission(self.common_auth.user, team)
8489

8590
auditor_roledef = RoleDefinition.objects.get(name='Platform Auditor')
86-
if "Platform Auditor" in self.common_auth.token.get('global_roles', []):
91+
if "Platform Auditor" in self.common_auth.gateway_claims.get('global_roles', []):
8792
auditor_roledef.give_global_permission(self.common_auth.user)
8893
else:
8994
auditor_roledef.remove_global_permission(self.common_auth.user)

ansible_base/resource_registry/rest_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ def list_team_assignments(self, team_ansible_id: Optional[str] = None, filters:
189189
return self._make_request("get", "role-team-assignments/", params=params)
190190

191191
def sync_assignment(self, assignment):
192-
from ansible_base.rbac.service_api.serializers import ServiceRoleTeamAssignmentSerializer, ServiceRoleUserAssignmentSerializer
192+
from ansible_base.rbac.service_api.serializers import (
193+
ServiceRoleTeamAssignmentSerializer,
194+
ServiceRoleUserAssignmentSerializer,
195+
)
193196

194197
if assignment._meta.model_name == 'roleuserassignment':
195198
serializer = ServiceRoleUserAssignmentSerializer(assignment)
@@ -227,3 +230,7 @@ def _sync_assignment(self, data, giving=True):
227230
url = f'role-{actor_type}-assignments/{sub_url}/'
228231

229232
return self._make_request(method="post", path=url, data=data)
233+
234+
def get_jwt_claims(self, user_ansible_id):
235+
"""Get JWT claims for a user from the gateway service-index."""
236+
return self._make_request("get", f"jwt_claims/{user_ansible_id}/")

test_app/tests/conftest.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,9 +542,8 @@ def __init__(self):
542542
"email": "[email protected]",
543543
"is_superuser": False,
544544
},
545-
"objects": {},
546-
"object_roles": {},
547-
"global_roles": [],
545+
# claims_hash is required to track permission changes
546+
"claims_hash": "test_hash_123",
548547
}
549548

550549
def encrypt_token(self):
@@ -595,6 +594,31 @@ def mocked_gateway_view_get_request(self, *args, **kwargs):
595594
return MockedHttp()
596595

597596

597+
@pytest.fixture
598+
def mock_gateway_jwt_claims():
599+
"""Mock for gateway JWT claims endpoint."""
600+
mock_claims = {
601+
"global_roles": ["Platform Auditor"],
602+
"object_roles": {"Organization Admin": {"content_type": "organization", "objects": [0]}},
603+
"objects": {"organization": [{"ansible_id": "test-org-id", "name": "Test Organization"}], "team": []},
604+
}
605+
606+
class MockResponse:
607+
def __init__(self, json_data, status_code=200):
608+
self.json_data = json_data
609+
self.status_code = status_code
610+
611+
def json(self):
612+
return self.json_data
613+
614+
class MockResourceAPIClient:
615+
def get_jwt_claims(self, user_ansible_id):
616+
return MockResponse(mock_claims)
617+
618+
with mock.patch('ansible_base.jwt_consumer.common.auth.get_resource_server_client', return_value=MockResourceAPIClient()):
619+
yield mock_claims
620+
621+
598622
@pytest.fixture
599623
def system_user(db, no_log_messages):
600624
with no_log_messages():

0 commit comments

Comments
 (0)