Skip to content

Commit e0c3121

Browse files
committed
starting parameters
1 parent 905cab7 commit e0c3121

File tree

8 files changed

+188
-57
lines changed

8 files changed

+188
-57
lines changed

.github/dependabot.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
version: 2
2+
updates:
3+
- package-ecosystem: pip
4+
directory: "/requirements"
5+
schedule:
6+
interval: monthly
7+
time: "04:00"
8+
timezone: Europe/Paris
9+
10+
- package-ecosystem: "github-actions"
11+
directory: "/"
12+
schedule:
13+
interval: "monthly"

pum/schema_migrations.py

Lines changed: 98 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
in the database.
88
"""
99

10+
import json
1011
import logging
1112
import re
1213

@@ -54,11 +55,11 @@ def exists(self, conn: Connection) -> bool:
5455
query = sql.SQL(
5556
"""
5657
SELECT EXISTS (
57-
SELECT 1
58-
FROM information_schema.tables
59-
WHERE table_name = {table} AND table_schema = {schema}
60-
)
61-
"""
58+
SELECT 1
59+
FROM information_schema.tables
60+
WHERE table_name = {table} AND table_schema = {schema}
61+
);
62+
"""
6263
).format(
6364
schema=sql.Literal(schema),
6465
table=sql.Literal(table),
@@ -67,33 +68,6 @@ def exists(self, conn: Connection) -> bool:
6768
cursor = execute_sql(conn, query)
6869
return cursor.fetchone()[0]
6970

70-
def installed_modules(self, conn: Connection) -> list[tuple[str, str]]:
71-
"""
72-
Returns the installed modules and their versions from the schema_migrations table.
73-
Args:
74-
conn (Connection): The database connection to fetch the installed modules and versions.
75-
Returns:
76-
list[tuple[str, str]]: A list of tuples containing the module name and version.
77-
"""
78-
query = sql.SQL(
79-
"""
80-
SELECT module, version
81-
FROM (
82-
SELECT module, version,
83-
ROW_NUMBER() OVER (PARTITION BY module ORDER BY date_installed DESC) AS rn
84-
FROM {schema_migrations_table}
85-
) t
86-
WHERE t.rn = 1;
87-
"""
88-
).format(
89-
schema_migrations_table=sql.Identifier(
90-
*self.config.schema_migrations_table.split(".")
91-
)
92-
)
93-
cursor = conn.cursor()
94-
cursor.execute(query)
95-
return cursor.fetchall()
96-
9771
def create(self, conn: Connection, commit: bool = True):
9872
"""
9973
Creates the schema_migrations information table
@@ -115,6 +89,7 @@ def create(self, conn: Connection, commit: bool = True):
11589
version character varying(50) NOT NULL,
11690
beta_testing boolean NOT NULL DEFAULT false,
11791
changelog_files text[],
92+
parameters jsonb,
11893
migration_table_version character varying(50) NOT NULL DEFAULT {version}
11994
);
12095
"""
@@ -147,6 +122,8 @@ def set_baseline(
147122
version: Version | str,
148123
beta_testing: bool = False,
149124
commit: bool = True,
125+
changelog_files: list[str] = None,
126+
parameters: dict = None,
150127
):
151128
"""
152129
Sets the baseline into the migration table
@@ -170,13 +147,17 @@ def set_baseline(
170147
query = sql.SQL(
171148
"""
172149
INSERT INTO {schema_migrations_table} (
173-
version,
174-
beta_testing,
175-
migration_table_version
150+
version,
151+
beta_testing,
152+
migration_table_version,
153+
changelog_files,
154+
parameters
176155
) VALUES (
177-
{version},
178-
{beta_testing},
179-
{migration_table_version}
156+
{version},
157+
{beta_testing},
158+
{migration_table_version},
159+
{changelog_files},
160+
{parameters}
180161
)
181162
"""
182163
).format(
@@ -186,10 +167,89 @@ def set_baseline(
186167
schema_migrations_table=sql.Identifier(
187168
*self.config.schema_migrations_table.split(".")
188169
),
170+
changelog_files=sql.Literal(changelog_files or []),
171+
parameters=sql.Literal(json.dumps(parameters or {})),
189172
)
190173
logger.info(
191174
f"Setting baseline version {version} in {self.config.schema_migrations_table}"
192175
)
193176
conn.execute(query)
194177
if commit:
195178
conn.commit()
179+
180+
def baseline(self, conn: Connection) -> str:
181+
"""
182+
Returns the baseline version from the migration table
183+
Args:
184+
conn: Connection
185+
The database connection to get the baseline version.
186+
Returns:
187+
str: The baseline version.
188+
"""
189+
query = sql.SQL(
190+
"""
191+
SELECT version
192+
FROM {schema_migrations_table}
193+
WHERE id = (
194+
SELECT id
195+
FROM {schema_migrations_table}
196+
ORDER BY date_installed DESC
197+
LIMIT 1
198+
)
199+
"""
200+
).format(
201+
schema_migrations_table=sql.Identifier(
202+
*self.config.schema_migrations_table.split(".")
203+
)
204+
)
205+
cursor = execute_sql(conn, query)
206+
return cursor.fetchone()[0]
207+
208+
def migration_details(self, conn: Connection, version: str = None) -> dict:
209+
"""
210+
Returns the migration details from the migration table
211+
Args:
212+
conn: Connection
213+
The database connection to get the migration details.
214+
version: str
215+
The version of the migration to get details for. If None, last migration is returned.
216+
Returns:
217+
dict: The migration details.
218+
"""
219+
query = None
220+
if version is None:
221+
query = sql.SQL(
222+
"""
223+
SELECT *
224+
FROM {schema_migrations_table}
225+
WHERE id = (
226+
SELECT id
227+
FROM {schema_migrations_table}
228+
ORDER BY date_installed DESC
229+
LIMIT 1
230+
)
231+
ORDER BY date_installed DESC
232+
"""
233+
).format(
234+
schema_migrations_table=sql.Identifier(
235+
*self.config.schema_migrations_table.split(".")
236+
),
237+
)
238+
else:
239+
query = sql.SQL(
240+
"""
241+
SELECT *
242+
FROM {schema_migrations_table}
243+
WHERE version = {version}
244+
"""
245+
).format(
246+
schema_migrations_table=sql.Identifier(
247+
*self.config.schema_migrations_table.split(".")
248+
),
249+
version=sql.Literal(version),
250+
)
251+
cursor = execute_sql(conn, query)
252+
row = cursor.fetchone()
253+
if row is None:
254+
return None
255+
return dict(zip([desc[0] for desc in cursor.description], row))

pum/upgrader.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,18 @@ def __init__(
7373
self.schema_migrations = SchemaMigrations(self.config)
7474
self.dir = dir
7575

76-
def install(self):
77-
"""Installs the given module"""
76+
def install(self, parameters: dict | None = None):
77+
"""
78+
Installs the given module
79+
This will create the schema_migrations table if it does not exist.
80+
This will also apply all the changelogs that are after the current version.
81+
The changelogs are applied in the order they are found in the directory.
82+
It will also set the baseline version to the current version of the module.
83+
84+
args:
85+
parameters: dict
86+
The parameters to pass to the SQL files
87+
"""
7888

7989
with psycopg.connect(f"service={self.pg_service}") as conn:
8090
if self.schema_migrations.exists(conn):
@@ -83,12 +93,17 @@ def install(self):
8393
)
8494
self.schema_migrations.create(conn, commit=False)
8595
for changelog in self.changelogs(after_current_version=False):
86-
self.__apply_changelog(conn, changelog, commit=False)
96+
changelog_files = self.__apply_changelog(
97+
conn, changelog, commit=False, parameters=parameters
98+
)
99+
changelog_files = [str(f) for f in changelog_files]
87100
self.schema_migrations.set_baseline(
88101
conn=conn,
89102
version=changelog.version,
90103
beta_testing=False,
91104
commit=False,
105+
changelog_files=changelog_files,
106+
parameters=parameters,
92107
)
93108
logger
94109

@@ -193,8 +208,12 @@ def changelog_files(self, changelog: str) -> list[Path]:
193208
return files
194209

195210
def __apply_changelog(
196-
self, conn: Connection, changelog: Changelog, commit: bool = True
197-
):
211+
self,
212+
conn: Connection,
213+
changelog: Changelog,
214+
parameters: dict | None = None,
215+
commit: bool = True,
216+
) -> list[Path]:
198217
"""
199218
Apply a changelog
200219
This will execute all the files in the changelog directory.
@@ -205,12 +224,19 @@ def __apply_changelog(
205224
The connection to the database
206225
changelog: Changelog
207226
The changelog to apply
227+
parameters: dict
228+
The parameters to pass to the SQL files
208229
commit: bool
209230
If true, the transaction is committed. The default is true.
231+
232+
Returns:
233+
list[Path]
234+
The list of changelogs that were executed
210235
"""
211236
files = self.changelog_files(changelog)
212237
for file in files:
213-
execute_sql(conn=conn, sql=file, commit=commit)
238+
execute_sql(conn=conn, sql=file, commit=commit, parameters=parameters)
239+
return files
214240

215241
def __run_delta_sql(self, delta):
216242
"""Execute the delta sql file on the database"""

pum/utils/execute_sql.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,29 +10,38 @@
1010

1111

1212
def execute_sql(
13-
conn: Connection, sql: str | Path, params: tuple = (), commit: bool = False
13+
conn: Connection,
14+
sql: str | Path,
15+
parameters: dict | None = None,
16+
commit: bool = False,
1417
) -> Cursor:
1518
"""
1619
Execute a SQL statement with optional parameters.
1720
1821
Args:
1922
conn (Connection): The database connection to execute the SQL statement.
2023
sql (str | Path): The SQL statement to execute or a path to a SQL file.
21-
params (tuple, optional): Parameters to bind to the SQL statement. Defaults to ().
24+
parameters (dict, optional): Parameters to bind to the SQL statement. Defaults to ().
2225
commit (bool, optional): Whether to commit the transaction. Defaults to False.
2326
Raises:
2427
RuntimeError: If the SQL execution fails.
2528
"""
2629
cursor = conn.cursor()
2730
try:
28-
sql_code = sql
2931
if isinstance(sql, Path):
3032
logger.debug(
3133
f"Executing SQL from file: {sql}",
3234
)
3335
with open(sql) as file:
34-
sql_code = file.read()
35-
cursor.execute(sql_code, params)
36+
sql_code = file.read().split(";")
37+
else:
38+
sql_code = [sql]
39+
40+
for statement in sql_code:
41+
if parameters:
42+
cursor.execute(statement, parameters)
43+
else:
44+
cursor.execute(statement)
3645
except SyntaxError as e:
3746
raise PumSqlException(
3847
f"SQL execution failed for the following code: {sql} {e}"

pyproject.toml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,3 @@ enabled = true
4343
readme = {file = ["README.md"], content-type = "text/markdown"}
4444
dependencies = {file = ["requirements/base.txt"]}
4545
optional-dependencies.dev = {file = ["requirements/development.txt"]}
46-
47-
[project.optional-dependencies]
48-
docs = [
49-
"mkdocstrings[python]~=0.25",
50-
"mkdocs-material>=9.5.17,<9.7.0",
51-
"fancyboxmd~=1.1"
52-
]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
changelogs_directory: my_delta_directory
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
CREATE SCHEMA IF NOT EXISTS pum_test_data;
2+
3+
CREATE EXTENSION IF NOT EXISTS postgis;
4+
5+
CREATE TABLE pum_test_data.some_table (
6+
id INT PRIMARY KEY,
7+
geom geometry(LineString, %(SRID)s)
8+
);

test/test_upgrader.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,28 @@ def test_install_simple(self):
4747
)
4848
upgrader.install()
4949
self.assertTrue(sm.exists(self.conn))
50-
self.assertEqual(sm.installed_modules(self.conn)[0][1], "1.2.3")
50+
self.assertEqual(sm.baseline(self.conn), "1.2.3")
51+
self.assertEqual(
52+
sm.migration_details(self.conn), sm.migration_details(self.conn, "1.2.3")
53+
)
54+
self.assertEqual(sm.migration_details(self.conn)["version"], "1.2.3")
55+
self.assertEqual(
56+
sm.migration_details(self.conn)["changelog_files"],
57+
["test/data/simple/changelogs/1.2.3/create_northwind.sql"],
58+
)
59+
60+
def test_parameters(self):
61+
cfg = PumConfig()
62+
sm = SchemaMigrations(cfg)
63+
self.assertFalse(sm.exists(self.conn))
64+
upgrader = Upgrader(
65+
pg_service=self.pg_service,
66+
config=cfg,
67+
dir="test/data/parameters",
68+
)
69+
upgrader.install(parameters={"SRID": 2056})
70+
self.assertTrue(sm.exists(self.conn))
71+
self.assertEqual(sm.migration_details(self.conn)["parameters"], "{}")
5172

5273
def test_install_custom_directory(self):
5374
cfg = PumConfig.from_yaml("test/data/custom_directory/.pum-config.yaml")

0 commit comments

Comments
 (0)