Skip to content
48 changes: 21 additions & 27 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,38 +144,32 @@ def match_type_and_length(self, name):
invalid_name_error = (
"{} name must be a valid unicode string "
+ "with length < {} ".format(MAX_NAME_LENGTH)
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods, "
+ "not {}"
+ "containing only ASCII alphanumerics, hyphens, underscores, or periods."
)

def valid_channel_name(self, name, receive=False):
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
if "!" in name and not name.endswith("!") and receive:
raise TypeError(
"Specific channel names in receive() must end at the !"
)
return True
raise TypeError(self.invalid_name_error.format("Channel", name))
def require_valid_channel_name(self, name, receive=False):
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Channel"))
if not bool(self.channel_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Channel"))
if "!" in name and not name.endswith("!") and receive:
raise TypeError("Specific channel names in receive() must end at the !")
return True

def require_valid_group_name(self, name):
if len(name) >= self.MAX_NAME_LENGTH:
raise TypeError(
f"Group name must be less than {self.MAX_NAME_LENGTH} characters."
)
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))
if not self.match_type_and_length(name):
raise TypeError(self.invalid_name_error.format("Group"))
if not bool(self.group_name_regex.match(name)):
raise TypeError(self.invalid_name_error.format("Group"))
return True

def valid_channel_names(self, names, receive=False):
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"

assert all(
self.valid_channel_name(channel, receive=receive) for channel in names
all(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The all doesn't have semantic value, can you refactor this as a normal for loop?

self.require_valid_channel_name(channel, receive=receive)
for channel in names
)
return True

Expand Down Expand Up @@ -247,7 +241,7 @@ async def send(self, channel, message):
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_channel_name(channel)
# If it's a process-local channel, strip off local part and stick full
# name in message
assert "__asgi_channel__" not in message
Expand All @@ -267,7 +261,7 @@ async def receive(self, channel):
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self.require_valid_channel_name(channel)
self._clean_expired()

queue = self.channels.setdefault(
Expand Down Expand Up @@ -346,14 +340,14 @@ async def group_add(self, group, channel):
"""
# Check the inputs
self.require_valid_group_name(group)
self.valid_channel_name(channel), "Channel name not valid"
self.require_valid_channel_name(channel)
# Add to group dict
self.groups.setdefault(group, {})
self.groups[group][channel] = time.time()

async def group_discard(self, group, channel):
# Both should be text and valid
assert self.valid_channel_name(channel), "Invalid channel name"
self.require_valid_channel_name(channel)
self.require_valid_group_name(group)
# Remove from group set
group_channels = self.groups.get(group, None)
Expand Down
33 changes: 24 additions & 9 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def test_send_receive():
@pytest.mark.parametrize(
"method",
[
BaseChannelLayer().valid_channel_name,
BaseChannelLayer().require_valid_channel_name,
BaseChannelLayer().require_valid_group_name,
],
)
Expand All @@ -90,21 +90,36 @@ def test_channel_and_group_name_validation(method, channel_name, expected_valid)


@pytest.mark.parametrize(
"name, expected_error_message",
"name",
[
(
"a" * 101,
f"Group name must be less than {BaseChannelLayer.MAX_NAME_LENGTH} "
"characters.",
), # Group name too long
"a" * 101, # Group name too long
],
)
def test_group_name_length_error_message(name, expected_error_message):
def test_group_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit.
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Group")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_group_name(name)


@pytest.mark.parametrize(
"name",
[
"a" * 101, # Channel name too long
],
)
def test_channel_name_length_error_message(name):
"""
Ensure the correct error message is raised when group names
exceed the character limit or contain invalid characters.
"""
layer = BaseChannelLayer()
expected_error_message = layer.invalid_name_error.format("Channel")

with pytest.raises(TypeError, match=expected_error_message):
layer.require_valid_channel_name(name)