Skip to content

Commit f951958

Browse files
committed
add test to check changelog does not contain BEGIN; or COMMIT;
1 parent 4a787ae commit f951958

File tree

4 files changed

+39
-1
lines changed

4 files changed

+39
-1
lines changed

pum/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class PumSqlException(Exception):
2525

2626
pass
2727

28+
class PumInvalidChangelog(Exception):
29+
"""Exception raised for invalid changelog."""
30+
31+
pass
32+
2833

2934
class PumConfigError(PumException):
3035
"""Exception raised for errors in the PUM configuration."""

pum/utils/execute_sql.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from psycopg import Connection, Cursor
55
from psycopg.errors import SyntaxError
66

7-
from ..exceptions import PumSqlException
7+
from ..exceptions import PumSqlException, PumInvalidChangelog
88
import re
99

1010
logger = logging.getLogger(__name__)
@@ -42,6 +42,13 @@ def remove_sql_comments(sql):
4242
sql = re.sub(r'(?m)(^|;)\s*--.*?(\r\n|\r|\n)', r'\1', sql)
4343
return sql
4444
sql_content = remove_sql_comments(sql_content)
45+
46+
# Check for forbidden transaction statements
47+
forbidden_statements = ['BEGIN;', 'COMMIT;']
48+
for forbidden in forbidden_statements:
49+
if re.search(rf'\b{forbidden[:-1]}\b\s*;', sql_content, re.IGNORECASE):
50+
raise PumInvalidChangelog(f"SQL contains forbidden transaction statement: {forbidden}")
51+
4552
if parameters:
4653
for key, value in parameters.items():
4754
sql_content = sql_content.replace(f"{{{{ {key} }}}}", str(value))
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
BEGIN;
2+
3+
CREATE SCHEMA IF NOT EXISTS pum_test_data;
4+
5+
CREATE TABLE pum_test_data.some_table (
6+
id INT PRIMARY KEY,
7+
name VARCHAR(100) NOT NULL,
8+
created_date DATE DEFAULT CURRENT_DATE,
9+
is_active BOOLEAN DEFAULT TRUE,
10+
amount NUMERIC(10,2)
11+
);
12+
13+
COMMIT;

test/test_upgrader.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def test_install_complex_files_content(self):
114114
)
115115
upgrader.install()
116116
self.assertTrue(sm.exists(self.conn))
117+
117118
def test_install_multiple_changelogs(self):
118119
cfg = PumConfig()
119120
sm = SchemaMigrations(cfg)
@@ -139,6 +140,18 @@ def test_install_multiple_changelogs(self):
139140
],
140141
)
141142

143+
def test_invalid_changelog(self):
144+
cfg = PumConfig()
145+
sm = SchemaMigrations(cfg)
146+
self.assertFalse(sm.exists(self.conn))
147+
upgrader = Upgrader(
148+
pg_service=self.pg_service,
149+
config=cfg,
150+
dir=str(Path("test") / "data" / "invalid_changelog"),
151+
)
152+
with self.assertRaises(Exception) as context:
153+
upgrader.install()
154+
self.assertTrue("SQL contains forbidden transaction statement: BEGIN;" in str(context.exception))
142155

143156
if __name__ == "__main__":
144157
unittest.main()

0 commit comments

Comments
 (0)