diff --git a/dispatcherd/brokers/pg_notify.py b/dispatcherd/brokers/pg_notify.py index 488576f4..233bddb9 100644 --- a/dispatcherd/brokers/pg_notify.py +++ b/dispatcherd/brokers/pg_notify.py @@ -37,6 +37,24 @@ def create_connection(**config) -> psycopg.Connection: # type: ignore[no-untype return connection +class DispatcherdInvalidChannel(psycopg.errors.SyntaxError): + pass + + +def validate_channel_name(channel_name: str) -> None: + """Raise an exception if the channel name can not be reliably used. + + This might happen due to a number of reasons. + The notify logic uses the channel name as a parameter, which is good for + security reasons, but imposes a character length constraint. + """ + if len(channel_name.encode('utf-8')) > 63: + raise DispatcherdInvalidChannel(f'Channel name is too long chars={len(channel_name.encode("utf-8"))}') + + if not channel_name: + raise DispatcherdInvalidChannel(f'Received blank channel name {channel_name}. PG notify channel name can not be blank.') + + class Broker(BrokerProtocol): NOTIFY_QUERY_TEMPLATE = 'SELECT pg_notify(%s, %s);' @@ -105,6 +123,10 @@ def __init__( server_channels.append(self.self_check_channel) self.channels = server_channels + # Raise an early error if any of the channel names are invalid + for channel in self.channels: + validate_channel_name(channel) + self.default_publish_channel = default_publish_channel self.self_check_status = BrokerSelfCheckStatus.IDLE self.last_self_check_message_time = time.monotonic() @@ -122,14 +144,17 @@ def generate_self_check_channel_name(cls) -> str: def get_publish_channel(self, channel: Optional[str] = None) -> str: "Handle default for the publishing channel for calls to publish_message, shared sync and async" if channel is not None: - return channel + return_channel = channel elif self.default_publish_channel is not None: - return self.default_publish_channel + return_channel = self.default_publish_channel elif len(self.user_channels) == 1: # de-facto default channel, because there is only 1 - return self.channels[0] + return_channel = self.channels[0] + else: + raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') - raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') + validate_channel_name(return_channel) + return return_channel def __str__(self) -> str: return 'pg_notify-broker' @@ -159,7 +184,9 @@ def get_listen_query(self, channel: str) -> psycopg.sql.Composed: This uses the psycopg utilities which ensure correct escaping so SQL injection is not possible. Return value is a valid argument for cursor.execute() """ - return psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel)) + # Postgres does not allow parameters for identifiers, so this limits what channel we can accept + validate_channel_name(channel) + return psycopg.sql.SQL("LISTEN {}").format(psycopg.sql.Identifier(channel)) def get_unlisten_query(self) -> psycopg.sql.SQL: """Stops listening on all channels for current session, see pg_notify docs""" diff --git a/tests/integration/brokers/test_pg_notify.py b/tests/integration/brokers/test_pg_notify.py index d31cb98a..19194459 100644 --- a/tests/integration/brokers/test_pg_notify.py +++ b/tests/integration/brokers/test_pg_notify.py @@ -3,6 +3,8 @@ import pytest +import psycopg + from dispatcherd.brokers.pg_notify import Broker, acreate_connection, create_connection @@ -106,3 +108,100 @@ async def test_async_connection_from_config_reuse(conn_config): assert conn is conn2 assert conn is not await acreate_connection(**conn_config) + + + +VALID_CHANNEL_NAMES = [ + 'foobar', + 'foobar🔥', + 'foo-bar', + '-foo-bar', + 'a' * 63 # just under the limit +] + + +BAD_CHANNEL_NAMES = [ + 'a' + '🔥' * 22, # under 64 but expanded unicode is over + 'a' * 64, # over the limit of 63 + 'a' * 120, + '' +] + + +class TestChannelSanitizationPostgresSanity: + """These do not test dispatcherd itself, but give a reference by testing psycopg and postgres + + These tests validate that the valid and bad channel name lists are, in fact, bad and valid. + """ + @pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES) + def test_psycopg_valid_sanity_check(self, channel_name, conn_config): + """Sanity check that postgres itself will accept valid names for listening""" + conn = psycopg.connect(**conn_config, autocommit=True) + conn.execute(psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel_name))) + conn.execute(Broker.NOTIFY_QUERY_TEMPLATE, (channel_name, 'foo')) + + @pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES) + def test_psycopg_error_sanity_check(self, channel_name, conn_config): + """Sanity check that postgres itself will raise an error for the known invalid names""" + conn = psycopg.connect(**conn_config, autocommit=True) + with pytest.raises(psycopg.DatabaseError): + conn.execute(psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel_name))) + conn.execute(Broker.NOTIFY_QUERY_TEMPLATE, (channel_name, 'foo')) + + @pytest.fixture + def can_receive_notification(self, conn_config): + def _rf(channel_name): + conn = psycopg.connect(**conn_config, autocommit=True) + try: + conn.execute(psycopg.sql.SQL("LISTEN {};").format(psycopg.sql.Identifier(channel_name))) + conn.execute(Broker.NOTIFY_QUERY_TEMPLATE, (channel_name, 'this is a test message')) + except Exception: + return False # did not work + gen = conn.notifies(timeout=0.001) + try: + for notify in gen: + assert notify.payload == 'this is a test message' + gen.close() + return True + else: + return False + finally: + gen.close() + return _rf + + @pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES) + def test_can_receive_over_valid_channels(self, can_receive_notification, channel_name): + assert can_receive_notification(channel_name) + + @pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES) + def test_can_not_receive_over_invalid_channels(self, can_receive_notification, channel_name): + assert not can_receive_notification(channel_name) + + +class TestChannelSanitizationPostgres: + """These tests verify that we do early validation + + Specifically, this means that dispatcherd will not let you listen to a channel you can not send to + and that you can not send to a channel you can not listen to""" + + @pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES) + def test_valid_channel_publish(self, channel_name, conn_config): + broker = Broker(config=conn_config) + broker.publish_message(channel=channel_name, message='foobar') + + @pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES) + def test_invalid_channel_publish(self, channel_name, conn_config): + broker = Broker(config=conn_config) + with pytest.raises(psycopg.DatabaseError): + broker.publish_message(channel=channel_name, message='foobar') + + @pytest.mark.parametrize('channel_name', VALID_CHANNEL_NAMES) + def test_valid_channel_listen(self, channel_name, conn_config): + broker = Broker(config=conn_config, channels=[channel_name]) + broker.process_notify(max_messages=0) + + @pytest.mark.parametrize('channel_name', BAD_CHANNEL_NAMES) + def test_invalid_channel_listen(self, channel_name, conn_config): + with pytest.raises(psycopg.DatabaseError): + broker = Broker(config=conn_config, channels=[channel_name]) + broker.process_notify(max_messages=0)