Skip to content

Commit 4071571

Browse files
committed
Add AnsibleBaseCsrfViewMiddleware and update SessionAuthentication to use get_setting
- Add AnsibleBaseCsrfViewMiddleware that reads CSRF_TRUSTED_ORIGINS using ansible_base.lib.utils.settings.get_setting instead of directly from Django settings - Override all three cached properties (csrf_trusted_origins_hosts, allowed_origins_exact, allowed_origin_subdomains) to use get_setting for dynamic configuration - Add AnsibleBaseCSRFCheck class that inherits from AnsibleBaseCsrfViewMiddleware - Modify SessionAuthentication.enforce_csrf to use AnsibleBaseCSRFCheck instead of Django's default CSRFCheck - Add comprehensive tests for both middleware and session authentication CSRF functionality - Enables CSRF_TRUSTED_ORIGINS to be dynamically loaded from various sources via ANSIBLE_BASE_SETTINGS_FUNCTION while maintaining backward compatibility
1 parent d758999 commit 4071571

File tree

3 files changed

+241
-2
lines changed

3 files changed

+241
-2
lines changed

ansible_base/authentication/middleware.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import logging
22

3+
from collections import defaultdict
4+
from urllib.parse import urlsplit
5+
36
from django.contrib.auth import BACKEND_SESSION_KEY
47
from django.core.exceptions import ImproperlyConfigured
8+
from django.middleware.csrf import CsrfViewMiddleware
9+
from django.utils.functional import cached_property
510
from django.utils.deprecation import MiddlewareMixin
611
from social_django.middleware import SocialAuthExceptionMiddleware
712

813
from ansible_base.authentication.authenticator_plugins.utils import get_authenticator_plugins
14+
from ansible_base.lib.utils.settings import get_setting
915

1016
logger = logging.getLogger('ansible_base.authentication.middleware')
1117

@@ -51,3 +57,55 @@ def get_redirect_uri(self, request, exception):
5157
backend_name = getattr(backend, "name", "unknown-backend")
5258
logger.error(f"Auth failure for backend {backend_name} - {repr(exception)}, redirecting to {error_url}")
5359
return error_url
60+
61+
62+
class AnsibleBaseCsrfViewMiddleware(CsrfViewMiddleware):
63+
"""
64+
CsrfViewMiddleware subclass that reads CSRF_TRUSTED_ORIGINS using
65+
ansible_base.lib.utils.settings.get_setting instead of directly from
66+
Django settings.
67+
68+
This allows the setting to be dynamically loaded from various sources
69+
as configured by the ANSIBLE_BASE_SETTINGS_FUNCTION setting.
70+
71+
Overrides all cached properties that access settings.CSRF_TRUSTED_ORIGINS
72+
to use get_setting instead.
73+
"""
74+
75+
@cached_property
76+
def csrf_trusted_origins_hosts(self):
77+
"""
78+
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
79+
"""
80+
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
81+
return [
82+
urlsplit(origin).netloc.lstrip("*")
83+
for origin in csrf_trusted_origins
84+
]
85+
86+
@cached_property
87+
def allowed_origins_exact(self):
88+
"""
89+
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
90+
"""
91+
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
92+
return {origin for origin in csrf_trusted_origins if "*" not in origin}
93+
94+
@cached_property
95+
def allowed_origin_subdomains(self):
96+
"""
97+
Override to use get_setting instead of settings.CSRF_TRUSTED_ORIGINS.
98+
A mapping of allowed schemes to list of allowed netlocs, where all
99+
subdomains of the netloc are allowed.
100+
"""
101+
csrf_trusted_origins = get_setting('CSRF_TRUSTED_ORIGINS', [])
102+
allowed_origin_subdomains = defaultdict(list)
103+
for parsed in (
104+
urlsplit(origin)
105+
for origin in csrf_trusted_origins
106+
if "*" in origin
107+
):
108+
allowed_origin_subdomains[parsed.scheme].append(
109+
parsed.netloc.lstrip("*")
110+
)
111+
return allowed_origin_subdomains
Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,48 @@
1-
from rest_framework import authentication
1+
from rest_framework import authentication, exceptions
2+
3+
from ansible_base.authentication.middleware import (
4+
AnsibleBaseCsrfViewMiddleware,
5+
)
6+
7+
8+
class AnsibleBaseCSRFCheck(AnsibleBaseCsrfViewMiddleware):
9+
"""
10+
Custom CSRF check class that uses AnsibleBaseCsrfViewMiddleware
11+
instead of Django's CsrfViewMiddleware for CSRF validation.
12+
13+
This ensures that CSRF_TRUSTED_ORIGINS is read using get_setting
14+
instead of directly from Django settings.
15+
"""
16+
17+
def _reject(self, request, reason):
18+
# Return the failure reason instead of an HttpResponse
19+
return reason
220

