Skip to content

Commit bdfc0e2

Browse files
committed
Add AnsibleBaseCsrfViewMiddleware and update SessionAuthentication to use get_setting
- Add AnsibleBaseCsrfViewMiddleware that reads CSRF_TRUSTED_ORIGINS using ansible_base.lib.utils.settings.get_setting instead of directly from Django settings - Override all three cached properties (csrf_trusted_origins_hosts, allowed_origins_exact, allowed_origin_subdomains) to use get_setting for dynamic configuration - Add AnsibleBaseCSRFCheck class that inherits from AnsibleBaseCsrfViewMiddleware - Modify SessionAuthentication.enforce_csrf to use AnsibleBaseCSRFCheck instead of Django's default CSRFCheck - Add comprehensive tests for both middleware and session authentication CSRF functionality - Enables CSRF_TRUSTED_ORIGINS to be dynamically loaded from various sources via ANSIBLE_BASE_SETTINGS_FUNCTION while maintaining backward compatibility
1 parent d758999 commit bdfc0e2

File tree

3 files changed

+233
-2
lines changed

3 files changed

+233
-2
lines changed

ansible_base/authentication/middleware.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import logging
2+
from collections import defaultdict
3+
from urllib.parse import urlsplit
24

35
from django.contrib.auth import BACKEND_SESSION_KEY
46
from django.core.exceptions import ImproperlyConfigured
7+
from django.middleware.csrf import CsrfViewMiddleware
58
from django.utils.deprecation import MiddlewareMixin
9+
from django.utils.functional import cached_property
610
from social_django.middleware import SocialAuthExceptionMiddleware
711

812
from ansible_base.authentication.authenticator_plugins.utils import get_authenticator_plugins
13+
from ansible_base.lib.utils.settings import get_setting
914

1015
logger = logging.getLogger('ansible_base.authentication.middleware')
1116

@@ -51,3 +56,46 @@ def get_redirect_uri(self, request, exception):
5156
backend_name = getattr(backend, "name", "unknown-backend")
5257
logger.error(f"Auth failure for backend {backend_name} - {repr(exception)}, redirecting to {error_url}")
5358
return error_url
59+
60+
61+
class AnsibleBaseCsrfViewMiddleware(CsrfViewMiddleware):
62+
"""
63+
CsrfViewMiddleware subclass that reads CSRF_TRUSTED_ORIGINS using
64+
ansible_base.lib.utils.settings.get_setting instead of directly from
65+
Django settings.
66+
67+
This allows the setting to be dynamically loaded from various sources
68+
as configured by the ANSIBLE_BASE_SETTINGS_FUNCTION setting.
69+
70+
Overrides all cached properties that access settings.CSRF_TRUSTED_ORIGINS
71+
to use get_setting instead.
72+
"""
73+
74+
@cached_property
75+
def csrf_trusted_origins_hosts(self):
76+
"""
77+
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
78+
"""
79+
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
80+
return [urlsplit(origin).netloc.lstrip("*") for origin in csrf_trusted_origins]
81+
82+
@cached_property
83+
def allowed_origins_exact(self):
84+
"""
85+
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
86+
"""
87+
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
88+
return {origin for origin in csrf_trusted_origins if "*" not in origin}
89+
90+
@cached_property
91+
def allowed_origin_subdomains(self):
92+
"""
93+
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
94+
A mapping of allowed schemes to list of allowed netlocs, where all
95+
subdomains of the netloc are allowed.
96+
"""
97+
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
98+
allowed_origin_subdomains = defaultdict(list)
99+
for parsed in (urlsplit(origin) for origin in csrf_trusted_origins if "*" in origin):
100+
allowed_origin_subdomains[parsed.scheme].append(parsed.netloc.lstrip("*"))
101+
return allowed_origin_subdomains
Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,49 @@
1-
from rest_framework import authentication
1+
from rest_framework import authentication, exceptions
2+
3+
from ansible_base.authentication.middleware import (
4+
AnsibleBaseCsrfViewMiddleware,
5+
)
6+
7+
8+
class AnsibleBaseCSRFCheck(AnsibleBaseCsrfViewMiddleware):
9+
"""
10+
Custom CSRF check class that uses AnsibleBaseCsrfViewMiddleware
11+
instead of Django's CsrfViewMiddleware for CSRF validation.
12+
13+
This ensures that CSRF_TRUSTED_ORIGINS is read using get_setting
14+
instead of directly from Django settings.
15+
"""
16+
17+
def _reject(self, request, reason):
18+
# Return the failure reason instead of an HttpResponse
19+
return reason
220

321

422
class SessionAuthentication(authentication.SessionAuthentication):
523
"""
624
This class allows us to fail with a 401 if the user is not authenticated.
25+
26+
Uses AnsibleBaseCsrfViewMiddleware for CSRF checking instead of Django's
27+
default CsrfViewMiddleware, allowing CSRF_TRUSTED_ORIGINS to be read
28+
dynamically using get_setting.
729
"""
830

931
def authenticate_header(self, request):
1032
return "Session"
33+
34+
def enforce_csrf(self, request):
35+
"""
36+
Enforce CSRF validation for session based authentication using
37+
AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware.
38+
"""
39+
40+
def dummy_get_response(request): # pragma: no cover
41+
return None
42+
43+
check = AnsibleBaseCSRFCheck(dummy_get_response)
44+
# populates request.META['CSRF_COOKIE'], which is used in process_view()
45+
check.process_request(request)
46+
reason = check.process_view(request, None, (), {})
47+
if reason:
48+
# CSRF failed, bail with explicit error message
49+
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)

test_app/tests/authentication/test_middleware.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1+
from unittest.mock import patch
2+
13
from django.conf import settings
24
from social_core.exceptions import AuthException
35

