diff --git a/ansible_base/authentication/middleware.py b/ansible_base/authentication/middleware.py index 33d011b1e..dad1e21e7 100644 --- a/ansible_base/authentication/middleware.py +++ b/ansible_base/authentication/middleware.py @@ -1,11 +1,16 @@ import logging +from collections import defaultdict +from urllib.parse import urlsplit from django.contrib.auth import BACKEND_SESSION_KEY from django.core.exceptions import ImproperlyConfigured +from django.middleware.csrf import CsrfViewMiddleware from django.utils.deprecation import MiddlewareMixin +from django.utils.functional import cached_property from social_django.middleware import SocialAuthExceptionMiddleware from ansible_base.authentication.authenticator_plugins.utils import get_authenticator_plugins +from ansible_base.lib.utils.settings import get_setting logger = logging.getLogger('ansible_base.authentication.middleware') @@ -51,3 +56,46 @@ def get_redirect_uri(self, request, exception): backend_name = getattr(backend, "name", "unknown-backend") logger.error(f"Auth failure for backend {backend_name} - {repr(exception)}, redirecting to {error_url}") return error_url + + +class AnsibleBaseCsrfViewMiddleware(CsrfViewMiddleware): + """ + CsrfViewMiddleware subclass that reads CSRF_TRUSTED_ORIGINS using + ansible_base.lib.utils.settings.get_setting instead of directly from + Django settings. + + This allows the setting to be dynamically loaded from various sources + as configured by the ANSIBLE_BASE_SETTINGS_FUNCTION setting. + + Overrides all cached properties that access settings.CSRF_TRUSTED_ORIGINS + to use get_setting instead. + """ + + @cached_property + def csrf_trusted_origins_hosts(self): + """ + Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. + """ + csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) + return [urlsplit(origin).netloc.lstrip("*") for origin in csrf_trusted_origins] + + @cached_property + def allowed_origins_exact(self): + """ + Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. + """ + csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) + return {origin for origin in csrf_trusted_origins if "*" not in origin} + + @cached_property + def allowed_origin_subdomains(self): + """ + Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS. + A mapping of allowed schemes to list of allowed netlocs, where all + subdomains of the netloc are allowed. + """ + csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', []) + allowed_origin_subdomains = defaultdict(list) + for parsed in (urlsplit(origin) for origin in csrf_trusted_origins if "*" in origin): + allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip("*")) + return allowed_origin_subdomains diff --git a/ansible_base/authentication/session.py b/ansible_base/authentication/session.py index a9eb63f79..b867cc211 100644 --- a/ansible_base/authentication/session.py +++ b/ansible_base/authentication/session.py @@ -1,10 +1,49 @@ -from rest_framework import authentication +from rest_framework import authentication, exceptions + +from ansible_base.authentication.middleware import ( + AnsibleBaseCsrfViewMiddleware, +) + + +class AnsibleBaseCSRFCheck(AnsibleBaseCsrfViewMiddleware): + """ + Custom CSRF check class that uses AnsibleBaseCsrfViewMiddleware + instead of Django's CsrfViewMiddleware for CSRF validation. + + This ensures that CSRF_TRUSTED_ORIGINS is read using get_setting + instead of directly from Django settings. + """ + + def _reject(self, request, reason): + # Return the failure reason instead of an HttpResponse + return reason class SessionAuthentication(authentication.SessionAuthentication): """ This class allows us to fail with a 401 if the user is not authenticated. + + Uses AnsibleBaseCsrfViewMiddleware for CSRF checking instead of Django's + default CsrfViewMiddleware, allowing CSRF_TRUSTED_ORIGINS to be read + dynamically using get_setting. """ def authenticate_header(self, request): return "Session" + + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication using + AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware. + """ + + def dummy_get_response(request): # pragma: no cover + return None + + check = AnsibleBaseCSRFCheck(dummy_get_response) + # populates request.META['CSRF_COOKIE'], which is used in process_view() + check.process_request(request) + reason = check.process_view(request, None, (), {}) + if reason: + # CSRF failed, bail with explicit error message + raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) diff --git a/test_app/tests/authentication/test_middleware.py b/test_app/tests/authentication/test_middleware.py index 627bee931..932c99d06 100644 --- a/test_app/tests/authentication/test_middleware.py +++ b/test_app/tests/authentication/test_middleware.py @@ -1,7 +1,16 @@ +from unittest.mock import patch + from django.conf import settings from social_core.exceptions import AuthException -from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware +from ansible_base.authentication.middleware import ( + AnsibleBaseCsrfViewMiddleware, + SocialExceptionHandlerMiddleware, +) +from ansible_base.authentication.session import ( + AnsibleBaseCSRFCheck, + SessionAuthentication, +) def test_social_exception_handler_mw(): @@ -21,3 +30,138 @@ def __init__(self): mw = SocialExceptionHandlerMiddleware(None) url = mw.get_redirect_uri(Request(), AuthException("test")) assert url == "/?auth_failed" + + +def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts(): + """Test that csrf_trusted_origins_hosts uses get_setting.""" + test_origins = ['https://example.com', 'https://*.test.com'] + + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = test_origins + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + result = middleware.csrf_trusted_origins_hosts + + mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) + # Should strip * from netloc + assert result == ['example.com', '.test.com'] + + +def test_ansible_base_csrf_view_middleware_allowed_origins_exact(): + """Test that allowed_origins_exact uses get_setting.""" + test_origins = ['https://example.com', 'https://*.test.com'] + + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = test_origins + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + result = middleware.allowed_origins_exact + + mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) + # Should only include origins without * + assert result == {'https://example.com'} + + +def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains(): + """Test that allowed_origin_subdomains uses get_setting.""" + test_origins = ['https://*.example.com', 'http://*.test.com'] + + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = test_origins + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + result = middleware.allowed_origin_subdomains + + mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', []) + # Should group by scheme and strip * + expected = {'https': ['.example.com'], 'http': ['.test.com']} + assert dict(result) == expected + + +def test_ansible_base_csrf_view_middleware_default_value(): + """Test that middleware returns empty/default values when setting is empty.""" + with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting: + mock_get_setting.return_value = [] + + middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None) + + # Test all three properties + assert middleware.csrf_trusted_origins_hosts == [] + assert middleware.allowed_origins_exact == set() + assert dict(middleware.allowed_origin_subdomains) == {} + + # get_setting should be called three times (once for each property) + assert mock_get_setting.call_count == 3 + + +def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware(): + """Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware.""" + csrf_check = AnsibleBaseCSRFCheck(lambda request: None) + assert isinstance(csrf_check, AnsibleBaseCsrfViewMiddleware) + + +def test_ansible_base_csrf_check_reject_method(): + """Test that AnsibleBaseCSRFCheck._reject returns the reason.""" + csrf_check = AnsibleBaseCSRFCheck(lambda request: None) + reason = "Test CSRF failure reason" + result = csrf_check._reject(None, reason) + assert result == reason + + +def test_session_authentication_uses_ansible_base_csrf_check(): + """Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation.""" + from unittest.mock import Mock + + # Create a mock request with an authenticated user + mock_request = Mock() + mock_request._request = Mock() + mock_request._request.user = Mock() + mock_request._request.user.is_active = True + + # Mock the AnsibleBaseCSRFCheck to track its usage + with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class: + mock_csrf_check = Mock() + mock_csrf_check.process_request.return_value = None + mock_csrf_check.process_view.return_value = None # No CSRF error + mock_csrf_check_class.return_value = mock_csrf_check + + # Create SessionAuthentication instance and call enforce_csrf + session_auth = SessionAuthentication() + session_auth.enforce_csrf(mock_request) + + # Verify AnsibleBaseCSRFCheck was instantiated + mock_csrf_check_class.assert_called_once() + + # Verify process_request and process_view were called + mock_csrf_check.process_request.assert_called_once_with(mock_request) + mock_csrf_check.process_view.assert_called_once_with(mock_request, None, (), {}) + + +def test_session_authentication_csrf_failure_raises_permission_denied(): + """Test that SessionAuthentication raises PermissionDenied when CSRF fails.""" + from unittest.mock import Mock + + from rest_framework.exceptions import PermissionDenied + + # Create a mock request with an authenticated user + mock_request = Mock() + mock_request._request = Mock() + mock_request._request.user = Mock() + mock_request._request.user.is_active = True + + # Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason + with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class: + mock_csrf_check = Mock() + mock_csrf_check.process_request.return_value = None + mock_csrf_check.process_view.return_value = "CSRF token missing" # CSRF error + mock_csrf_check_class.return_value = mock_csrf_check + + # Create SessionAuthentication instance and call enforce_csrf + session_auth = SessionAuthentication() + + # Should raise PermissionDenied with the CSRF failure reason + try: + session_auth.enforce_csrf(mock_request) + assert False, "Expected PermissionDenied to be raised" + except PermissionDenied as e: + assert "CSRF Failed: CSRF token missing" in str(e)