321

422
class SessionAuthentication(authentication.SessionAuthentication):
523
"""
624
This class allows us to fail with a 401 if the user is not authenticated.
25+
26+
Uses AnsibleBaseCsrfViewMiddleware for CSRF checking instead of Django's
27+
default CsrfViewMiddleware, allowing CSRF_TRUSTED_ORIGINS to be read
28+
dynamically using get_setting.
729
"""
830

931
def authenticate_header(self, request):
1032
return "Session"
33+
34+
def enforce_csrf(self, request):
35+
"""
36+
Enforce CSRF validation for session based authentication using
37+
AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware.
38+
"""
39+
def dummy_get_response(request): # pragma: no cover
40+
return None
41+
42+
check = AnsibleBaseCSRFCheck(dummy_get_response)
43+
# populates request.META['CSRF_COOKIE'], which is used in process_view()
44+
check.process_request(request)
45+
reason = check.process_view(request, None, (), {})
46+
if reason:
47+
# CSRF failed, bail with explicit error message
48+
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)

test_app/tests/authentication/test_middleware.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1+
from unittest.mock import patch
2+
13
from django.conf import settings
24
from social_core.exceptions import AuthException
35

4-
from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware
6+
from ansible_base.authentication.middleware import (
7+
AnsibleBaseCsrfViewMiddleware,
8+
SocialExceptionHandlerMiddleware,
9+
)
10+
from ansible_base.authentication.session import (
11+
AnsibleBaseCSRFCheck,
12+
SessionAuthentication,
13+
)
514

615

