1
+ from unittest .mock import patch
2
+
1
3
from django .conf import settings
2
4
from social_core .exceptions import AuthException
3
5
4
- from ansible_base .authentication .middleware import SocialExceptionHandlerMiddleware
6
+ from ansible_base .authentication .middleware import (
7
+ AnsibleBaseCsrfViewMiddleware ,
8
+ SocialExceptionHandlerMiddleware ,
9
+ )
10
+ from ansible_base .authentication .session import (
11
+ AnsibleBaseCSRFCheck ,
12
+ SessionAuthentication ,
13
+ )
5
14
6
15
7
16
def test_social_exception_handler_mw ():
@@ -21,3 +30,138 @@ def __init__(self):
21
30
mw = SocialExceptionHandlerMiddleware (None )
22
31
url = mw .get_redirect_uri (Request (), AuthException ("test" ))
23
32
assert url == "/?auth_failed"
33
+
34
+
35
+ def test_ansible_base_csrf_view_middleware_csrf_trusted_origins_hosts ():
36
+ """Test that csrf_trusted_origins_hosts uses get_setting."""
37
+ test_origins = ['https://example.com' , 'https://*.test.com' ]
38
+
39
+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
40
+ mock_get_setting .return_value = test_origins
41
+
42
+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
43
+ result = middleware .csrf_trusted_origins_hosts
44
+
45
+ mock_get_setting .assert_called_once_with ('CSRF_TRUSTED_ORIGINS' , [])
46
+ # Should strip * from netloc
47
+ assert result == ['example.com' , '.test.com' ]
48
+
49
+
50
+ def test_ansible_base_csrf_view_middleware_allowed_origins_exact ():
51
+ """Test that allowed_origins_exact uses get_setting."""
52
+ test_origins = ['https://example.com' , 'https://*.test.com' ]
53
+
54
+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
55
+ mock_get_setting .return_value = test_origins
56
+
57
+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
58
+ result = middleware .allowed_origins_exact
59
+
60
+ mock_get_setting .assert_called_once_with ('CSRF_TRUSTED_ORIGINS' , [])
61
+ # Should only include origins without *
62
+ assert result == {'https://example.com' }
63
+
64
+
65
+ def test_ansible_base_csrf_view_middleware_allowed_origin_subdomains ():
66
+ """Test that allowed_origin_subdomains uses get_setting."""
67
+ test_origins = ['https://*.example.com' , 'http://*.test.com' ]
68
+
69
+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
70
+ mock_get_setting .return_value = test_origins
71
+
72
+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
73
+ result = middleware .allowed_origin_subdomains
74
+
75
+ mock_get_setting .assert_called_once_with ('CSRF_TRUSTED_ORIGINS' , [])
76
+ # Should group by scheme and strip *
77
+ expected = {'https' : ['.example.com' ], 'http' : ['.test.com' ]}
78
+ assert dict (result ) == expected
79
+
80
+
81
+ def test_ansible_base_csrf_view_middleware_default_value ():
82
+ """Test that middleware returns empty/default values when setting is empty."""
83
+ with patch ('ansible_base.authentication.middleware.get_setting' ) as mock_get_setting :
84
+ mock_get_setting .return_value = []
85
+
86
+ middleware = AnsibleBaseCsrfViewMiddleware (lambda request : None )
87
+
88
+ # Test all three properties
89
+ assert middleware .csrf_trusted_origins_hosts == []
90
+ assert middleware .allowed_origins_exact == set ()
91
+ assert dict (middleware .allowed_origin_subdomains ) == {}
92
+
93
+ # get_setting should be called three times (once for each property)
94
+ assert mock_get_setting .call_count == 3
95
+
96
+
97
+ def test_ansible_base_csrf_check_inherits_from_ansible_base_csrf_view_middleware ():
98
+ """Test that AnsibleBaseCSRFCheck inherits from AnsibleBaseCsrfViewMiddleware."""
99
+ csrf_check = AnsibleBaseCSRFCheck (lambda request : None )
100
+ assert isinstance (csrf_check , AnsibleBaseCsrfViewMiddleware )
101
+
102
+
103
+ def test_ansible_base_csrf_check_reject_method ():
104
+ """Test that AnsibleBaseCSRFCheck._reject returns the reason."""
105
+ csrf_check = AnsibleBaseCSRFCheck (lambda request : None )
106
+ reason = "Test CSRF failure reason"
107
+ result = csrf_check ._reject (None , reason )
108
+ assert result == reason
109
+
110
+
111
+ def test_session_authentication_uses_ansible_base_csrf_check ():
112
+ """Test that SessionAuthentication uses AnsibleBaseCSRFCheck for CSRF validation."""
113
+ from unittest .mock import Mock
114
+
115
+ # Create a mock request with an authenticated user
116
+ mock_request = Mock ()
117
+ mock_request ._request = Mock ()
118
+ mock_request ._request .user = Mock ()
119
+ mock_request ._request .user .is_active = True
120
+
121
+ # Mock the AnsibleBaseCSRFCheck to track its usage
122
+ with patch ('ansible_base.authentication.session.AnsibleBaseCSRFCheck' ) as mock_csrf_check_class :
123
+ mock_csrf_check = Mock ()
124
+ mock_csrf_check .process_request .return_value = None
125
+ mock_csrf_check .process_view .return_value = None # No CSRF error
126
+ mock_csrf_check_class .return_value = mock_csrf_check
127
+
128
+ # Create SessionAuthentication instance and call enforce_csrf
129
+ session_auth = SessionAuthentication ()
130
+ session_auth .enforce_csrf (mock_request )
131
+
132
+ # Verify AnsibleBaseCSRFCheck was instantiated
133
+ mock_csrf_check_class .assert_called_once ()
134
+
135
+ # Verify process_request and process_view were called
136
+ mock_csrf_check .process_request .assert_called_once_with (mock_request )
137
+ mock_csrf_check .process_view .assert_called_once_with (mock_request , None , (), {})
138
+
139
+
140
+ def test_session_authentication_csrf_failure_raises_permission_denied ():
141
+ """Test that SessionAuthentication raises PermissionDenied when CSRF fails."""
142
+ from unittest .mock import Mock
143
+
144
+ from rest_framework .exceptions import PermissionDenied
145
+
146
+ # Create a mock request with an authenticated user
147
+ mock_request = Mock ()
148
+ mock_request ._request = Mock ()
149
+ mock_request ._request .user = Mock ()
150
+ mock_request ._request .user .is_active = True
151
+
152
+ # Mock the AnsibleBaseCSRFCheck to return a CSRF failure reason
153
+ with patch ('ansible_base.authentication.session.AnsibleBaseCSRFCheck' ) as mock_csrf_check_class :
154
+ mock_csrf_check = Mock ()
155
+ mock_csrf_check .process_request .return_value = None
156
+ mock_csrf_check .process_view .return_value = "CSRF token missing" # CSRF error
157
+ mock_csrf_check_class .return_value = mock_csrf_check
158
+
159
+ # Create SessionAuthentication instance and call enforce_csrf
160
+ session_auth = SessionAuthentication ()
161
+
162
+ # Should raise PermissionDenied with the CSRF failure reason
163
+ try :
164
+ session_auth .enforce_csrf (mock_request )
165
+ assert False , "Expected PermissionDenied to be raised"
166
+ except PermissionDenied as e :
167
+ assert "CSRF Failed: CSRF token missing" in str (e )
0 commit comments