Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions dispatcherd/brokers/pg_notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ def create_connection(**config) -> psycopg.Connection: # type: ignore[no-untype
return connection


class DispatcherdInvalidChannel(psycopg.errors.SyntaxError):
pass


def validate_channel_name(channel_name: str) -> None:
"""Raise an exception if the channel name can not be reliably used.

This might happen due to a number of reasons.
The notify logic uses the channel name as a parameter, which is good for
security reasons, but imposes a character length constraint.
"""
if len(channel_name.encode('utf-8')) > 63:
raise DispatcherdInvalidChannel(f'Channel name is too long chars={len(channel_name.encode("utf-8"))}')

if not channel_name:
raise DispatcherdInvalidChannel(f'Received blank channel name {channel_name}. PG notify channel name can not be blank.')


class Broker(BrokerProtocol):
NOTIFY_QUERY_TEMPLATE = 'SELECT pg_notify(%s, %s);'

Expand Down Expand Up @@ -105,6 +123,10 @@ def __init__(
server_channels.append(self.self_check_channel)
self.channels = server_channels

# Raise an early error if any of the channel names are invalid
for channel in self.channels:
validate_channel_name(channel)

self.default_publish_channel = default_publish_channel
self.self_check_status = BrokerSelfCheckStatus.IDLE
self.last_self_check_message_time = time.monotonic()
Expand All @@ -122,14 +144,17 @@ def generate_self_check_channel_name(cls) -> str:
def get_publish_channel(self, channel: Optional[str] = None) -> str:
"Handle default for the publishing channel for calls to publish_message, shared sync and async"
if channel is not None:
return channel
return_channel = channel
elif self.default_publish_channel is not None:
return self.default_publish_channel
return_channel = self.default_publish_channel
elif len(self.user_channels) == 1:
# de-facto default channel, because there is only 1
return self.channels[0]
return_channel = self.channels[0]
else:
raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config')

raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config')
validate_channel_name(return_channel)
return return_channel

def __str__(self) -> str:
return 'pg_notify-broker'
Expand Down Expand Up @@ -159,7 +184,9 @@ def get_listen_query(self, channel: str) -> psycopg.sql.Composed:
This uses the psycopg utilities which ensure correct escaping so SQL injection is not possible.
Return value is a valid argument for cursor.execute()
"""
return psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel))
# Postgres does not allow parameters for identifiers, so this limits what channel we can accept
validate_channel_name(channel)
return psycopg.sql.SQL("LISTEN {}").format(psycopg.sql.Identifier(channel))

def get_unlisten_query(self) -> psycopg.sql.SQL:
"""Stops listening on all channels for current session, see pg_notify docs"""
Expand Down
99 changes: 99 additions & 0 deletions tests/integration/brokers/test_pg_notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest

import psycopg

from dispatcherd.brokers.pg_notify import Broker, acreate_connection, create_connection


Expand Down Expand Up @@ -106,3 +108,100 @@ async def test_async_connection_from_config_reuse(conn_config):
assert conn is conn2

assert conn is not await acreate_connection(**conn_config)



VALID_CHANNEL_NAMES = [
'foobar',
'foobar🔥',
'foo-bar',
'-foo-bar',
'a' * 63 # just under the limit
]


BAD_CHANNEL_NAMES = [
'a' + '🔥' * 22, # under 64 but expanded unicode is over
'a' * 64, # over the limit of 63
'a' * 120,
''
]


class TestChannelSanitizationPostgresSanity:
"""These do not test dispatcherd itself, but give a reference by testing psycopg and postgres

These tests validate that the valid and bad channel name lists are, in fact, bad and valid.
"""
@pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES)
def test_psycopg_valid_sanity_check(self, channel_name, conn_config):
"""Sanity check that postgres itself will accept valid names for listening"""
conn = psycopg.connect(**conn_config, autocommit=True)
conn.execute(psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel_name)))
conn.execute(Broker.NOTIFY_QUERY_TEMPLATE, (channel_name, 'foo'))

@pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES)
def test_psycopg_error_sanity_check(self, channel_name, conn_config):
"""Sanity check that postgres itself will raise an error for the known invalid names"""
conn = psycopg.connect(**conn_config, autocommit=True)
with pytest.raises(psycopg.DatabaseError):
conn.execute(psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel_name)))
conn.execute(Broker.NOTIFY_QUERY_TEMPLATE, (channel_name, 'foo'))

@pytest.fixture
def can_receive_notification(self, conn_config):
def _rf(channel_name):
conn = psycopg.connect(**conn_config, autocommit=True)
try:
conn.execute(psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel_name)))
conn.execute(Broker.NOTIFY_QUERY_TEMPLATE, (channel_name, 'this is a test message'))
except Exception:
return False # did not work
gen = conn.notifies(timeout=0.001)
try:
for notify in gen:
assert notify.payload == 'this is a test message'
gen.close()
return True
else:
return False
finally:
gen.close()
return _rf

@pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES)
def test_can_receive_over_valid_channels(self, can_receive_notification, channel_name):
assert can_receive_notification(channel_name)

@pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES)
def test_can_not_receive_over_invalid_channels(self, can_receive_notification, channel_name):
assert not can_receive_notification(channel_name)


class TestChannelSanitizationPostgres:
"""These tests verify that we do early validation

Specifically, this means that dispatcherd will not let you listen to a channel you can not send to
and that you can not send to a channel you can not listen to"""

@pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES)
def test_valid_channel_publish(self, channel_name, conn_config):
broker = Broker(config=conn_config)
broker.publish_message(channel=channel_name, message='foobar')

@pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES)
def test_invalid_channel_publish(self, channel_name, conn_config):
broker = Broker(config=conn_config)
with pytest.raises(psycopg.DatabaseError):
broker.publish_message(channel=channel_name, message='foobar')

@pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES)
def test_valid_channel_listen(self, channel_name, conn_config):
broker = Broker(config=conn_config, channels=[channel_name])
broker.process_notify(max_messages=0)

@pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES)
def test_invalid_channel_listen(self, channel_name, conn_config):
with pytest.raises(psycopg.DatabaseError):
broker = Broker(config=conn_config, channels=[channel_name])
broker.process_notify(max_messages=0)