diff --git a/src/corsheaders/conf.py b/src/corsheaders/conf.py index 226b5325..1299e5a3 100644 --- a/src/corsheaders/conf.py +++ b/src/corsheaders/conf.py @@ -1,13 +1,12 @@ from __future__ import annotations -from typing import cast -from typing import List +import re +from dataclasses import dataclass +from typing import Any from typing import Pattern from typing import Sequence -from typing import Tuple -from typing import Union -from django.conf import settings +from django.conf import settings as _django_settings from corsheaders.defaults import default_headers from corsheaders.defaults import default_methods @@ -15,63 +14,44 @@ # Kept here for backwards compatibility +@dataclass class Settings: + CORS_ALLOW_HEADERS: Sequence[str] = default_headers + CORS_ALLOW_METHODS: Sequence[str] = default_methods + CORS_ALLOW_CREDENTIALS: bool = False + CORS_ALLOW_PRIVATE_NETWORK: bool = False + CORS_PREFLIGHT_MAX_AGE: int = 86400 + CORS_ALLOW_ALL_ORIGINS: bool = False + CORS_ALLOWED_ORIGINS: list[str] | tuple[str] = () # type: ignore + CORS_ALLOWED_ORIGIN_REGEXES: Sequence[str | Pattern[str]] = () + CORS_EXPOSE_HEADERS: Sequence[str] = () + CORS_URLS_REGEX: str | Pattern[str] = re.compile(r"^.*$") + + +_RENAMED_SETTINGS = { + # New name -> Old name + "CORS_ALLOW_ALL_ORIGINS": "CORS_ORIGIN_ALLOW_ALL", + "CORS_ALLOWED_ORIGINS": "CORS_ORIGIN_WHITELIST", + "CORS_ALLOWED_ORIGIN_REGEXES": "CORS_ORIGIN_REGEX_WHITELIST", +} + + +class DjangoConfig(Settings): """ - Shadow Django's settings with a little logic - """ - - @property - def CORS_ALLOW_HEADERS(self) -> Sequence[str]: - return getattr(settings, "CORS_ALLOW_HEADERS", default_headers) - - @property - def CORS_ALLOW_METHODS(self) -> Sequence[str]: - return getattr(settings, "CORS_ALLOW_METHODS", default_methods) - - @property - def CORS_ALLOW_CREDENTIALS(self) -> bool: - return getattr(settings, "CORS_ALLOW_CREDENTIALS", False) - - @property - def CORS_ALLOW_PRIVATE_NETWORK(self) -> bool: - return getattr(settings, "CORS_ALLOW_PRIVATE_NETWORK", False) + A version of Settings that prefers to read from Django's settings. - @property - def CORS_PREFLIGHT_MAX_AGE(self) -> int: - return getattr(settings, "CORS_PREFLIGHT_MAX_AGE", 86400) - - @property - def CORS_ALLOW_ALL_ORIGINS(self) -> bool: - return getattr( - settings, - "CORS_ALLOW_ALL_ORIGINS", - getattr(settings, "CORS_ORIGIN_ALLOW_ALL", False), - ) - - @property - def CORS_ALLOWED_ORIGINS(self) -> list[str] | tuple[str]: - value = getattr( - settings, - "CORS_ALLOWED_ORIGINS", - getattr(settings, "CORS_ORIGIN_WHITELIST", ()), - ) - return cast(Union[List[str], Tuple[str]], value) - - @property - def CORS_ALLOWED_ORIGIN_REGEXES(self) -> Sequence[str | Pattern[str]]: - return getattr( - settings, - "CORS_ALLOWED_ORIGIN_REGEXES", - getattr(settings, "CORS_ORIGIN_REGEX_WHITELIST", ()), - ) - - @property - def CORS_EXPOSE_HEADERS(self) -> Sequence[str]: - return getattr(settings, "CORS_EXPOSE_HEADERS", ()) + Falls back to its own values if the setting is not configured + in Django. + """ - @property - def CORS_URLS_REGEX(self) -> str | Pattern[str]: - return getattr(settings, "CORS_URLS_REGEX", r"^.*$") + def __getattribute__(self, name: str) -> Any: + default = object.__getattribute__(self, name) + if name in _RENAMED_SETTINGS: + # Renamed settings are used if the new setting + # is not configured in Django, + old_name = _RENAMED_SETTINGS[name] + default = getattr(_django_settings, old_name, default) + return getattr(_django_settings, name, default) -conf = Settings() +conf = DjangoConfig() diff --git a/src/corsheaders/decorators.py b/src/corsheaders/decorators.py new file mode 100644 index 00000000..578a733e --- /dev/null +++ b/src/corsheaders/decorators.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import asyncio +import functools +from typing import Any +from typing import Callable +from typing import cast +from typing import TypeVar + +from django.http import HttpRequest +from django.http.response import HttpResponseBase + +from corsheaders.conf import conf as _conf +from corsheaders.conf import Settings +from corsheaders.middleware import CorsMiddleware + +F = TypeVar("F", bound=Callable[..., HttpResponseBase]) + + +def cors(func: F | None = None, *, conf: Settings = _conf) -> F | Callable[[F], F]: + if func is None: + return cast(Callable[[F], F], functools.partial(cors, conf=conf)) + + assert callable(func) + + if asyncio.iscoroutinefunction(func): + + async def inner( + _request: HttpRequest, *args: Any, **kwargs: Any + ) -> HttpResponseBase: + async def get_response(request: HttpRequest) -> HttpResponseBase: + return await func(request, *args, **kwargs) + + return await CorsMiddleware(get_response, conf=conf)(_request) + + else: + + def inner(_request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase: + def get_response(request: HttpRequest) -> HttpResponseBase: + return func(request, *args, **kwargs) + + return CorsMiddleware(get_response, conf=conf)(_request) + + wrapper = functools.wraps(func)(inner) + wrapper._skip_cors_middleware = True # type: ignore [attr-defined] + return cast(F, wrapper) diff --git a/src/corsheaders/middleware.py b/src/corsheaders/middleware.py index e23b9692..12c870a6 100644 --- a/src/corsheaders/middleware.py +++ b/src/corsheaders/middleware.py @@ -2,6 +2,7 @@ import asyncio import re +from typing import Any from typing import Awaitable from typing import Callable from urllib.parse import SplitResult @@ -13,6 +14,7 @@ from django.utils.cache import patch_vary_headers from corsheaders.conf import conf +from corsheaders.conf import Settings from corsheaders.signals import check_request_enabled ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin" @@ -35,8 +37,10 @@ def __init__( Callable[[HttpRequest], HttpResponseBase] | Callable[[HttpRequest], Awaitable[HttpResponseBase]] ), + conf: Settings = conf, ) -> None: self.get_response = get_response + self.conf = conf if asyncio.iscoroutinefunction(self.get_response): # Mark the class as async-capable, but do the actual switch # inside __call__ to avoid swapping out dunder methods @@ -51,22 +55,40 @@ def __call__( ) -> HttpResponseBase | Awaitable[HttpResponseBase]: if self._is_coroutine: return self.__acall__(request) - response: HttpResponseBase | None = self.check_preflight(request) - if response is None: - result = self.get_response(request) - assert isinstance(result, HttpResponseBase) - response = result - self.add_response_headers(request, response) - return response + result = self.get_response(request) + assert isinstance(result, HttpResponseBase) + response = result + if getattr(response, "_cors_processing_done", False): + return response + else: + # Request wasn't processed (e.g. because of a 404) + return self.add_response_headers( + request, self.check_preflight(request) or response + ) async def __acall__(self, request: HttpRequest) -> HttpResponseBase: - response = self.check_preflight(request) - if response is None: - result = self.get_response(request) - assert not isinstance(result, HttpResponseBase) - response = await result - self.add_response_headers(request, response) - return response + result = self.get_response(request) + assert not isinstance(result, HttpResponseBase) + response = await result + if getattr(response, "_cors_processing_done", False): + return response + else: + # View wasn't processed (e.g. because of a 404) + return self.add_response_headers( + request, self.check_preflight(request) or response + ) + + def process_view( + self, + request: HttpRequest, + callback: Callable[[HttpRequest], HttpResponseBase], + callback_args: Any, + callback_kwargs: Any, + ) -> HttpResponseBase | None: + if getattr(callback, "_skip_cors_middleware", False): + # View is decorated and will add CORS headers itself + return None + return self.check_preflight(request) def check_preflight(self, request: HttpRequest) -> HttpResponseBase | None: """ @@ -87,6 +109,7 @@ def add_response_headers( """ Add the respective CORS headers """ + response._cors_processing_done = True enabled = getattr(request, "_cors_enabled", None) if enabled is None: enabled = self.is_enabled(request) @@ -105,34 +128,38 @@ def add_response_headers( except ValueError: return response - if conf.CORS_ALLOW_CREDENTIALS: + if self.conf.CORS_ALLOW_CREDENTIALS: response[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true" if ( - not conf.CORS_ALLOW_ALL_ORIGINS + not self.conf.CORS_ALLOW_ALL_ORIGINS and not self.origin_found_in_white_lists(origin, url) and not self.check_signal(request) ): return response - if conf.CORS_ALLOW_ALL_ORIGINS and not conf.CORS_ALLOW_CREDENTIALS: + if self.conf.CORS_ALLOW_ALL_ORIGINS and not self.conf.CORS_ALLOW_CREDENTIALS: response[ACCESS_CONTROL_ALLOW_ORIGIN] = "*" else: response[ACCESS_CONTROL_ALLOW_ORIGIN] = origin - if len(conf.CORS_EXPOSE_HEADERS): + if len(self.conf.CORS_EXPOSE_HEADERS): response[ACCESS_CONTROL_EXPOSE_HEADERS] = ", ".join( - conf.CORS_EXPOSE_HEADERS + self.conf.CORS_EXPOSE_HEADERS ) if request.method == "OPTIONS": - response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join(conf.CORS_ALLOW_HEADERS) - response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join(conf.CORS_ALLOW_METHODS) - if conf.CORS_PREFLIGHT_MAX_AGE: - response[ACCESS_CONTROL_MAX_AGE] = str(conf.CORS_PREFLIGHT_MAX_AGE) + response[ACCESS_CONTROL_ALLOW_HEADERS] = ", ".join( + self.conf.CORS_ALLOW_HEADERS + ) + response[ACCESS_CONTROL_ALLOW_METHODS] = ", ".join( + self.conf.CORS_ALLOW_METHODS + ) + if self.conf.CORS_PREFLIGHT_MAX_AGE: + response[ACCESS_CONTROL_MAX_AGE] = str(self.conf.CORS_PREFLIGHT_MAX_AGE) if ( - conf.CORS_ALLOW_PRIVATE_NETWORK + self.conf.CORS_ALLOW_PRIVATE_NETWORK and request.headers.get(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK) == "true" ): response[ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK] = "true" @@ -141,7 +168,7 @@ def add_response_headers( def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool: return ( - (origin == "null" and origin in conf.CORS_ALLOWED_ORIGINS) + (origin == "null" and origin in self.conf.CORS_ALLOWED_ORIGINS) or self._url_in_whitelist(url) or self.regex_domain_match(origin) ) @@ -149,12 +176,12 @@ def origin_found_in_white_lists(self, origin: str, url: SplitResult) -> bool: def regex_domain_match(self, origin: str) -> bool: return any( re.match(domain_pattern, origin) - for domain_pattern in conf.CORS_ALLOWED_ORIGIN_REGEXES + for domain_pattern in self.conf.CORS_ALLOWED_ORIGIN_REGEXES ) def is_enabled(self, request: HttpRequest) -> bool: return bool( - re.match(conf.CORS_URLS_REGEX, request.path_info) + re.match(self.conf.CORS_URLS_REGEX, request.path_info) ) or self.check_signal(request) def check_signal(self, request: HttpRequest) -> bool: @@ -162,7 +189,7 @@ def check_signal(self, request: HttpRequest) -> bool: return any(return_value for function, return_value in signal_responses) def _url_in_whitelist(self, url: SplitResult) -> bool: - origins = [urlsplit(o) for o in conf.CORS_ALLOWED_ORIGINS] + origins = [urlsplit(o) for o in self.conf.CORS_ALLOWED_ORIGINS] return any( origin.scheme == url.scheme and origin.netloc == url.netloc for origin in origins diff --git a/tests/test_conf.py b/tests/test_conf.py index 87c0c446..63cf38c2 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -7,6 +7,11 @@ class ConfTests(SimpleTestCase): + @override_settings(SECRET_KEY="foo") + def test_other_setting(self): + # Only proxy settings that are defined in the Settings class. + self.assertRaises(AttributeError, getattr, conf, "SECRET_KEY") + @override_settings(CORS_ALLOW_HEADERS=["foo"]) def test_can_override(self): assert conf.CORS_ALLOW_HEADERS == ["foo"] diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 00000000..34666a0b --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from django.test import TestCase +from django.test.utils import modify_settings +from django.test.utils import override_settings + +from corsheaders.middleware import ACCESS_CONTROL_ALLOW_ORIGIN + + +@modify_settings( + MIDDLEWARE={ + "remove": "corsheaders.middleware.CorsMiddleware", + } +) +@override_settings(CORS_ALLOWED_ORIGINS=["https://example.com"]) +class CorsDecoratorsTestCase(TestCase): + def test_get_no_origin(self): + resp = self.client.get("/decorated/hello/") + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated: hello" + + def test_get_not_in_allowed_origins(self): + resp = self.client.get( + "/decorated/hello/", + HTTP_ORIGIN="https://example.net", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated: hello" + + def test_get_in_allowed_origins_preflight(self): + resp = self.client.options( + "/decorated/hello/", + HTTP_ORIGIN="https://example.com", + HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + def test_get_in_allowed_origins(self): + resp = self.client.get( + "/decorated/hello/", + HTTP_ORIGIN="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Decorated: hello" + + async def test_async_get_not_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated/hello/", + origin="https://example.org", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Async Decorated: hello" + + async def test_async_get_in_allowed_origins_preflight(self): + resp = await self.async_client.options( + "/async-decorated/hello/", + origin="https://example.com", + access_control_request_method="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + async def test_async_get_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated/hello/", + origin="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Async Decorated: hello" + + +class CorsDecoratorsWithConfTestCase(TestCase): + def test_get_no_origin(self): + resp = self.client.get("/decorated-with-conf/hello/") + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated (with conf): hello" + + def test_get_not_in_allowed_origins(self): + resp = self.client.get( + "/decorated-with-conf/hello/", + HTTP_ORIGIN="https://example.net", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Decorated (with conf): hello" + + def test_get_in_allowed_origins_preflight(self): + resp = self.client.options( + "/decorated-with-conf/hello/", + HTTP_ORIGIN="https://example.com", + HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + def test_get_in_allowed_origins(self): + resp = self.client.get( + "/decorated-with-conf/hello/", + HTTP_ORIGIN="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Decorated (with conf): hello" + + async def test_async_get_not_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated-with-conf/hello/", + origin="https://example.org", + ) + assert ACCESS_CONTROL_ALLOW_ORIGIN not in resp + assert resp.content == b"Async Decorated (with conf): hello" + + async def test_async_get_in_allowed_origins_preflight(self): + resp = await self.async_client.options( + "/async-decorated-with-conf/hello/", + origin="https://example.com", + access_control_request_method="GET", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"" + + async def test_async_get_in_allowed_origins(self): + resp = await self.async_client.get( + "/async-decorated-with-conf/hello/", + origin="https://example.com", + ) + assert resp[ACCESS_CONTROL_ALLOW_ORIGIN] == "https://example.com" + assert resp.content == b"Async Decorated (with conf): hello" diff --git a/tests/urls.py b/tests/urls.py index a790fb1d..c3f64407 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -7,6 +7,10 @@ urlpatterns = [ path("", views.index), path("async/", views.async_), + path("decorated//", views.decorated), + path("decorated-with-conf//", views.decorated_with_conf), + path("async-decorated//", views.async_decorated), + path("async-decorated-with-conf//", views.async_decorated_with_conf), path("unauthorized/", views.unauthorized), path("delete-enabled/", views.delete_enabled_attribute), ] diff --git a/tests/views.py b/tests/views.py index 06e257a5..ecfc5761 100644 --- a/tests/views.py +++ b/tests/views.py @@ -5,6 +5,9 @@ from django.http import HttpResponse from django.views.decorators.http import require_GET +from corsheaders.conf import Settings +from corsheaders.decorators import cors + @require_GET def index(request): @@ -15,6 +18,26 @@ async def async_(request): return HttpResponse("Asynchronous") +@cors +def decorated(request, slug): + return HttpResponse(f"Decorated: {slug}") + + +@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"])) +def decorated_with_conf(request, slug): + return HttpResponse(f"Decorated (with conf): {slug}") + + +@cors +async def async_decorated(request, slug): + return HttpResponse(f"Async Decorated: {slug}") + + +@cors(conf=Settings(CORS_ALLOWED_ORIGINS=["https://example.com"])) +async def async_decorated_with_conf(request, slug): + return HttpResponse(f"Async Decorated (with conf): {slug}") + + def unauthorized(request): return HttpResponse("Unauthorized", status=HTTPStatus.UNAUTHORIZED)