Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions ansible_base/authentication/middleware.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import logging
from collections import defaultdict
from urllib.parse import urlsplit

from django.contrib.auth import BACKEND_SESSION_KEY
from django.core.exceptions import ImproperlyConfigured
from django.middleware.csrf import CsrfViewMiddleware
from django.utils.deprecation import MiddlewareMixin
from django.utils.functional import cached_property
from social_django.middleware import SocialAuthExceptionMiddleware

from ansible_base.authentication.authenticator_plugins.utils import get_authenticator_plugins
from ansible_base.lib.utils.settings import get_setting

logger = logging.getLogger('ansible_base.authentication.middleware')

Expand Down Expand Up @@ -51,3 +56,46 @@ def get_redirect_uri(self, request, exception):
backend_name = getattr(backend, "name", "unknown-backend")
logger.error(f"Auth failure for backend {backend_name} - {repr(exception)}, redirecting to {error_url}")
return error_url


class AnsibleBaseCsrfViewMiddleware(CsrfViewMiddleware):
"""
CsrfViewMiddleware subclass that reads CSRF_TRUSTED_ORIGINS using
ansible_base.lib.utils.settings.get_setting instead of directly from
Django settings.

This allows the setting to be dynamically loaded from various sources
as configured by the ANSIBLE_BASE_SETTINGS_FUNCTION setting.
Copy link
Member

@rochacbruno rochacbruno Jun 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻 I like the composability (is this a word?) here


Overrides all cached properties that access settings.CSRF_TRUSTED_ORIGINS
to use get_setting instead.
"""

@cached_property
def csrf_trusted_origins_hosts(self):
"""
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
"""
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
return [urlsplit(origin).netloc.lstrip("*") for origin in csrf_trusted_origins]

@cached_property
def allowed_origins_exact(self):
"""
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
"""
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
return {origin for origin in csrf_trusted_origins if "*" not in origin}

@cached_property
def allowed_origin_subdomains(self):
"""
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
A mapping of allowed schemes to list of allowed netlocs, where all
subdomains of the netloc are allowed.
"""
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
allowed_origin_subdomains = defaultdict(list)
for parsed in (urlsplit(origin) for origin in csrf_trusted_origins if "*" in origin):
allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip("*"))
return allowed_origin_subdomains
41 changes: 40 additions & 1 deletion ansible_base/authentication/session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,49 @@
from rest_framework import authentication
from rest_framework import authentication, exceptions

from ansible_base.authentication.middleware import (
AnsibleBaseCsrfViewMiddleware,
)


class AnsibleBaseCSRFCheck(AnsibleBaseCsrfViewMiddleware):
"""
Custom CSRF check class that uses AnsibleBaseCsrfViewMiddleware
instead of Django's CsrfViewMiddleware for CSRF validation.

This ensures that CSRF_TRUSTED_ORIGINS is read using get_setting
instead of directly from Django settings.
"""

def _reject(self, request, reason):
# Return the failure reason instead of an HttpResponse
return reason


class SessionAuthentication(authentication.SessionAuthentication):
"""
This class allows us to fail with a 401 if the user is not authenticated.

Uses AnsibleBaseCsrfViewMiddleware for CSRF checking instead of Django's
default CsrfViewMiddleware, allowing CSRF_TRUSTED_ORIGINS to be read
dynamically using get_setting.
"""

def authenticate_header(self, request):
return "Session"

