Skip to content

Commit cc4e470

Browse files
Added Allowed Origins application setting
1 parent bdc26e7 commit cc4e470

File tree

9 files changed

+134
-12
lines changed

9 files changed

+134
-12
lines changed

docs/tutorial/tutorial_01.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ point your browser to http://localhost:8000/o/applications/ and add an Applicati
9191
specifies one of the verified redirection uris. For this tutorial, paste verbatim the value
9292
`https://www.getpostman.com/oauth2/callback`
9393

94+
* `Allowed origins`: Web applications use Cross-Origin Resource Sharing (CORS) to request resources from origins other than their own.
95+
You can provide list of origins of web applications that will have access to the token endpoint of :term:`Authorization Server`.
96+
This setting controls only token endpoint and it is not related with Django CORS Headers settings.
97+
9498
* `Client type`: this value affects the security level at which some communications between the client application and
9599
the authorization server are performed. For this tutorial choose *Confidential*.
96100

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Generated by Django 4.1.5 on 2023-09-27 20:15
2+
3+
from django.db import migrations, models
4+
5+
6+
class Migration(migrations.Migration):
7+
8+
dependencies = [
9+
("oauth2_provider", "0009_add_hash_client_secret"),
10+
]
11+
12+
operations = [
13+
migrations.AddField(
14+
model_name="application",
15+
name="allowed_origins",
16+
field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"),
17+
),
18+
]

oauth2_provider/models.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from .scopes import get_scopes_backend
2121
from .settings import oauth2_settings
2222
from .utils import jwk_from_pem
23-
from .validators import RedirectURIValidator, WildcardSet
24-
23+
from .validators import RedirectURIValidator, WildcardSet, URIValidator
2524

2625
logger = logging.getLogger(__name__)
2726

@@ -132,7 +131,10 @@ class AbstractApplication(models.Model):
132131
created = models.DateTimeField(auto_now_add=True)
133132
updated = models.DateTimeField(auto_now=True)
134133
algorithm = models.CharField(max_length=5, choices=ALGORITHM_TYPES, default=NO_ALGORITHM, blank=True)
135-
134+
allowed_origins = models.TextField(
135+
blank=True,
136+
help_text=_("Allowed origins list to enable CORS, space separated"),
137+
)
136138
class Meta:
137139
abstract = True
138140

