diff --git a/channels/security/websocket.py b/channels/security/websocket.py index 15bc4fa96..0d14b4b9c 100644 --- a/channels/security/websocket.py +++ b/channels/security/websocket.py @@ -140,7 +140,7 @@ def AllowedHostsOriginValidator(application): """ allowed_hosts = settings.ALLOWED_HOSTS if settings.DEBUG and not allowed_hosts: - allowed_hosts = ["localhost", "127.0.0.1", "[::1]"] + allowed_hosts = [".localhost", "127.0.0.1", "[::1]"] return OriginValidator(application, allowed_hosts) diff --git a/tests/security/test_websocket.py b/tests/security/test_websocket.py index 1444ea829..93b1eab18 100644 --- a/tests/security/test_websocket.py +++ b/tests/security/test_websocket.py @@ -1,10 +1,42 @@ import pytest +from django.test import override_settings from channels.generic.websocket import AsyncWebsocketConsumer -from channels.security.websocket import OriginValidator +from channels.security.websocket import ( + AllowedHostsOriginValidator, + OriginValidator, +) from channels.testing import WebsocketCommunicator +@pytest.mark.django_db(transaction=True) +@pytest.mark.asyncio +async def test_allowed_hosts_origin_validator(): + """ + Tests that AllowedHostsOriginValidator correctly allows/denies connections. + """ + with override_settings( + DEBUG=True, + ALLOWED_HOSTS=[], + ): + # Make our test application + application = AllowedHostsOriginValidator(AsyncWebsocketConsumer()) + # Test a subdomain of localhost + communicator = WebsocketCommunicator( + application, "/", headers=[(b"origin", b"http://subdomain.localhost:8000")] + ) + connected, _ = await communicator.connect() + assert connected + await communicator.disconnect() + # Test a bad connection + communicator = WebsocketCommunicator( + application, "/", headers=[(b"origin", b"http://bad-domain.com")] + ) + connected, _ = await communicator.connect() + assert not connected + await communicator.disconnect() + + @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio async def test_origin_validator():