Skip to content

Commit 40efb47

Browse files
chore: add helper to sanitize pg identifiers (#1295)
For dispatcherd we need to sanitize the channel name as we already do for event streams. There is no need to fill the code with `.replace()` calls. Instead create a helper that can sanitize the name in a more reusable and reliable way. --------- Signed-off-by: Alex <[email protected]>
1 parent 5be5d6b commit 40efb47

File tree

3 files changed

+81
-4
lines changed

3 files changed

+81
-4
lines changed

src/aap_eda/core/models/event_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from django.db import models
1818

19+
from aap_eda.utils import sanitize_postgres_identifier
20+
1921
from .base import BaseOrgModel, PrimordialModel, UniqueNamedModel
2022

2123
__all__ = "EventStream"
@@ -92,9 +94,7 @@ class EventStream(BaseOrgModel, UniqueNamedModel, PrimordialModel):
9294

9395
def _get_channel_name(self) -> str:
9496
"""Generate the channel name based on the UUID and prefix."""
95-
return (
96-
f"{EDA_EVENT_STREAM_CHANNEL_PREFIX}"
97-
f"{str(self.uuid).replace('-','_')}"
98-
)
97+
channel_name = f"{EDA_EVENT_STREAM_CHANNEL_PREFIX}{str(self.uuid)}"
98+
return sanitize_postgres_identifier(channel_name)
9999

100100
channel_name = property(_get_channel_name)

src/aap_eda/utils/__init__.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414
import importlib.metadata
1515
import logging
16+
import re
17+
from functools import cache
1618

1719
logger = logging.getLogger(__name__)
1820

@@ -32,3 +34,39 @@ def get_package_version(package_name: str) -> str:
3234
package_name,
3335
)
3436
return "unknown"
37+
38+
39+
@cache
40+
def sanitize_postgres_identifier(identifier: str) -> str:
41+
"""
42+
Sanitize an input string to conform to PostgreSQL identifier rules.
43+
44+
Initially intended to be used for pg_notify channel names.
45+
"""
46+
max_identifier_length = 63
47+
if not identifier:
48+
raise ValueError("Identifier cannot be empty.")
49+
50+
if len(identifier) > max_identifier_length:
51+
raise ValueError(
52+
f"Identifier exceeds {max_identifier_length} characters."
53+
)
54+
55+
sanitized = identifier
56+
57+
# Ensure it starts with a valid character
58+
if not re.match(r"[A-Za-z_]", identifier[0]):
59+
if len(identifier) == max_identifier_length:
60+
raise ValueError(
61+
f"Identifier has invalid first character and "
62+
f"{max_identifier_length} characters. "
63+
"It can not be sanitized to a valid "
64+
"PostgreSQL identifier below "
65+
f"{max_identifier_length} characters."
66+
)
67+
sanitized = f"_{identifier}"
68+
69+
# Replace invalid characters with underscores
70+
sanitized = re.sub(r"\W", "_", sanitized)
71+
72+
return sanitized

tests/unit/test_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from aap_eda.utils import (
2626
get_package_version,
2727
logger as utils_logger,
28+
sanitize_postgres_identifier,
2829
str_to_bool,
2930
)
3031
from aap_eda.utils.openapi import generate_query_params
@@ -174,3 +175,41 @@ def test_generate_query_params(serializer, expected_params):
174175

175176
for param, expected in zip(query_params, expected_params):
176177
assert param.__dict__ == expected.__dict__
178+
179+
180+
@pytest.mark.parametrize(
181+
"input_str,expected_output",
182+
[
183+
("valid_name", "valid_name"),
184+
("Valid123", "Valid123"),
185+
("_underscore", "_underscore"),
186+
("hello-world", "hello_world"),
187+
("bad@name!", "bad_name_"),
188+
("some space", "some_space"),
189+
("123name", "_123name"),
190+
("9abc", "_9abc"),
191+
("@@@", "____"),
192+
("123", "_123"),
193+
("a" * 63, "a" * 63),
194+
(("1" + "a" * 61), ("_1" + "a" * 61)),
195+
],
196+
)
197+
def test_sanitize_postgres_identifier_valid_cases(input_str, expected_output):
198+
assert sanitize_postgres_identifier(input_str) == expected_output
199+
200+
201+
def test_empty_identifier_raises():
202+
with pytest.raises(ValueError, match="Identifier cannot be empty."):
203+
sanitize_postgres_identifier("")
204+
205+
206+
def test_identifier_exceeding_length_limit_raises():
207+
too_long = "a" * 64
208+
with pytest.raises(ValueError, match="exceeds 63 characters"):
209+
sanitize_postgres_identifier(too_long)
210+
211+
212+
def test_identifier_with_invalid_first_character_raises():
213+
too_long = "1" + "a" * 62
214+
with pytest.raises(ValueError, match="invalid first character"):
215+
sanitize_postgres_identifier(too_long)

0 commit comments

Comments
 (0)