diff --git a/AUTHORS b/AUTHORS index 0c46746e7..d651e3c75 100644 --- a/AUTHORS +++ b/AUTHORS @@ -57,6 +57,7 @@ Pavel Tvrdík Patrick Palacin Peter Carnesciali Petr Dlouhý +Rebecca Claire Murphy Rodney Richardson Rustem Saiargaliev Sandro Rodrigues diff --git a/CHANGELOG.md b/CHANGELOG.md index 7df17aee2..20b19a103 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased] ### Added +* #1150 Automatic CORS Headers based on Application redirect_url. * Support `prompt=login` for the OIDC Authorization Code Flow end user [Authentication Request](https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest). * #1163 Adds French translations. * #1166 Add spanish (es) translations. @@ -24,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed * #1152 `createapplication` management command enhanced to display an auto-generated secret before it gets hashed. + ## [2.0.0] 2022-04-24 This is a major release with **BREAKING** changes. Please make sure to review these changes before upgrading: diff --git a/docs/tutorial/tutorial_01.rst b/docs/tutorial/tutorial_01.rst index f0b8cb3ed..844b04538 100644 --- a/docs/tutorial/tutorial_01.rst +++ b/docs/tutorial/tutorial_01.rst @@ -8,17 +8,16 @@ You want to make your own :term:`Authorization Server` to issue access tokens to Start Your App -------------- During this tutorial you will make an XHR POST from a Heroku deployed app to your localhost instance. -Since the domain that will originate the request (the app on Heroku) is different from the destination domain (your local instance), -you will need to install the `django-cors-headers `_ app. +Since the domain that will originate the request (the app on Heroku) is different than the destination domain (your local instance), you will need to use the cors-middleware that we're providing. These "cross-domain" requests are by default forbidden by web browsers unless you use `CORS `_. -Create a virtualenv and install `django-oauth-toolkit` and `django-cors-headers`: +Create a virtualenv and install `django-oauth-toolkit`: :: - pip install django-oauth-toolkit django-cors-headers + pip install django-oauth-toolkit -Start a Django project, add `oauth2_provider` and `corsheaders` to the installed apps, and enable admin: +Start a Django project, add `oauth2_provider` to the installed apps, and enable admin: .. code-block:: python @@ -26,7 +25,6 @@ Start a Django project, add `oauth2_provider` and `corsheaders` to the installed 'django.contrib.admin', # ... 'oauth2_provider', - 'corsheaders', } Include the Django OAuth Toolkit urls in your `urls.py`, choosing the urlspace you prefer. For example: @@ -49,17 +47,11 @@ CorsMiddleware should be placed as high as possible, especially before any middl MIDDLEWARE = ( # ... - 'corsheaders.middleware.CorsMiddleware', + 'oauth2_provider.middleware.CorsMiddleware', # ... ) -Allow CORS requests from all domains (just for the scope of this tutorial): - -.. code-block:: python - - CORS_ORIGIN_ALLOW_ALL = True - -.. _loginTemplate: +This will allow CORS requests from the redirect uris of your applications. Include the required hidden input in your login template, `registration/login.html`. The ``{{ next }}`` template context variable will be populated with the correct diff --git a/oauth2_provider/middleware.py b/oauth2_provider/middleware.py index 17ba6c35f..3f65165cf 100644 --- a/oauth2_provider/middleware.py +++ b/oauth2_provider/middleware.py @@ -1,6 +1,9 @@ +from django import http from django.contrib.auth import authenticate from django.utils.cache import patch_vary_headers +from .models import AbstractApplication, Application + class OAuth2TokenMiddleware: """ @@ -36,3 +39,44 @@ def __call__(self, request): response = self.get_response(request) patch_vary_headers(response, ("Authorization",)) return response + + +HEADERS = ("x-requested-with", "content-type", "accept", "origin", "authorization", "x-csrftoken") +METHODS = ("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") + + +class CorsMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + """If this is a preflight-request, we must always return 200""" + if request.method == "OPTIONS" and "HTTP_ACCESS_CONTROL_REQUEST_METHOD" in request.META: + response = http.HttpResponse() + else: + response = self.get_response(request) + + """Add cors-headers to request if they can be derived correctly""" + try: + cors_allow_origin = _get_cors_allow_origin_header(request) + except AbstractApplication.NoSuitableOriginFoundError: + pass + else: + response["Access-Control-Allow-Origin"] = cors_allow_origin + response["Access-Control-Allow-Credentials"] = "true" + if request.method == "OPTIONS": + response["Access-Control-Allow-Headers"] = ", ".join(HEADERS) + response["Access-Control-Allow-Methods"] = ", ".join(METHODS) + return response + + +def _get_cors_allow_origin_header(request): + """Fetch the oauth-application that is responsible for making the + request and return a sutible cors-header, or None + """ + origin = request.META.get("HTTP_ORIGIN") + if origin: + app = Application.objects.filter(redirect_uris__contains=origin).first() + if app is not None: + return app.get_cors_header(origin) + raise AbstractApplication.NoSuitableOriginFoundError() diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 1ded7a4e2..cef815710 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -203,6 +203,24 @@ def get_allowed_schemes(self): """ return oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES + def get_cors_header(self, origin): + """Return a proper cors-header for this origin, in the context of this + application. + + :param origin: Origin-url from HTTP-request. + :raises: Application.NoSuitableOriginFoundError + """ + parsed_origin = urlparse(origin) + for allowed_uri in self.redirect_uris.split(): + parsed_allowed_uri = urlparse(allowed_uri) + if ( + parsed_allowed_uri.scheme == parsed_origin.scheme + and parsed_allowed_uri.netloc == parsed_origin.netloc + and parsed_allowed_uri.port == parsed_origin.port + ): + return origin + raise Application.NoSuitableOriginFoundError + def allows_grant_type(self, *grant_types): return self.authorization_grant_type in grant_types @@ -224,6 +242,9 @@ def jwk_key(self): return jwk.JWK(kty="oct", k=base64url_encode(self.client_secret)) raise ImproperlyConfigured("This application does not support signed tokens") + class NoSuitableOriginFoundError(Exception): + pass + class ApplicationManager(models.Manager): def get_by_natural_key(self, client_id): diff --git a/tests/mig_settings.py b/tests/mig_settings.py index 8f77d1190..2039bf6cd 100644 --- a/tests/mig_settings.py +++ b/tests/mig_settings.py @@ -42,6 +42,7 @@ ] MIDDLEWARE = [ + "oauth2_provider.middleware.CorsMiddleware", "django.middleware.security.SecurityMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.common.CommonMiddleware", diff --git a/tests/settings.py b/tests/settings.py index 9315a6e39..f81d531e3 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -64,6 +64,7 @@ ] MIDDLEWARE = ( + "oauth2_provider.middleware.CorsMiddleware", "django.middleware.common.CommonMiddleware", "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.csrf.CsrfViewMiddleware", diff --git a/tests/test_cors_middleware.py b/tests/test_cors_middleware.py new file mode 100644 index 000000000..f03d1e3e8 --- /dev/null +++ b/tests/test_cors_middleware.py @@ -0,0 +1,72 @@ +from datetime import timedelta + +from django.contrib.auth import get_user_model +from django.test import Client, TestCase, override_settings +from django.utils import timezone + +from oauth2_provider.models import AccessToken, get_application_model + + +Application = get_application_model() +UserModel = get_user_model() + + +@override_settings( + AUTHENTICATION_BACKENDS=("oauth2_provider.backends.OAuth2Backend",), + MIDDLEWARE_CLASSES=( + "oauth2_provider.middleware.OAuth2TokenMiddleware", + "oauth2_provider.middleware.CorsMiddleware", + ), +) +class TestCORSMiddleware(TestCase): + def setUp(self): + self.user = UserModel.objects.create_user("test_user", "test@user.com") + self.application = Application.objects.create( + name="Test Application", + redirect_uris="https://foo.bar", + user=self.user, + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + ) + + self.access_token = AccessToken.objects.create( + user=self.user, + scope="read write", + expires=timezone.now() + timedelta(seconds=300), + token="secret-access-token-key", + application=self.application, + ) + + auth_header = "Bearer {0}".format(self.access_token.token) + self.client = Client(HTTP_AUTHORIZATION=auth_header) + + def test_cors_successful(self): + """Ensure that we get cors-headers according to our oauth-app""" + resp = self.client.post("/cors-test/", HTTP_ORIGIN="https://foo.bar") + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp["Access-Control-Allow-Origin"], "https://foo.bar") + self.assertEqual(resp["Access-Control-Allow-Credentials"], "true") + + def test_cors_no_auth(self): + """Ensure that CORS-headers are sent non-authenticated requests""" + client = Client() + resp = client.post("/cors-test/", HTTP_ORIGIN="https://foo.bar") + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp["Access-Control-Allow-Origin"], "https://foo.bar") + self.assertEqual(resp["Access-Control-Allow-Credentials"], "true") + + def test_cors_wrong_origin(self): + """Ensure that CORS-headers aren't sent to requests from wrong origin""" + resp = self.client.post("/cors-test/", HTTP_ORIGIN="https://bar.foo") + self.assertEqual(resp.status_code, 200) + self.assertFalse(resp.has_header("Access-Control-Allow-Origin")) + + def test_cors_200_preflight(self): + """Ensure that preflight always get 200 responses""" + resp = self.client.options( + "/cors-test/", HTTP_ACCESS_CONTROL_REQUEST_METHOD="GET", HTTP_ORIGIN="https://foo.bar" + ) + self.assertEqual(resp.status_code, 200) + self.assertEqual(resp["Access-Control-Allow-Origin"], "https://foo.bar") + self.assertTrue(resp.has_header("Access-Control-Allow-Headers")) + self.assertTrue(resp.has_header("Access-Control-Allow-Methods")) diff --git a/tests/urls.py b/tests/urls.py index 0661a9336..bd09c815b 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,6 +1,8 @@ from django.contrib import admin from django.urls import include, path +from .views import MockView + admin.autodiscover() @@ -8,4 +10,5 @@ urlpatterns = [ path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")), path("admin/", admin.site.urls), + path("cors-test/", MockView.as_view()), ] diff --git a/tests/views.py b/tests/views.py new file mode 100644 index 000000000..f2f062a36 --- /dev/null +++ b/tests/views.py @@ -0,0 +1,7 @@ +from django.http import HttpResponse +from django.views.generic import View + + +class MockView(View): + def post(self, request): + return HttpResponse()