Skip to content

Commit ce6a35f

Browse files
akanstantsinaudopry
authored andcommitted
Fix CORS by passing 'Origin' header to OAuthLib
It is possible to control CORS by overriding is_origin_allowed method of RequestValidator class. OAuthLib allows origin if: - is_origin_allowed returns True for particular request - Request connection is secure - Request has 'Origin' header
1 parent 9aa27c7 commit ce6a35f

File tree

3 files changed

+120
-0
lines changed

3 files changed

+120
-0
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ Alejandro Mantecon Guillen
1616
Aleksander Vaskevich
1717
Alessandro De Angelis
1818
Alex Szabó
19+
Aliaksei Kanstantsinau
1920
Allisson Azevedo
2021
Andrea Greco
2122
Andrej Zbín

oauth2_provider/oauth2_backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def extract_headers(self, request):
7575
del headers["wsgi.errors"]
7676
if "HTTP_AUTHORIZATION" in headers:
7777
headers["Authorization"] = headers["HTTP_AUTHORIZATION"]
78+
if "HTTP_ORIGIN" in headers:
79+
headers["Origin"] = headers["HTTP_ORIGIN"]
7880
if request.is_secure():
7981
headers["X_DJANGO_OAUTH_TOOLKIT_SECURE"] = "1"
8082
elif "X_DJANGO_OAUTH_TOOLKIT_SECURE" in headers:

tests/test_cors.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from urllib.parse import parse_qs, urlparse
2+
3+
import pytest
4+
from django.contrib.auth import get_user_model
5+
from django.test import RequestFactory, TestCase
6+
from django.urls import reverse
7+
8+
from oauth2_provider.models import get_application_model
9+
from oauth2_provider.oauth2_validators import OAuth2Validator
10+
11+
from . import presets
12+
from .utils import get_basic_auth_header
13+
14+
15+
class CorsOAuth2Validator(OAuth2Validator):
16+
def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
17+
"""Enable CORS in OAuthLib"""
18+
return True
19+
20+
21+
Application = get_application_model()
22+
UserModel = get_user_model()
23+
24+
CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz"
25+
26+
# CORS is allowed for https only
27+
CLIENT_URI = "https://example.org"
28+
29+
30+
@pytest.mark.usefixtures("oauth2_settings")
31+
@pytest.mark.oauth2_settings(presets.DEFAULT_SCOPES_RW)
32+
class CorsTest(TestCase):
33+
"""
34+
Test that CORS headers can be managed by OAuthLib.
35+
The objective is: http request 'Origin' header should be passed to OAuthLib
36+
"""
37+
38+
def setUp(self):
39+
self.factory = RequestFactory()
40+
self.test_user = UserModel.objects.create_user("test_user", "[email protected]", "123456")
41+
self.dev_user = UserModel.objects.create_user("dev_user", "[email protected]", "123456")
42+
43+
self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"]
44+
self.oauth2_settings.PKCE_REQUIRED = False
45+
46+
self.application = Application.objects.create(
47+
name="Test Application",
48+
redirect_uris=(CLIENT_URI),
49+
user=self.dev_user,
50+
client_type=Application.CLIENT_CONFIDENTIAL,
51+
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
52+
client_secret=CLEARTEXT_SECRET,
53+
)
54+
55+
self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"]
56+
self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator
57+
58+
def tearDown(self):
59+
self.application.delete()
60+
self.test_user.delete()
61+
self.dev_user.delete()
62+
63+
def test_cors_header(self):
64+
"""
65+
Test that /token endpoint has Access-Control-Allow-Origin
66+
"""
67+
authorization_code = self._get_authorization_code()
68+
69+
# exchange authorization code for a valid access token
70+
token_request_data = {
71+
"grant_type": "authorization_code",
72+
"code": authorization_code,
73+
"redirect_uri": CLIENT_URI,
74+
}
75+
76+
auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET)
77+
auth_headers["origin"] = CLIENT_URI
78+
79+
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
80+
self.assertEqual(response.status_code, 200)
81+
self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI)
82+
83+
def test_no_cors_header(self):
84+
"""
85+
Test that /token endpoint does not have Access-Control-Allow-Origin
86+
"""
87+
authorization_code = self._get_authorization_code()
88+
89+
# exchange authorization code for a valid access token
90+
token_request_data = {
91+
"grant_type": "authorization_code",
92+
"code": authorization_code,
93+
"redirect_uri": CLIENT_URI,
94+
}
95+
96+
auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET)
97+
98+
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
99+
self.assertEqual(response.status_code, 200)
100+
# No CORS headers, because request did not have Origin
101+
self.assertFalse(response.has_header("Access-Control-Allow-Origin"))
102+
103+
def _get_authorization_code(self):
104+
self.client.login(username="test_user", password="123456")
105+
106+
# retrieve a valid authorization code
107+
authcode_data = {
108+
"client_id": self.application.client_id,
109+
"state": "random_state_string",
110+
"scope": "read write",
111+
"redirect_uri": "https://example.org",
112+
"response_type": "code",
113+
"allow": True,
114+
}
115+
response = self.client.post(reverse("oauth2_provider:authorize"), data=authcode_data)
116+
query_dict = parse_qs(urlparse(response["Location"]).query)
117+
return query_dict["code"].pop()

0 commit comments

Comments
 (0)