716
def test_social_exception_handler_mw():
@@ -21,3 +30,137 @@ def __init__(self):
2130
mw = SocialExceptionHandlerMiddleware(None)
2231
url = mw.get_redirect_uri(Request(), AuthException("test"))
2332
assert url == "/?auth_failed"
33+
34+
35+
def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts():
36+
"""Test that csrf_trusted_origins_hosts uses get_setting."""
37+
test_origins = ['https://example.com', 'https://*.test.com']
38+
39+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
40+
mock_get_setting.return_value = test_origins
41+
42+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
43+
result = middleware.csrf_trusted_origins_hosts
44+
45+
mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
46+
# Should strip * from netloc
47+
assert result == ['example.com', '.test.com']
48+
49+
50+
def test_ansible_base_csrf_view_middleware_allowed_origins_exact():
51+
"""Test that allowed_origins_exact uses get_setting."""
52+
test_origins = ['https://example.com', 'https://*.test.com']
53+
54+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
55+
mock_get_setting.return_value = test_origins
56+
57+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
58+
result = middleware.allowed_origins_exact
59+
60+
mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
61+
# Should only include origins without *
62+
assert result == {'https://example.com'}
63+
64+
65+
def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains():
66+
"""Test that allowed_origin_subdomains uses get_setting."""
67+
test_origins = ['https://*.example.com', 'http://*.test.com']
68+
69+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
70+
mock_get_setting.return_value = test_origins
71+
72+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
73+
result = middleware.allowed_origin_subdomains
74+
75+
mock_get_setting.assert_called_once_with('CSRF_TRUSTED_ORIGINS', [])
76+
# Should group by scheme and strip *
77+
expected = {'https': ['.example.com'], 'http': ['.test.com']}
78+
assert dict(result) == expected
79+
80+
81+
def test_ansible_base_csrf_view_middleware_default_value():
82+
"""Test that middleware returns empty/default values when setting is empty."""
83+
with patch('ansible_base.authentication.middleware.get_setting') as mock_get_setting:
84+
mock_get_setting.return_value = []
85+
86+
middleware = AnsibleBaseCsrfViewMiddleware(lambda request: None)
87+
88+
# Test all three properties
89+
assert middleware.csrf_trusted_origins_hosts == []
90+
assert middleware.allowed_origins_exact == set()
91+
assert dict(middleware.allowed_origin_subdomains) == {}
92+
93+
# get_setting should be called three times (once for each property)
94+
assert mock_get_setting.call_count == 3
95+
96+
97+
def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware():
98+
"""Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware."""
99+
csrf_check = AnsibleBaseCSRFCheck(lambda request: None)
100+
assert isinstance(csrf_check, AnsibleBaseCsrfViewMiddleware)
101+
102+
103+
def test_ansible_base_csrf_check_reject_method():
104+
"""Test that AnsibleBaseCSRFCheck._reject returns the reason."""
105+
csrf_check = AnsibleBaseCSRFCheck(lambda request: None)
106+
reason = "Test CSRF failure reason"
107+
result = csrf_check._reject(None, reason)
108+
assert result == reason
109+
110+
111+
def test_session_authentication_uses_ansible_base_csrf_check():
112+
"""Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation."""
113+
from unittest.mock import MagicMock, Mock
114+
115+
# Create a mock request with an authenticated user
116+
mock_request = Mock()
117+
mock_request._request = Mock()
118+
mock_request._request.user = Mock()
119+
mock_request._request.user.is_active = True
120+
121+
# Mock the AnsibleBaseCSRFCheck to track its usage
122+
with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class:
123+
mock_csrf_check = Mock()
124+
mock_csrf_check.process_request.return_value = None
125+
mock_csrf_check.process_view.return_value = None # No CSRF error
126+
mock_csrf_check_class.return_value = mock_csrf_check
127+
128+
# Create SessionAuthentication instance and call enforce_csrf
129+
session_auth = SessionAuthentication()
130+
session_auth.enforce_csrf(mock_request)
131+
132+
# Verify AnsibleBaseCSRFCheck was instantiated
133+
mock_csrf_check_class.assert_called_once()
134+
135+
# Verify process_request and process_view were called
136+
mock_csrf_check.process_request.assert_called_once_with(mock_request)
137+
mock_csrf_check.process_view.assert_called_once_with(mock_request, None, (), {})
138+
139+
140+
def test_session_authentication_csrf_failure_raises_permission_denied():
141+
"""Test that SessionAuthentication raises PermissionDenied when CSRF fails."""
142+
from rest_framework.exceptions import PermissionDenied
143+
from unittest.mock import Mock
144+
145+
# Create a mock request with an authenticated user
146+
mock_request = Mock()
147+
mock_request._request = Mock()
148+
mock_request._request.user = Mock()
149+
mock_request._request.user.is_active = True
150+
151+
# Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason
152+
with patch('ansible_base.authentication.session.AnsibleBaseCSRFCheck') as mock_csrf_check_class:
153+
mock_csrf_check = Mock()
154+
mock_csrf_check.process_request.return_value = None
155+
mock_csrf_check.process_view.return_value = "CSRF token missing" # CSRF error
156+
mock_csrf_check_class.return_value = mock_csrf_check
157+
158+
# Create SessionAuthentication instance and call enforce_csrf
159+
session_auth = SessionAuthentication()
160+
161+
# Should raise PermissionDenied with the CSRF failure reason
162+
try:
163+
session_auth.enforce_csrf(mock_request)
164+
assert False, "Expected PermissionDenied to be raised"
165+
except PermissionDenied as e:
166+
assert "CSRF Failed: CSRF token missing" in str(e)

0 commit comments

Comments
 (0)