Skip to content

Commit 4c13679

Browse files
authored
fix: RedirectURIValidator Encapsulation (#1345)
1 parent 584627d commit 4c13679

File tree

6 files changed

+225
-54
lines changed

6 files changed

+225
-54
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2929
* #1322 Instructions in documentation on how to create a code challenge and code verifier
3030
* #1284 Allow to logout with no id_token_hint even if the browser session already expired
3131
* #1296 Added reverse function in migration 0006_alter_application_client_secret
32+
* #1336 Fix encapsulation for Redirect URI scheme validation
3233

3334
## [2.3.0] 2023-05-31
3435

oauth2_provider/models.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .scopes import get_scopes_backend
2121
from .settings import oauth2_settings
2222
from .utils import jwk_from_pem
23-
from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet
23+
from .validators import AllowedURIValidator
2424

2525

2626
logger = logging.getLogger(__name__)
@@ -202,12 +202,11 @@ def clean(self):
202202
allowed_schemes = set(s.lower() for s in self.get_allowed_schemes())
203203

204204
if redirect_uris:
205-
validator = RedirectURIValidator(WildcardSet())
205+
validator = AllowedURIValidator(
206+
allowed_schemes, name="redirect uri", allow_path=True, allow_query=True
207+
)
206208
for uri in redirect_uris:
207209
validator(uri)
208-
scheme = urlparse(uri).scheme
209-
if scheme not in allowed_schemes:
210-
raise ValidationError(_("Unauthorized redirect scheme: {scheme}").format(scheme=scheme))
211210

212211
elif self.authorization_grant_type in grant_types:
213212
raise ValidationError(
@@ -218,7 +217,7 @@ def clean(self):
218217
allowed_origins = self.allowed_origins.strip().split()
219218
if allowed_origins:
220219
# oauthlib allows only https scheme for CORS
221-
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin")
220+
validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "allowed origin")
222221
for uri in allowed_origins:
223222
validator(uri)
224223

oauth2_provider/oauth2_validators.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,6 @@ def authenticate_client_id(self, client_id, request, *args, **kwargs):
305305
proceed only if the client exists and is not of type "Confidential".
306306
"""
307307
if self._load_application(client_id, request) is not None:
308-
log.debug("Application %r has type %r" % (client_id, request.client.client_type))
309308
return request.client.client_type != AbstractApplication.CLIENT_CONFIDENTIAL
310309
return False
311310

oauth2_provider/validators.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import re
2+
import warnings
23
from urllib.parse import urlsplit
34

45
from django.core.exceptions import ValidationError
@@ -20,6 +21,7 @@ class URIValidator(URLValidator):
2021

2122
class RedirectURIValidator(URIValidator):
2223
def __init__(self, allowed_schemes, allow_fragments=False):
24+
warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning)
2325
super().__init__(schemes=allowed_schemes)
2426
self.allow_fragments = allow_fragments
2527

@@ -32,6 +34,8 @@ def __call__(self, value):
3234

3335

3436
class AllowedURIValidator(URIValidator):
37+
# TODO: find a way to get these associated with their form fields in place of passing name
38+
# TODO: submit PR to get `cause` included in the parent class ValidationError params`
3539
def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False):
3640
"""
3741
:param schemes: List of allowed schemes. E.g.: ["https"]
@@ -47,15 +51,45 @@ def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fra
4751
self.allow_fragments = allow_fragments
4852

4953
def __call__(self, value):
50-
super().__call__(value)
5154
value = force_str(value)
52-
scheme, netloc, path, query, fragment = urlsplit(value)
55+
try:
56+
scheme, netloc, path, query, fragment = urlsplit(value)
57+
except ValueError as e:
58+
raise ValidationError(
59+
"%(name)s URI validation error. %(cause)s: %(value)s",
60+
params={"name": self.name, "value": value, "cause": e},
61+
)
62+
63+
# send better validation errors
64+
if scheme not in self.schemes:
65+
raise ValidationError(
66+
"%(name)s URI Validation error. %(cause)s: %(value)s",
67+
params={"name": self.name, "value": value, "cause": "invalid_scheme"},
68+
)
69+
5370
if query and not self.allow_query:
54-
raise ValidationError("{} URIs must not contain query".format(self.name))
71+
raise ValidationError(
72+
"%(name)s URI validation error. %(cause)s: %(value)s",
73+
params={"name": self.name, "value": value, "cause": "query string not allowed"},
74+
)
5575
if fragment and not self.allow_fragments:
56-
raise ValidationError("{} URIs must not contain fragments".format(self.name))
76+
raise ValidationError(
77+
"%(name)s URI validation error. %(cause)s: %(value)s",
78+
params={"name": self.name, "value": value, "cause": "fragment not allowed"},
79+
)
5780
if path and not self.allow_path:
58-
raise ValidationError("{} URIs must not contain path".format(self.name))
81+
raise ValidationError(
82+
"%(name)s URI validation error. %(cause)s: %(value)s",
83+
params={"name": self.name, "value": value, "cause": "path not allowed"},
84+
)
85+
86+
try:
87+
super().__call__(value)
88+
except ValidationError as e:
89+
raise ValidationError(
90+
"%(name)s URI validation error. %(cause)s: %(value)s",
91+
params={"name": self.name, "value": value, "cause": e},
92+
)
5993

6094

6195
##
@@ -69,5 +103,9 @@ class WildcardSet(set):
69103
A set that always returns True on `in`.
70104
"""
71105

106+
def __init__(self, *args, **kwargs):
107+
warnings.warn("This class is deprecated and will be removed in version 2.5.0.", DeprecationWarning)
108+
super().__init__(*args, **kwargs)
109+
72110
def __contains__(self, item):
73111
return True

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def test_application_clean(oauth2_settings, application):
591591
application.allowed_origins = "http://example.com"
592592
with pytest.raises(ValidationError) as exc:
593593
application.clean()
594-
assert "Enter a valid URL" in str(exc.value)
594+
assert "allowed origin URI Validation error. invalid_scheme: http://example.com" in str(exc.value)
595595
application.allowed_origins = "https://example.com"
596596
application.clean()
597597

tests/test_validators.py

Lines changed: 175 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from django.core.validators import ValidationError
33
from django.test import TestCase
44

5-
from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator
5+
from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator, WildcardSet
66

77

88
@pytest.mark.usefixtures("oauth2_settings")
@@ -36,11 +36,6 @@ def test_validate_custom_uri_scheme(self):
3636
# Check ValidationError not thrown
3737
validator(uri)
3838

39-
validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin")
40-
for uri in good_uris:
41-
# Check ValidationError not thrown
42-
validator(uri)
43-
4439
def test_validate_bad_uris(self):
4540
validator = RedirectURIValidator(allowed_schemes=["https"])
4641
self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"]
@@ -67,47 +62,73 @@ def test_validate_bad_uris(self):
6762
with self.assertRaises(ValidationError):
6863
validator(uri)
6964

70-
def test_validate_good_origin_uris(self):
71-
"""
72-
Test AllowedURIValidator validates origin URIs if they match requirements
73-
"""
74-
validator = AllowedURIValidator(
75-
["https"],
76-
"Origin",
77-
allow_path=False,
78-
allow_query=False,
79-
allow_fragments=False,
80-
)
65+
def test_validate_wildcard_scheme__bad_uris(self):
66+
validator = RedirectURIValidator(allowed_schemes=WildcardSet())
67+
bad_uris = [
68+
"http:/example.com#fragment",
69+
"HTTP://localhost#fragment",
70+
"http://example.com/#fragment",
71+
"good://example.com/#fragment",
72+
" ",
73+
"",
74+
# Bad IPv6 URL, urlparse behaves differently for these
75+
'https://["><script>alert()</script>',
76+
]
77+
78+
for uri in bad_uris:
79+
with self.assertRaises(ValidationError, msg=uri):
80+
validator(uri)
81+
82+
def test_validate_wildcard_scheme_good_uris(self):
83+
validator = RedirectURIValidator(allowed_schemes=WildcardSet())
8184
good_uris = [
85+
"my-scheme://example.com",
86+
"my-scheme://example",
87+
"my-scheme://localhost",
8288
"https://example.com",
83-
"https://example.com:8080",
84-
"https://example",
85-
"https://localhost",
86-
"https://1.1.1.1",
87-
"https://127.0.0.1",
88-
"https://255.255.255.255",
89+
"HTTPS://example.com",
90+
"HTTPS://example.com.",
91+
"git+ssh://example.com",
92+
"ANY://localhost",
93+
"scheme://example.com",
94+
"at://example.com",
95+
"all://example.com",
8996
]
9097
for uri in good_uris:
9198
# Check ValidationError not thrown
9299
validator(uri)
93100

94-
def test_validate_bad_origin_uris(self):
95-
"""
96-
Test AllowedURIValidator rejects origin URIs if they do not match requirements
97-
"""
98-
validator = AllowedURIValidator(
99-
["https"],
100-
"Origin",
101-
allow_path=False,
102-
allow_query=False,
103-
allow_fragments=False,
104-
)
101+
102+
@pytest.mark.usefixtures("oauth2_settings")
103+
class TestAllowedURIValidator(TestCase):
104+
# TODO: verify the specifics of the ValidationErrors
105+
def test_valid_schemes(self):
106+
validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "test")
107+
good_uris = [
108+
"my-scheme://example.com",
109+
"my-scheme://example",
110+
"my-scheme://localhost",
111+
"https://example.com",
112+
"HTTPS://example.com",
113+
"git+ssh://example.com",
114+
]
115+
for uri in good_uris:
116+
# Check ValidationError not thrown
117+
validator(uri)
118+
119+
def test_invalid_schemes(self):
120+
validator = AllowedURIValidator(["https"], "test")
105121
bad_uris = [
106122
"http:/example.com",
107123
"HTTP://localhost",
108124
"HTTP://example.com",
125+
"https://-exa", # triggers an exception in the upstream validators
126+
"HTTP://example.com/path",
127+
"HTTP://example.com/path?query=string",
128+
"HTTP://example.com/path?query=string#fragmemt",
109129
"HTTP://example.com.",
110-
"http://example.com/#fragment",
130+
"http://example.com/path/#fragment",
131+
"http://example.com?query=string#fragment",
111132
"123://example.com",
112133
"http://fe80::1",
113134
"git+ssh://example.com",
@@ -119,12 +140,125 @@ def test_validate_bad_origin_uris(self):
119140
"",
120141
# Bad IPv6 URL, urlparse behaves differently for these
121142
'https://["><script>alert()</script>',
122-
# Origin uri should not contain path, query of fragment parts
123-
# https://www.rfc-editor.org/rfc/rfc6454#section-7.1
124-
"https://example.com/",
125-
"https://example.com/test",
126-
"https://example.com/?q=test",
127-
"https://example.com/#test",
143+
]
144+
145+
for uri in bad_uris:
146+
with self.assertRaises(ValidationError):
147+
validator(uri)
148+
149+
def test_allow_paths_valid_urls(self):
150+
validator = AllowedURIValidator(["https", "myapp"], "test", allow_path=True)
151+
good_uris = [
152+
"https://example.com",
153+
"https://example.com:8080",
154+
"https://example",
155+
"https://example.com/path",
156+
"https://example.com:8080/path",
157+
"https://example/path",
158+
"https://localhost/path",
159+
"myapp://host/path",
160+
]
161+
for uri in good_uris:
162+
# Check ValidationError not thrown
163+
validator(uri)
164+
165+
def test_allow_paths_invalid_urls(self):
166+
validator = AllowedURIValidator(["https", "myapp"], "test", allow_path=True)
167+
bad_uris = [
168+
"https://example.com?query=string",
169+
"https://example.com#fragment",
170+
"https://example.com/path?query=string",
171+
"https://example.com/path#fragment",
172+
"https://example.com/path?query=string#fragment",
173+
"myapp://example.com/path?query=string",
174+
"myapp://example.com/path#fragment",
175+
"myapp://example.com/path?query=string#fragment",
176+
"bad://example.com/path",
177+
]
178+
179+
for uri in bad_uris:
180+
with self.assertRaises(ValidationError):
181+
validator(uri)
182+
183+
def test_allow_query_valid_urls(self):
184+
validator = AllowedURIValidator(["https", "myapp"], "test", allow_query=True)
185+
good_uris = [
186+
"https://example.com",
187+
"https://example.com:8080",
188+
"https://example.com?query=string",
189+
"https://example",
190+
"myapp://example.com?query=string",
191+
"myapp://example?query=string",
192+
]
193+
for uri in good_uris:
194+
# Check ValidationError not thrown
195+
validator(uri)
196+
197+
def test_allow_query_invalid_urls(self):
198+
validator = AllowedURIValidator(["https", "myapp"], "test", allow_query=True)
199+
bad_uris = [
200+
"https://example.com/path",
201+
"https://example.com#fragment",
202+
"https://example.com/path?query=string",
203+
"https://example.com/path#fragment",
204+
"https://example.com/path?query=string#fragment",
205+
"https://example.com:8080/path",
206+
"https://example/path",
207+
"https://localhost/path",
208+
"myapp://example.com/path?query=string",
209+
"myapp://example.com/path#fragment",
210+
"myapp://example.com/path?query=string#fragment",
211+
"bad://example.com/path",
212+
]
213+
214+
for uri in bad_uris:
215+
with self.assertRaises(ValidationError):
216+
validator(uri)
217+
218+
def test_allow_fragment_valid_urls(self):
219+
validator = AllowedURIValidator(["https", "myapp"], "test", allow_fragments=True)
220+
good_uris = [
221+
"https://example.com",
222+
"https://example.com#fragment",
223+
"https://example.com:8080",
224+
"https://example.com:8080#fragment",
225+
"https://example",
226+
"https://example#fragment",
227+
"myapp://example",
228+
"myapp://example#fragment",
229+
"myapp://example.com",
230+
"myapp://example.com#fragment",
231+
]
232+
for uri in good_uris:
233+
# Check ValidationError not thrown
234+
validator(uri)
235+
236+
def test_allow_fragment_invalid_urls(self):
237+
validator = AllowedURIValidator(["https", "myapp"], "test", allow_fragments=True)
238+
bad_uris = [
239+
"https://example.com?query=string",
240+
"https://example.com?query=string#fragment",
241+
"https://example.com/path",
242+
"https://example.com/path?query=string",
243+
"https://example.com/path#fragment",
244+
"https://example.com/path?query=string#fragment",
245+
"https://example.com:8080/path",
246+
"https://example?query=string",
247+
"https://example?query=string#fragment",
248+
"https://example/path",
249+
"https://example/path?query=string",
250+
"https://example/path#fragment",
251+
"https://example/path?query=string#fragment",
252+
"myapp://example?query=string",
253+
"myapp://example?query=string#fragment",
254+
"myapp://example/path",
255+
"myapp://example/path?query=string",
256+
"myapp://example/path#fragment",
257+
"myapp://example.com/path?query=string",
258+
"myapp://example.com/path#fragment",
259+
"myapp://example.com/path?query=string#fragment",
260+
"myapp://example.com?query=string",
261+
"bad://example.com",
128262
]
129263

130264
for uri in bad_uris:

0 commit comments

Comments
 (0)