diff --git a/docs/settings.rst b/docs/settings.rst index f31aff533..a7cac94a1 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -63,6 +63,17 @@ assigned ports. Note that you may override ``Application.get_allowed_schemes()`` to set this on a per-application basis. +ALLOWED_SCHEMES +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Default: ``["https"]`` + +A list of schemes that the ``allowed_origins`` field will be validated against. +Setting this to ``["https"]`` only in production is strongly recommended. +Adding ``"http"`` to the list is considered to be safe only for local development and testing. +Note that `OAUTHLIB_INSECURE_TRANSPORT `_ +environment variable should be also set to allow http origins. + APPLICATION_MODEL ~~~~~~~~~~~~~~~~~ diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index c37057e49..80d8f3487 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -20,7 +20,7 @@ from .scopes import get_scopes_backend from .settings import oauth2_settings from .utils import jwk_from_pem -from .validators import RedirectURIValidator, URIValidator, WildcardSet +from .validators import AllowedURIValidator, RedirectURIValidator, WildcardSet logger = logging.getLogger(__name__) @@ -218,7 +218,7 @@ def clean(self): allowed_origins = self.allowed_origins.strip().split() if allowed_origins: # oauthlib allows only https scheme for CORS - validator = URIValidator({"https"}) + validator = AllowedURIValidator(oauth2_settings.ALLOWED_SCHEMES, "Origin") for uri in allowed_origins: validator(uri) @@ -808,6 +808,10 @@ def is_origin_allowed(origin, allowed_origins): """ parsed_origin = urlparse(origin) + + if parsed_origin.scheme not in oauth2_settings.ALLOWED_SCHEMES: + return False + for allowed_origin in allowed_origins: parsed_allowed_origin = urlparse(allowed_origin) if ( diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index aa7de7351..c5af9ebae 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -68,6 +68,7 @@ "REFRESH_TOKEN_ADMIN_CLASS": "oauth2_provider.admin.RefreshTokenAdmin", "REQUEST_APPROVAL_PROMPT": "force", "ALLOWED_REDIRECT_URI_SCHEMES": ["http", "https"], + "ALLOWED_SCHEMES": ["https"], "OIDC_ENABLED": False, "OIDC_ISS_ENDPOINT": "", "OIDC_USERINFO_ENDPOINT": "", diff --git a/oauth2_provider/validators.py b/oauth2_provider/validators.py index 6c8fa3839..df3d9e753 100644 --- a/oauth2_provider/validators.py +++ b/oauth2_provider/validators.py @@ -31,6 +31,33 @@ def __call__(self, value): raise ValidationError("Redirect URIs must not contain fragments") +class AllowedURIValidator(URIValidator): + def __init__(self, schemes, name, allow_path=False, allow_query=False, allow_fragments=False): + """ + :param schemes: List of allowed schemes. E.g.: ["https"] + :param name: Name of the validated URI. It is required for validation message. E.g.: "Origin" + :param allow_path: If URI can contain path part + :param allow_query: If URI can contain query part + :param allow_fragments: If URI can contain fragments part + """ + super().__init__(schemes=schemes) + self.name = name + self.allow_path = allow_path + self.allow_query = allow_query + self.allow_fragments = allow_fragments + + def __call__(self, value): + super().__call__(value) + value = force_str(value) + scheme, netloc, path, query, fragment = urlsplit(value) + if query and not self.allow_query: + raise ValidationError("{} URIs must not contain query".format(self.name)) + if fragment and not self.allow_fragments: + raise ValidationError("{} URIs must not contain fragments".format(self.name)) + if path and not self.allow_path: + raise ValidationError("{} URIs must not contain path".format(self.name)) + + ## # WildcardSet is a special set that contains everything. # This is required in order to move validation of the scheme from diff --git a/tests/conftest.py b/tests/conftest.py index d620c3f59..eff48f7fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,6 +124,18 @@ def public_application(): ) +@pytest.fixture +def cors_application(): + return Application.objects.create( + name="Test CORS Application", + client_type=Application.CLIENT_CONFIDENTIAL, + authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE, + algorithm=Application.RS256_ALGORITHM, + client_secret=CLEARTEXT_SECRET, + allowed_origins="https://example.com http://example.com", + ) + + @pytest.fixture def logged_in_client(test_user): from django.test.client import Client diff --git a/tests/presets.py b/tests/presets.py index 1ac8d3279..4538c64eb 100644 --- a/tests/presets.py +++ b/tests/presets.py @@ -57,3 +57,11 @@ "READ_SCOPE": "read", "WRITE_SCOPE": "write", } + +ALLOWED_SCHEMES_DEFAULT = { + "ALLOWED_SCHEMES": ["https"], +} + +ALLOWED_SCHEMES_HTTP = { + "ALLOWED_SCHEMES": ["https", "http"], +} diff --git a/tests/test_models.py b/tests/test_models.py index 4de823b8d..8c62e2c99 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -594,3 +594,19 @@ def test_application_clean(oauth2_settings, application): assert "Enter a valid URL" in str(exc.value) application.allowed_origins = "https://example.com" application.clean() + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) +def test_application_origin_allowed_default_https(oauth2_settings, cors_application): + """Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https""" + assert cors_application.origin_allowed("https://example.com") + assert not cors_application.origin_allowed("http://example.com") + + +@pytest.mark.django_db +@pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP) +def test_application_origin_allowed_http(oauth2_settings, cors_application): + """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" + assert cors_application.origin_allowed("https://example.com") + assert cors_application.origin_allowed("http://example.com") diff --git a/tests/test_validators.py b/tests/test_validators.py index 0760e0290..6cbc23172 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -2,7 +2,7 @@ from django.core.validators import ValidationError from django.test import TestCase -from oauth2_provider.validators import RedirectURIValidator +from oauth2_provider.validators import AllowedURIValidator, RedirectURIValidator @pytest.mark.usefixtures("oauth2_settings") @@ -36,6 +36,11 @@ def test_validate_custom_uri_scheme(self): # Check ValidationError not thrown validator(uri) + validator = AllowedURIValidator(["my-scheme", "https", "git+ssh"], "Origin") + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + def test_validate_bad_uris(self): validator = RedirectURIValidator(allowed_schemes=["https"]) self.oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES = ["https", "good"] @@ -61,3 +66,67 @@ def test_validate_bad_uris(self): for uri in bad_uris: with self.assertRaises(ValidationError): validator(uri) + + def test_validate_good_origin_uris(self): + """ + Test AllowedURIValidator validates origin URIs if they match requirements + """ + validator = AllowedURIValidator( + ["https"], + "Origin", + allow_path=False, + allow_query=False, + allow_fragments=False, + ) + good_uris = [ + "https://example.com", + "https://example.com:8080", + "https://example", + "https://localhost", + "https://1.1.1.1", + "https://127.0.0.1", + "https://255.255.255.255", + ] + for uri in good_uris: + # Check ValidationError not thrown + validator(uri) + + def test_validate_bad_origin_uris(self): + """ + Test AllowedURIValidator rejects origin URIs if they do not match requirements + """ + validator = AllowedURIValidator( + ["https"], + "Origin", + allow_path=False, + allow_query=False, + allow_fragments=False, + ) + bad_uris = [ + "http:/example.com", + "HTTP://localhost", + "HTTP://example.com", + "HTTP://example.com.", + "http://example.com/#fragment", + "123://example.com", + "http://fe80::1", + "git+ssh://example.com", + "my-scheme://example.com", + "uri-without-a-scheme", + "https://example.com/#fragment", + "good://example.com/#fragment", + " ", + "", + # Bad IPv6 URL, urlparse behaves differently for these + 'https://[">', + # Origin uri should not contain path, query of fragment parts + # https://www.rfc-editor.org/rfc/rfc6454#section-7.1 + "https://example.com/", + "https://example.com/test", + "https://example.com/?q=test", + "https://example.com/#test", + ] + + for uri in bad_uris: + with self.assertRaises(ValidationError): + validator(uri)