Skip to content

Commit ec61ec2

Browse files
akanstantsinaudopry
authored andcommitted
Add more tests for origin validators
1 parent 45ea962 commit ec61ec2

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

oauth2_provider/validators.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __call__(self, value):
3434
class AllowedURIValidator(URIValidator):
3535
def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False):
3636
"""
37-
:params schemes: List of allowed schemes. E.g.: ["https"]
38-
:params name: Name of the validater URI required for validation message. E.g.: "Origin"
39-
:params allow_path: If URI can contain path part
40-
:params allow_query: If URI can contain query part
41-
:params allow_fragments: If URI can contain fragments part
37+
:param schemes: List of allowed schemes. E.g.: ["https"]
38+
:param name: Name of the validated URI. It is required for validation message. E.g.: "Origin"
39+
:param allow_path: If URI can contain path part
40+
:param allow_query: If URI can contain query part
41+
:param allow_fragments: If URI can contain fragments part
4242
"""
4343
super().__init__(schemes=schemes)
4444
self.name = name
@@ -50,12 +50,12 @@ def __call__(self, value):
5050
super().__call__(value)
5151
value = force_str(value)
5252
scheme, netloc, path, query, fragment = urlsplit(value)
53-
if path and not self.allow_path:
54-
raise ValidationError("{} URIs must not contain path".format(self.name))
5553
if query and not self.allow_query:
5654
raise ValidationError("{} URIs must not contain query".format(self.name))
5755
if fragment and not self.allow_fragments:
5856
raise ValidationError("{} URIs must not contain fragments".format(self.name))
57+
if path and not self.allow_path:
58+
raise ValidationError("{} URIs must not contain path".format(self.name))
5959

6060

6161
##

tests/conftest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,18 @@ def public_application():
124124
)
125125

126126

127+
@pytest.fixture
128+
def cors_application():
129+
return Application.objects.create(
130+
name="Test CORS Application",
131+
client_type=Application.CLIENT_CONFIDENTIAL,
132+
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
133+
algorithm=Application.RS256_ALGORITHM,
134+
client_secret=CLEARTEXT_SECRET,
135+
allowed_origins="https://example.com http://example.com",
136+
)
137+
138+
127139
@pytest.fixture
128140
def logged_in_client(test_user):
129141
from django.test.client import Client

tests/presets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,11 @@
5757
"READ_SCOPE": "read",
5858
"WRITE_SCOPE": "write",
5959
}
60+
61+
ALLOWED_SCHEMES_DEFAULT = {
62+
"ALLOWED_SCHEMES": ["https"],
63+
}
64+
65+
ALLOWED_SCHEMES_HTTP = {
66+
"ALLOWED_SCHEMES": ["https", "http"],
67+
}

tests/test_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,3 +594,19 @@ def test_application_clean(oauth2_settings, application):
594594
assert "Enter a valid URL" in str(exc.value)
595595
application.allowed_origins = "https://example.com"
596596
application.clean()
597+
598+
599+
@pytest.mark.django_db
600+
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT)
601+
def test_application_origin_allowed_default_https(oauth2_settings, cors_application):
602+
"""Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https"""
603+
assert cors_application.origin_allowed("https://example.com")
604+
assert not cors_application.origin_allowed("http://example.com")
605+
606+
607+
@pytest.mark.django_db
608+
@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP)
609+
def test_application_origin_allowed_http(oauth2_settings, cors_application):
610+
"""Test that http schemes are allowed because http was added to ALLOWED_SCHEMES"""
611+
assert cors_application.origin_allowed("https://example.com")
612+
assert cors_application.origin_allowed("http://example.com")

0 commit comments

Comments
 (0)