4-
from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware
6+
from ansible_base.authentication.middleware import (
7+
AnsibleBaseCsrfViewMiddleware,
8+
SocialExceptionHandlerMiddleware,
9+
)
10+
from ansible_base.authentication.session import (
11+
AnsibleBaseCSRFCheck,
12+
SessionAuthentication,
13+
)
514

615

716
def test_social_exception_handler_mw():
@@ -21,3 +30,138 @@ def __init__(self):
2130
mw = SocialExceptionHandlerMiddleware(None)
2231
url = mw.get_redirect_uri(Request(), AuthException("test"))
2332
assert url == "/?auth_failed"
33+
34+
35+
def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts():
36+
"""Test that csrf_trusted_origins_hosts uses get_setting."""
37+
test_origins = ['https://example.com', 'https://*.test.com']
38+
39+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
40+
mock_get_setting.return_value = test_origins
41+
42+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
43+
result = middleware.csrf_trusted_origins_hosts
44+
45+
mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
46+
# Should strip * from netloc
47+
assert result == ['example.com', '.test.com']
48+
49+
50+
def test_ansible_base_csrf_view_middleware_allowed_origins_exact():
51+
"""Test that allowed_origins_exact uses get_setting."""
52+
test_origins = ['https://example.com', 'https://*.test.com']
53+
54+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
55+
mock_get_setting.return_value = test_origins
56+
57+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
58+
result = middleware.allowed_origins_exact
59+
60+
mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
61+
# Should only include origins without *
62+
assert result == {'https://example.com'}
63+
64+
65+
def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains():
66+
"""Test that allowed_origin_subdomains uses get_setting."""
67+
test_origins = ['https://*.example.com', 'http://*.test.com']
68+
69+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
70+
mock_get_setting.return_value = test_origins
71+
72+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
73+
result = middleware.allowed_origin_subdomains
74+
75+
mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
76+
# Should group by scheme and strip *
77+
expected = {'https': ['.example.com'], 'http': ['.test.com']}
78+
assert dict(result) == expected
79+
80+
81+
def test_ansible_base_csrf_view_middleware_default_value():
82+
"""Test that middleware returns empty/default values when setting is empty."""
83+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
84+
mock_get_setting.return_value = []
85+
86+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
87+
88+
# Test all three properties
89+
assert middleware.csrf_trusted_origins_hosts == []
90+
assert middleware.allowed_origins_exact == set()
91+
assert dict(middleware.allowed_origin_subdomains) == {}
92+
93+
# get_setting should be called three times (once for each property)
94+
assert mock_get_setting.call_count == 3
95+
96+
97+
def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware():
98+
"""Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware."""
99+
csrf_check = AnsibleBaseCSRFCheck(lambda request: None)
100+
assert isinstance(csrf_check, AnsibleBaseCsrfViewMiddleware)
101+
102+
103+
def test_ansible_base_csrf_check_reject_method():
104+
"""Test that AnsibleBaseCSRFCheck._reject returns the reason."""
105+
csrf_check = AnsibleBaseCSRFCheck(lambda request: None)
106+
reason = "Test CSRF failure reason"
107+
result = csrf_check._reject(None, reason)
108+
assert result == reason
109+
110+
111+
def test_session_authentication_uses_ansible_base_csrf_check():
112+
"""Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation."""
113+
from unittest.mock import Mock
114+
115+
# Create a mock request with an authenticated user
116+
mock_request = Mock()
117+
mock_request._request = Mock()
118+
mock_request._request.user = Mock()
119+
mock_request._request.user.is_active = True
120+
121+
# Mock the AnsibleBaseCSRFCheck to track its usage
122+
with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class:
123+
mock_csrf_check = Mock()
124+
mock_csrf_check.process_request.return_value = None
125+
mock_csrf_check.process_view.return_value = None # No CSRF error
126+
mock_csrf_check_class.return_value = mock_csrf_check
127+
128+
# Create SessionAuthentication instance and call enforce_csrf
129+
session_auth = SessionAuthentication()
130+
session_auth.enforce_csrf(mock_request)
131+
132+
# Verify AnsibleBaseCSRFCheck was instantiated
133+
mock_csrf_check_class.assert_called_once()
134+
135+
# Verify process_request and process_view were called
136+
mock_csrf_check.process_request.assert_called_once_with(mock_request)
137+
mock_csrf_check.process_view.assert_called_once_with(mock_request, None, (), {})
138+
139+
140+
def test_session_authentication_csrf_failure_raises_permission_denied():
141+
"""Test that SessionAuthentication raises PermissionDenied when CSRF fails."""
142+
from unittest.mock import Mock
143+
144+
from rest_framework.exceptions import PermissionDenied
145+
146+
# Create a mock request with an authenticated user
147+
mock_request = Mock()
148+
mock_request._request = Mock()
149+
mock_request._request.user = Mock()
150+
mock_request._request.user.is_active = True
151+
152+
# Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason
153+
with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class:
154+
mock_csrf_check = Mock()
155+
mock_csrf_check.process_request.return_value = None
156+
mock_csrf_check.process_view.return_value = "CSRF token missing" # CSRF error
157+
mock_csrf_check_class.return_value = mock_csrf_check
158+
159+
# Create SessionAuthentication instance and call enforce_csrf
160+
session_auth = SessionAuthentication()
161+
162+
# Should raise PermissionDenied with the CSRF failure reason
163+
try:
164+
session_auth.enforce_csrf(mock_request)
165+
assert False, "Expected PermissionDenied to be raised"
166+
except PermissionDenied as e:
167+
assert "CSRF Failed: CSRF token missing" in str(e)

0 commit comments

Comments
 (0)