Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion channels/security/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def AllowedHostsOriginValidator(application):
"""
allowed_hosts = settings.ALLOWED_HOSTS
if settings.DEBUG and not allowed_hosts:
allowed_hosts = ["localhost", "127.0.0.1", "[::1]"]
allowed_hosts = [".localhost", "127.0.0.1", "[::1]"]
return OriginValidator(application, allowed_hosts)


Expand Down
34 changes: 33 additions & 1 deletion tests/security/test_websocket.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,42 @@
import pytest
from django.test import override_settings

from channels.generic.websocket import AsyncWebsocketConsumer
from channels.security.websocket import OriginValidator
from channels.security.websocket import (
AllowedHostsOriginValidator,
OriginValidator,
)
from channels.testing import WebsocketCommunicator


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_allowed_hosts_origin_validator():
"""
Tests that AllowedHostsOriginValidator correctly allows/denies connections.
"""
with override_settings(
DEBUG=True,
ALLOWED_HOSTS=[],
):
# Make our test application
application = AllowedHostsOriginValidator(AsyncWebsocketConsumer())
# Test a subdomain of localhost
communicator = WebsocketCommunicator(
application, "/", headers=[(b"origin", b"http://subdomain.localhost:8000")]
)
connected, _ = await communicator.connect()
assert connected
await communicator.disconnect()
# Test a bad connection
communicator = WebsocketCommunicator(
application, "/", headers=[(b"origin", b"http://bad-domain.com")]
)
connected, _ = await communicator.connect()
assert not connected
await communicator.disconnect()


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_origin_validator():
Expand Down