Skip to content

Commit 24ee096

Browse files
committed
Approach for adding a CORS-middleware
1 parent 7cba4a5 commit 24ee096

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed

oauth2_provider/middleware.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from django import http
12
from django.contrib.auth import authenticate
23
from django.utils.cache import patch_vary_headers
34

5+
from .models import Application
6+
47

58
class OAuth2TokenMiddleware(object):
69
"""
@@ -32,3 +35,46 @@ def process_request(self, request):
3235
def process_response(self, request, response):
3336
patch_vary_headers(response, ('Authorization',))
3437
return response
38+
39+
HEADERS = ('x-requested-with', 'content-type', 'accept', 'origin',
40+
'authorization', 'x-csrftoken')
41+
METHODS = ('GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS')
42+
43+
44+
class CorsMiddleware(object):
45+
def process_request(self, request):
46+
'''If this is a preflight-request, we must always return 200'''
47+
if (request.method == 'OPTIONS' and
48+
'HTTP_ACCESS_CONTROL_REQUEST_METHOD' in request.META):
49+
return http.HttpResponse()
50+
return None
51+
52+
def process_response(self, request, response):
53+
'''Add cors-headers to request if they can be derived correctly'''
54+
try:
55+
cors_allow_origin = _get_cors_allow_origin_header(request)
56+
except Application.NoSuitableOriginFoundError:
57+
pass
58+
else:
59+
response['Access-Control-Allow-Origin'] = cors_allow_origin
60+
response['Access-Control-Allow-Credentials'] = 'true'
61+
if request.method == 'OPTIONS':
62+
response['Access-Control-Allow-Headers'] = ', '.join(HEADERS)
63+
response['Access-Control-Allow-Methods'] = ', '.join(METHODS)
64+
return response
65+
66+
67+
def _get_cors_allow_origin_header(request):
68+
'''Fetch the oauth-application that is responsible for making the
69+
request and return a sutible cors-header, or None
70+
'''
71+
origin = request.META.get('HTTP_ORIGIN')
72+
if origin:
73+
try:
74+
app = Application.objects.filter(redirect_uris__contains=origin)[0]
75+
except IndexError:
76+
# No application for this origin found
77+
pass
78+
else:
79+
return app.get_cors_header(origin)
80+
raise Application.NoSuitableOriginFoundError

oauth2_provider/models.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,31 @@ def clean(self):
120120
error = _('Redirect_uris could not be empty with {0} grant_type')
121121
raise ValidationError(error.format(self.authorization_grant_type))
122122

123+
def get_cors_header(self, origin):
124+
'''Return a proper cors-header for this origin, in the context of this
125+
application.
126+
127+
:param origin: Origin-url from HTTP-request.
128+
:raises: Application.NoSuitableOriginFoundError
129+
'''
130+
parsed_origin = urlparse(origin)
131+
for allowed_uri in self.redirect_uris.split():
132+
parsed_allowed_uri = urlparse(allowed_uri)
133+
if (parsed_allowed_uri.scheme == parsed_origin.scheme and
134+
parsed_allowed_uri.netloc == parsed_origin.netloc and
135+
parsed_allowed_uri.port == parsed_origin.port):
136+
return origin
137+
raise Application.NoSuitableOriginFoundError
138+
123139
def get_absolute_url(self):
124140
return reverse('oauth2_provider:detail', args=[str(self.id)])
125141

126142
def __str__(self):
127143
return self.name or self.client_id
128144

145+
class NoSuitableOriginFoundError(Exception):
146+
pass
147+
129148

130149
class Application(AbstractApplication):
131150
class Meta(AbstractApplication.Meta):
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from datetime import timedelta
2+
3+
from django.test import TestCase, Client, override_settings
4+
from django.utils import timezone
5+
from django.conf.urls import patterns, url
6+
from django.http import HttpResponse
7+
from django.views.generic import View
8+
9+
from ..models import AccessToken, get_application_model
10+
from django.contrib.auth import get_user_model
11+
12+
13+
Application = get_application_model()
14+
UserModel = get_user_model()
15+
16+
17+
class MockView(View):
18+
def post(self, request):
19+
return HttpResponse()
20+
21+
urlpatterns = patterns(
22+
'',
23+
url(r'^cors-test/$', MockView.as_view()),
24+
)
25+
26+
27+
@override_settings(
28+
ROOT_URLCONF='oauth2_provider.tests.test_cors_middleware',
29+
AUTHENTICATION_BACKENDS=('oauth2_provider.backends.OAuth2Backend',),
30+
MIDDLEWARE_CLASSES=(
31+
'oauth2_provider.middleware.OAuth2TokenMiddleware',
32+
'oauth2_provider.middleware.CorsMiddleware',
33+
))
34+
class TestCORSMiddleware(TestCase):
35+
def setUp(self):
36+
self.user = UserModel.objects.create_user('test_user', '[email protected]')
37+
self.application = Application.objects.create(
38+
name='Test Application',
39+
redirect_uris='https://foo.bar',
40+
user=self.user,
41+
client_type=Application.CLIENT_CONFIDENTIAL,
42+
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
43+
)
44+
45+
self.access_token = AccessToken.objects.create(
46+
user=self.user,
47+
scope='read write',
48+
expires=timezone.now() + timedelta(seconds=300),
49+
token='secret-access-token-key',
50+
application=self.application
51+
)
52+
53+
auth_header = "Bearer {0}".format(self.access_token.token)
54+
self.client = Client(HTTP_AUTHORIZATION=auth_header)
55+
56+
def test_cors_successful(self):
57+
'''Ensure that we get cors-headers according to our oauth-app'''
58+
resp = self.client.post('/cors-test/', HTTP_ORIGIN='https://foo.bar')
59+
self.assertEqual(resp.status_code, 200)
60+
self.assertEqual(resp['Access-Control-Allow-Origin'], 'https://foo.bar')
61+
self.assertEqual(resp['Access-Control-Allow-Credentials'], 'true')
62+
63+
def test_cors_no_auth(self):
64+
'''Ensure that CORS-headers are sent non-authenticated requests'''
65+
client = Client()
66+
resp = client.post('/cors-test/', HTTP_ORIGIN='https://foo.bar')
67+
self.assertEqual(resp.status_code, 200)
68+
self.assertEqual(resp['Access-Control-Allow-Origin'], 'https://foo.bar')
69+
self.assertEqual(resp['Access-Control-Allow-Credentials'], 'true')
70+
71+
def test_cors_wrong_origin(self):
72+
'''Ensure that CORS-headers aren't sent to requests from wrong origin'''
73+
resp = self.client.post('/cors-test/', HTTP_ORIGIN='https://bar.foo')
74+
self.assertEqual(resp.status_code, 200)
75+
self.assertFalse(resp.has_header('Access-Control-Allow-Origin'))
76+
77+
def test_cors_200_preflight(self):
78+
'''Ensure that preflight always get 200 responses'''
79+
resp = self.client.options('/cors-test/',
80+
HTTP_ACCESS_CONTROL_REQUEST_METHOD='GET',
81+
HTTP_ORIGIN='https://foo.bar')
82+
self.assertEqual(resp.status_code, 200)
83+
self.assertEqual(resp['Access-Control-Allow-Origin'], 'https://foo.bar')
84+
self.assertTrue(resp.has_header('Access-Control-Allow-Headers'))
85+
self.assertTrue(resp.has_header('Access-Control-Allow-Methods'))

0 commit comments

Comments
 (0)