@@ -172,6 +174,14 @@ def post_logout_redirect_uri_allowed(self, uri):
172174
"""
173175
return redirect_to_uri_allowed(uri, self.post_logout_redirect_uris.split())
174176

177+
def origin_allowed(self, origin):
178+
"""
179+
Checks if given origin is one of the items in :attr:`allowed_origins` string
180+
181+
:param origin: Origin to check
182+
"""
183+
return self.allowed_origins and is_origin_allowed(origin, self.allowed_origins.split())
184+
175185
def clean(self):
176186
from django.core.exceptions import ValidationError
177187

@@ -202,6 +212,13 @@ def clean(self):
202212
grant_type=self.authorization_grant_type
203213
)
204214
)
215+
allowed_origins = self.allowed_origins.strip().split()
216+
if allowed_origins:
217+
# oauthlib allows only https scheme for CORS
218+
validator = URIValidator({"https"})
219+
for uri in allowed_origins:
220+
validator(uri)
221+
205222
if self.algorithm == AbstractApplication.RS256_ALGORITHM:
206223
if not oauth2_settings.OIDC_RSA_PRIVATE_KEY:
207224
raise ValidationError(_("You must set OIDC_RSA_PRIVATE_KEY to use RSA algorithm"))
@@ -777,3 +794,20 @@ def redirect_to_uri_allowed(uri, allowed_uris):
777794
return True
778795

779796
return False
797+
798+
799+
def is_origin_allowed(origin, allowed_origins):
800+
"""
801+
Checks if a given origin uri is allowed based on the provided allowed_origins configuration.
802+
803+
:param origin: Origin URI to check
804+
:param allowed_origins: A list of Origin URIs that are allowed
805+
"""
806+
807+
parsed_origin = urlparse(origin)
808+
for allowed_origin in allowed_origins:
809+
parsed_allowed_origin = urlparse(allowed_origin)
810+
if (parsed_allowed_origin.scheme == parsed_origin.scheme
811+
and parsed_allowed_origin.netloc == parsed_origin.netloc):
812+
return True
813+
return False

oauth2_provider/oauth2_backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def extract_headers(self, request):
7575
del headers["wsgi.errors"]
7676
if "HTTP_AUTHORIZATION" in headers:
7777
headers["Authorization"] = headers["HTTP_AUTHORIZATION"]
78+
# Add Access-Control-Allow-Origin header to the token endpoint response for authentication code grant, if the origin is allowed by RequestValidator.is_origin_allowed.
79+
# https://github.com/oauthlib/oauthlib/pull/791
7880
if "HTTP_ORIGIN" in headers:
7981
headers["Origin"] = headers["HTTP_ORIGIN"]
8082
if request.is_secure():

oauth2_provider/oauth2_validators.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,3 +958,12 @@ def get_userinfo_claims(self, request):
958958

959959
def get_additional_claims(self, request):
960960
return {}
961+
962+
def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
963+
if request.client is None or not request.client.client_id:
964+
return False
965+
application = Application.objects.filter(client_id=request.client.client_id).first()
966+
if application:
967+
return application.origin_allowed(origin)
968+
else:
969+
return False

oauth2_provider/views/application.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_form_class(self):
3939
"authorization_grant_type",
4040
"redirect_uris",
4141
"post_logout_redirect_uris",
42+
"allowed_origins",
4243
"algorithm",
4344
),
4445
)
@@ -99,6 +100,7 @@ def get_form_class(self):
99100
"authorization_grant_type",
100101
"redirect_uris",
101102
"post_logout_redirect_uris",
103+
"allowed_origins",
102104
"algorithm",
103105
),
104106
)

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def application():
108108
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
109109
algorithm=Application.RS256_ALGORITHM,
110110
client_secret=CLEARTEXT_SECRET,
111+
allowed_origins="https://example.com",
111112
)
112113

113114

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Generated by Django 4.1.5 on 2023-09-27 22:25
2+
3+
from django.conf import settings
4+
from django.db import migrations, models
5+
import django.db.models.deletion
6+
7+
8+
class Migration(migrations.Migration):
9+
10+
dependencies = [
11+
migrations.swappable_dependency(settings.OAUTH2_PROVIDER_ID_TOKEN_MODEL),
12+
("tests", "0004_basetestapplication_hash_client_secret_and_more"),
13+
]
14+
15+
operations = [
16+
migrations.AddField(
17+
model_name="basetestapplication",
18+
name="allowed_origins",
19+
field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"),
20+
),
21+
migrations.AddField(
22+
model_name="sampleapplication",
23+
name="allowed_origins",
24+
field=models.TextField(blank=True, help_text="Allowed origins list to enable CORS, space separated"),
25+
),
26+
]

tests/test_cors.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from urllib.parse import parse_qs, urlparse
23

34
import pytest
@@ -6,18 +7,11 @@
67
from django.urls import reverse
78

89
from oauth2_provider.models import get_application_model
9-
from oauth2_provider.oauth2_validators import OAuth2Validator
1010

1111
from . import presets
1212
from .utils import get_basic_auth_header
1313

1414

15-
class CorsOAuth2Validator(OAuth2Validator):
16-
def is_origin_allowed(self, client_id, origin, request, *args, **kwargs):
17-
"""Enable CORS in OAuthLib"""
18-
return True
19-
20-
2115
Application = get_application_model()
2216
UserModel = get_user_model()
2317

@@ -50,10 +44,10 @@ def setUp(self):
5044
client_type=Application.CLIENT_CONFIDENTIAL,
5145
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
5246
client_secret=CLEARTEXT_SECRET,
47+
allowed_origins=CLIENT_URI,
5348
)
5449

5550
self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https"]
56-
self.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CorsOAuth2Validator
5751

5852
def tearDown(self):
5953
self.application.delete()
@@ -76,10 +70,42 @@ def test_cors_header(self):
7670
auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET)
7771
auth_headers["HTTP_ORIGIN"] = CLIENT_URI
7872
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
73+
74+
content = json.loads(response.content.decode("utf-8"))
75+
76+
self.assertEqual(response.status_code, 200)
77+
self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI)
78+
79+
token_request_data = {
80+
"grant_type": "refresh_token",
81+
"refresh_token": content["refresh_token"],
82+
"scope": content["scope"],
83+
}
84+
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
7985
self.assertEqual(response.status_code, 200)
8086
self.assertEqual(response["Access-Control-Allow-Origin"], CLIENT_URI)
8187

82-
def test_no_cors_header(self):
88+
def test_no_cors_header_origin_not_allowed(self):
89+
"""
90+
Test that /token endpoint does not have Access-Control-Allow-Origin
91+
when request origin is not in Application.allowed_origins
92+
"""
93+
authorization_code = self._get_authorization_code()
94+
95+
# exchange authorization code for a valid access token
96+
token_request_data = {
97+
"grant_type": "authorization_code",
98+
"code": authorization_code,
99+
"redirect_uri": CLIENT_URI,
100+
}
101+
102+
auth_headers = get_basic_auth_header(self.application.client_id, CLEARTEXT_SECRET)
103+
auth_headers["HTTP_ORIGIN"] = "another_example.org"
104+
response = self.client.post(reverse("oauth2_provider:token"), data=token_request_data, **auth_headers)
105+
self.assertEqual(response.status_code, 200)
106+
self.assertFalse(response.has_header("Access-Control-Allow-Origin"))
107+
108+
def test_no_cors_header_no_origin(self):
83109
"""
84110
Test that /token endpoint does not have Access-Control-Allow-Origin
85111
"""

0 commit comments

Comments
 (0)