def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication using
AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware.
"""

def dummy_get_response(request): # pragma: no cover
return None

check = AnsibleBaseCSRFCheck(dummy_get_response)
# populates request.META['CSRF_COOKIE'], which is used in process_view()
check.process_request(request)
reason = check.process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
146 changes: 145 additions & 1 deletion test_app/tests/authentication/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from unittest.mock import patch

from django.conf import settings
from social_core.exceptions import AuthException

from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware
from ansible_base.authentication.middleware import (
AnsibleBaseCsrfViewMiddleware,
SocialExceptionHandlerMiddleware,
)
from ansible_base.authentication.session import (
AnsibleBaseCSRFCheck,
SessionAuthentication,
)


def test_social_exception_handler_mw():
Expand All @@ -21,3 +30,138 @@ def __init__(self):
mw = SocialExceptionHandlerMiddleware(None)
url = mw.get_redirect_uri(Request(), AuthException("test"))
assert url == "/?auth_failed"


def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts():
"""Test that csrf_trusted_origins_hosts uses get_setting."""
test_origins = ['https://example.com', 'https://*.test.com']

with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
mock_get_setting.return_value = test_origins

middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
result = middleware.csrf_trusted_origins_hosts

mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
# Should strip * from netloc
assert result == ['example.com', '.test.com']


def test_ansible_base_csrf_view_middleware_allowed_origins_exact():
"""Test that allowed_origins_exact uses get_setting."""
test_origins = ['https://example.com', 'https://*.test.com']

with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
mock_get_setting.return_value = test_origins

middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
result = middleware.allowed_origins_exact

mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
# Should only include origins without *
assert result == {'https://example.com'}


def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains():
"""Test that allowed_origin_subdomains uses get_setting."""
test_origins = ['https://*.example.com', 'http://*.test.com']

with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
mock_get_setting.return_value = test_origins

middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
result = middleware.allowed_origin_subdomains

mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
# Should group by scheme and strip *
expected = {'https': ['.example.com'], 'http': ['.test.com']}
assert dict(result) == expected


def test_ansible_base_csrf_view_middleware_default_value():
"""Test that middleware returns empty/default values when setting is empty."""
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
mock_get_setting.return_value = []

middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)

# Test all three properties
assert middleware.csrf_trusted_origins_hosts == []
assert middleware.allowed_origins_exact == set()
assert dict(middleware.allowed_origin_subdomains) == {}

# get_setting should be called three times (once for each property)
assert mock_get_setting.call_count == 3


def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware():
"""Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware."""
csrf_check = AnsibleBaseCSRFCheck(lambda request: None)
assert isinstance(csrf_check, AnsibleBaseCsrfViewMiddleware)


def test_ansible_base_csrf_check_reject_method():
"""Test that AnsibleBaseCSRFCheck._reject returns the reason."""
csrf_check = AnsibleBaseCSRFCheck(lambda request: None)
reason = "Test CSRF failure reason"
result = csrf_check._reject(None, reason)
assert result == reason


def test_session_authentication_uses_ansible_base_csrf_check():
"""Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation."""
from unittest.mock import Mock

# Create a mock request with an authenticated user
mock_request = Mock()
mock_request._request = Mock()
mock_request._request.user = Mock()
mock_request._request.user.is_active = True

# Mock the AnsibleBaseCSRFCheck to track its usage
with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class:
mock_csrf_check = Mock()
mock_csrf_check.process_request.return_value = None
mock_csrf_check.process_view.return_value = None # No CSRF error
mock_csrf_check_class.return_value = mock_csrf_check

# Create SessionAuthentication instance and call enforce_csrf
session_auth = SessionAuthentication()
session_auth.enforce_csrf(mock_request)

# Verify AnsibleBaseCSRFCheck was instantiated
mock_csrf_check_class.assert_called_once()

# Verify process_request and process_view were called
mock_csrf_check.process_request.assert_called_once_with(mock_request)
mock_csrf_check.process_view.assert_called_once_with(mock_request, None, (), {})


def test_session_authentication_csrf_failure_raises_permission_denied():
"""Test that SessionAuthentication raises PermissionDenied when CSRF fails."""
from unittest.mock import Mock

from rest_framework.exceptions import PermissionDenied

# Create a mock request with an authenticated user
mock_request = Mock()
mock_request._request = Mock()
mock_request._request.user = Mock()
mock_request._request.user.is_active = True

# Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason
with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class:
mock_csrf_check = Mock()
mock_csrf_check.process_request.return_value = None
mock_csrf_check.process_view.return_value = "CSRF token missing" # CSRF error
mock_csrf_check_class.return_value = mock_csrf_check

# Create SessionAuthentication instance and call enforce_csrf
session_auth = SessionAuthentication()

# Should raise PermissionDenied with the CSRF failure reason
try:
session_auth.enforce_csrf(mock_request)
assert False, "Expected PermissionDenied to be raised"
except PermissionDenied as e:
assert "CSRF Failed: CSRF token missing" in str(e)
Loading