Skip to content
Open
28 changes: 26 additions & 2 deletions ansible_base/jwt_consumer/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ansible_base.lib.logging.runtime import log_excess_runtime
from ansible_base.lib.utils.auth import get_user_by_ansible_id
from ansible_base.lib.utils.translations import translatableConditionally as _
from ansible_base.rbac.claims import get_user_claims, get_user_claims_hashable_form, get_claims_hash
from ansible_base.resource_registry.models import Resource, ResourceType
from ansible_base.resource_registry.rest_client import get_resource_server_client
from ansible_base.resource_registry.signals.handlers import no_reverse_sync
Expand Down Expand Up @@ -177,8 +178,31 @@ def _should_fetch_claims_from_gateway(self, user_ansible_id, current_claims_hash

cached_hash = self.cache.get_claims_hash(user_ansible_id)
if cached_hash != current_claims_hash:
logger.debug(f"Claims hash changed for user {user_ansible_id}: cached={cached_hash}, current={current_claims_hash}")
return True
# Recalculate hash from local database to verify the mismatch
# It is possible that the cached hash is stale, but the local data is synced to the resource server.
# This is an optimization to avoid fetching claims from the resource server if the local data is synced.
logger.debug(f"Claims hash mismatch for user {user_ansible_id}: cached={cached_hash}, current={current_claims_hash}")
logger.debug(f"Recalculating hash from local database for user {user_ansible_id}")

try:
# Get user claims from local database
user_claims = get_user_claims(self.user)
hashable_claims = get_user_claims_hashable_form(user_claims)
recalculated_hash = get_claims_hash(hashable_claims)

logger.debug(f"Recalculated hash for user {user_ansible_id}: {recalculated_hash}")
# Compare recalculated hash with current hash from token
if recalculated_hash != current_claims_hash:
logger.debug(f"Claims hash still differs after recalculation for user {user_ansible_id}: local={recalculated_hash}, current={current_claims_hash}")
return True
else:
logger.debug(f"Recalculated hash matches current hash for user {user_ansible_id}")
return False

except Exception as e:
logger.error(f"Failed to recalculate claims hash for user {user_ansible_id}: {e}")
# If recalculation fails, fall back to treating it as a hash mismatch
return True

# Hash matches cached value, try to get cached claims
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, don't do that.

If hashes match, do nothing. Don't do approximately nothing, do exactly nothing. Don't save the claims for later. If the hashes match that means the claims have already been saved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will update to do approximately nothing 🤣

cached_claims = self.cache.get_cached_claims(user_ansible_id)
Expand Down
72 changes: 72 additions & 0 deletions test_app/tests/jwt_consumer/common/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ansible_base.jwt_consumer.common.cert import JWTCert, JWTCertException
from ansible_base.jwt_consumer.common.exceptions import InvalidTokenException
from ansible_base.lib.utils.translations import translatableConditionally as _
from ansible_base.rbac.claims import get_user_claims, get_user_claims_hashable_form, get_claims_hash
from ansible_base.rbac.models import RoleDefinition, RoleUserAssignment
from ansible_base.rbac.permission_registry import permission_registry
from ansible_base.resource_registry.models import Resource
Expand Down Expand Up @@ -539,6 +540,77 @@ def test_should_fetch_claims_from_gateway_cache_miss(self):
with mock.patch.object(authentication.cache, 'get_cached_claims', return_value=None):
assert authentication._should_fetch_claims_from_gateway(user_ansible_id, "same_hash") is True

@pytest.mark.django_db
def test_should_fetch_claims_from_gateway_hash_recalculation_matches(self, admin_user):
"""Test that claims are not fetched when recalculated hash matches current hash from token."""
authentication = JWTCommonAuth()
authentication.user = admin_user
user_ansible_id = str(admin_user.resource.ansible_id)

# Mock claims functions to return predictable data
mock_user_claims = {
'objects': {'organization': []},
'object_roles': {},
'global_roles': ['System Auditor']
}
mock_hashable_claims = {
'global_roles': ['System Auditor'],
'object_roles': {}
}
recalculated_hash = "recalculated_hash_value"

# Mock cache to return different hash (triggering recalculation)
with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_cached_hash"):
# Mock the claims functions to return our test data
with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims', return_value=mock_user_claims):
with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims_hashable_form', return_value=mock_hashable_claims):
with mock.patch('ansible_base.jwt_consumer.common.auth.get_claims_hash', return_value=recalculated_hash):
# When recalculated hash matches current hash, should return False
assert authentication._should_fetch_claims_from_gateway(user_ansible_id, recalculated_hash) is False

@pytest.mark.django_db
def test_should_fetch_claims_from_gateway_hash_recalculation_still_differs(self, admin_user):
"""Test that claims are fetched when recalculated hash still differs from current hash."""
authentication = JWTCommonAuth()
authentication.user = admin_user
user_ansible_id = str(admin_user.resource.ansible_id)

# Mock claims functions to return predictable data
mock_user_claims = {
'objects': {'organization': []},
'object_roles': {},
'global_roles': ['System Auditor']
}
mock_hashable_claims = {
'global_roles': ['System Auditor'],
'object_roles': {}
}
recalculated_hash = "recalculated_hash_value"
current_hash = "different_current_hash"

# Mock cache to return different hash (triggering recalculation)
with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_cached_hash"):
# Mock the claims functions to return our test data
with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims', return_value=mock_user_claims):
with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims_hashable_form', return_value=mock_hashable_claims):
with mock.patch('ansible_base.jwt_consumer.common.auth.get_claims_hash', return_value=recalculated_hash):
# When recalculated hash still differs from current hash, should return True
assert authentication._should_fetch_claims_from_gateway(user_ansible_id, current_hash) is True

@pytest.mark.django_db
def test_should_fetch_claims_from_gateway_hash_recalculation_exception(self, admin_user):
"""Test that claims are fetched when hash recalculation fails with exception."""
authentication = JWTCommonAuth()
authentication.user = admin_user
user_ansible_id = str(admin_user.resource.ansible_id)

# Mock cache to return different hash (triggering recalculation)
with mock.patch.object(authentication.cache, 'get_claims_hash', return_value="old_cached_hash"):
# Mock get_user_claims to raise an exception
with mock.patch('ansible_base.jwt_consumer.common.auth.get_user_claims', side_effect=Exception("Test exception")):
# When recalculation fails, should return True (fallback behavior)
assert authentication._should_fetch_claims_from_gateway(user_ansible_id, "current_hash") is True

@pytest.mark.django_db
def test_fetch_jwt_claims_from_gateway_exception(self, caplog):
"""Test handling of exceptions when fetching JWT claims from gateway."""
Expand Down
Loading