Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions ansible_base/authentication/session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from rest_framework import authentication

from ansible_base.lib.utils.settings import replace_trusted_origins


class SessionAuthentication(authentication.SessionAuthentication):
"""
This class allows us to fail with a 401 if the user is not authenticated.

Allows CSRF_TRUSTED_ORIGINS to be read dynamically using get_setting.
Reverting the value of CSRF_TRUSTED_ORIGINS afterwards.
"""

def authenticate_header(self, request):
return "Session"

@replace_trusted_origins
def enforce_csrf(self, request):
"""
Enforce CSRF validation for session based authentication using
AnsibleBaseCsrfViewMiddleware instead of Django's CsrfViewMiddleware.
"""
return super().enforce_csrf(request)
18 changes: 18 additions & 0 deletions ansible_base/lib/utils/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ def get_function_from_setting(setting_name: str) -> Any:
return None


def replace_trusted_origins(func):
"""Decorator for patching the CSRF_TRUSTED_ORIGINS django setting using the potentially different value in get_setting for the duration of a
function call
"""

def override_setting(*args, **kwargs):
csrf_trusted_origins = settings.CSRF_TRUSTED_ORIGINS
try:
# Temporarily patch the setting
settings.CSRF_TRUSTED_ORIGINS = get_setting("CSRF_TRUSTED_ORIGINS", csrf_trusted_origins)
return func(*args, **kwargs)
finally:
# Revert setting after this is done
settings.CSRF_TRUSTED_ORIGINS = csrf_trusted_origins

return override_setting


def get_from_import(module_name, attr):
"Thin wrapper around importlib.import_module, mostly exists so that we can safely mock this in tests"
module = importlib.import_module(module_name, package=attr)
Expand Down
4 changes: 3 additions & 1 deletion test_app/tests/authentication/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from django.conf import settings
from social_core.exceptions import AuthException

from ansible_base.authentication.middleware import SocialExceptionHandlerMiddleware
from ansible_base.authentication.middleware import (
SocialExceptionHandlerMiddleware,
)


def test_social_exception_handler_mw():
Expand Down