Skip to content

Commit 30cad87

Browse files
committed
Update CORS middleware and tests
1 parent e8e0148 commit 30cad87

File tree

6 files changed

+33
-22
lines changed

6 files changed

+33
-22
lines changed

oauth2_provider/middleware.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from django.contrib.auth import authenticate
33
from django.utils.cache import patch_vary_headers
44

5-
from .models import Application
5+
from .models import AbstractApplication, Application
66

77

88
class OAuth2TokenMiddleware:
@@ -45,18 +45,22 @@ def __call__(self, request):
4545
METHODS = ("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS")
4646

4747

48-
class CorsMiddleware(object):
49-
def process_request(self, request):
48+
class CorsMiddleware:
49+
50+
def __init__(self, get_response):
51+
self.get_response = get_response
52+
53+
def __call__(self, request):
5054
"""If this is a preflight-request, we must always return 200"""
5155
if request.method == "OPTIONS" and "HTTP_ACCESS_CONTROL_REQUEST_METHOD" in request.META:
52-
return http.HttpResponse()
53-
return None
56+
response = http.HttpResponse()
57+
else:
58+
response = self.get_response(request)
5459

55-
def process_response(self, request, response):
5660
"""Add cors-headers to request if they can be derived correctly"""
5761
try:
5862
cors_allow_origin = _get_cors_allow_origin_header(request)
59-
except Application.NoSuitableOriginFoundError:
63+
except AbstractApplication.NoSuitableOriginFoundError:
6064
pass
6165
else:
6266
response["Access-Control-Allow-Origin"] = cors_allow_origin

oauth2_provider/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,21 @@ def get_allowed_schemes(self):
221221
"""
222222
return oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES
223223

224+
def get_cors_header(self, origin):
225+
'''Return a proper cors-header for this origin, in the context of this
226+
application.
227+
:param origin: Origin-url from HTTP-request.
228+
:raises: Application.NoSuitableOriginFoundError
229+
'''
230+
parsed_origin = urlparse(origin)
231+
for allowed_uri in self.redirect_uris.split():
232+
parsed_allowed_uri = urlparse(allowed_uri)
233+
if (parsed_allowed_uri.scheme == parsed_origin.scheme and
234+
parsed_allowed_uri.netloc == parsed_origin.netloc and
235+
parsed_allowed_uri.port == parsed_origin.port):
236+
return origin
237+
raise AbstractApplication.NoSuitableOriginFoundError
238+
224239
def allows_grant_type(self, *grant_types):
225240
return self.authorization_grant_type in grant_types
226241

@@ -243,6 +258,9 @@ def jwk_key(self):
243258
raise ImproperlyConfigured("This application does not support signed tokens")
244259

245260

261+
class NoSuitableOriginFoundError(Exception):
262+
pass
263+
246264
class ApplicationManager(models.Manager):
247265
def get_by_natural_key(self, client_id):
248266
return self.get(client_id=client_id)

tests/mig_settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
]
4343

4444
MIDDLEWARE = [
45+
"oauth2_provider.middleware.CorsMiddleware",
4546
"django.middleware.security.SecurityMiddleware",
4647
"django.contrib.sessions.middleware.SessionMiddleware",
4748
"django.middleware.common.CommonMiddleware",

tests/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
]
6464

6565
MIDDLEWARE = (
66+
"oauth2_provider.middleware.CorsMiddleware",
6667
"django.middleware.common.CommonMiddleware",
6768
"django.contrib.sessions.middleware.SessionMiddleware",
6869
"django.middleware.csrf.CsrfViewMiddleware",

tests/test_cors_middleware.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from datetime import timedelta
22

3-
from django.conf.urls import patterns, url
43
from django.contrib.auth import get_user_model
5-
from django.http import HttpResponse
64
from django.test import Client, TestCase, override_settings
75
from django.utils import timezone
8-
from django.views.generic import View
96

107
from oauth2_provider.models import AccessToken, get_application_model
118

@@ -14,19 +11,7 @@
1411
UserModel = get_user_model()
1512

1613

17-
class MockView(View):
18-
def post(self, request):
19-
return HttpResponse()
20-
21-
22-
urlpatterns = patterns(
23-
"",
24-
url(r"^cors-test/$", MockView.as_view()),
25-
)
26-
27-
2814
@override_settings(
29-
ROOT_URLCONF="oauth2_provider.tests.test_cors_middleware",
3015
AUTHENTICATION_BACKENDS=("oauth2_provider.backends.OAuth2Backend",),
3116
MIDDLEWARE_CLASSES=(
3217
"oauth2_provider.middleware.OAuth2TokenMiddleware",

tests/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from django.contrib import admin
22
from django.urls import include, path
3+
from .views import MockView
34

45

56
admin.autodiscover()
@@ -8,4 +9,5 @@
89
urlpatterns = [
910
path("o/", include("oauth2_provider.urls", namespace="oauth2_provider")),
1011
path("admin/", admin.site.urls),
12+
path("cors-test/", MockView.as_view()),
1113
]

0 commit comments

Comments
 (0)