diff --git a/ansible_base/authentication/session.py b/ansible_base/authentication/session.py index a9eb63f79..24ffd89a9 100644 --- a/ansible_base/authentication/session.py +++ b/ansible_base/authentication/session.py @@ -1,10 +1,23 @@ from rest_framework import authentication +from ansible_base.lib.utils.settings import replace_trusted_origins + class SessionAuthentication(authentication.SessionAuthentication): """ This class allows us to fail with a 401 if the user is not authenticated. + + Allows CSRF_TRUSTED_ORIGINS to be read dynamically using get_setting. + Reverting the value of CSRF_TRUSTED_ORIGINS afterwards. """ def authenticate_header(self, request): return "Session" + + @replace_trusted_origins + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication using + AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware. + """ + return super().enforce_csrf(request) diff --git a/ansible_base/lib/utils/settings.py b/ansible_base/lib/utils/settings.py index e71bd96ee..b79d8a93e 100644 --- a/ansible_base/lib/utils/settings.py +++ b/ansible_base/lib/utils/settings.py @@ -49,6 +49,24 @@ def get_function_from_setting(setting_name: str) -> Any: return None +def replace_trusted_origins(func): + """Decorator for patching the CSRF_TRUSTED_ORIGINS django setting using the potentially different value in get_setting for the duration of a + function call + """ + + def override_setting(*args, **kwargs): + csrf_trusted_origins = settings.CSRF_TRUSTED_ORIGINS + try: + # Temporarily patch the setting + settings.CSRF_TRUSTED_ORIGINS = get_setting("CSRF_TRUSTED_ORIGINS", csrf_trusted_origins) + return func(*args, **kwargs) + finally: + # Revert setting after this is done + settings.CSRF_TRUSTED_ORIGINS = csrf_trusted_origins + + return override_setting + + def get_from_import(module_name, attr): "Thin wrapper around importlib.import_module, mostly exists so that we can safely mock this in tests" module = importlib.import_module(module_name, package=attr) diff --git a/test_app/tests/authentication/test_middleware.py b/test_app/tests/authentication/test_middleware.py index 627bee931..744afc3e4 100644 --- a/test_app/tests/authentication/test_middleware.py +++ b/test_app/tests/authentication/test_middleware.py @@ -1,7 +1,9 @@ from django.conf import settings from social_core.exceptions import AuthException -from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware +from ansible_base.authentication.middleware import ( + SocialExceptionHandlerMiddleware, +) def test_social_exception_handler_mw():