Skip to content

Commit c97dc37

Browse files
committed
fix: allow blacklisting db functions to prevent being dropped
1 parent 22afd84 commit c97dc37

File tree

1 file changed

+41
-7
lines changed

1 file changed

+41
-7
lines changed

sftkit/sftkit/database/_migrations.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ async def _drop_all_triggers(conn: Connection, schema: str):
6060
await conn.execute(drop_statements)
6161

6262

63-
async def _drop_all_functions(conn: Connection, schema: str):
63+
async def _drop_all_functions(
64+
conn: Connection, schema: str, function_blacklist: list[str], function_blacklist_prefix: str | None
65+
):
66+
blacklist = set(function_blacklist)
6467
funcs = await list_functions(conn, schema)
6568
drop_statements = []
6669
for func in funcs:
@@ -73,6 +76,11 @@ async def _drop_all_functions(conn: Connection, schema: str):
7376
else:
7477
raise RuntimeError(f'Unknown postgres function type "{func.prokind}"')
7578

79+
if func.proname in blacklist:
80+
continue
81+
if function_blacklist_prefix is not None and func.proname.startswith(function_blacklist_prefix):
82+
continue
83+
7684
drop_statements.append(f'drop {drop_type} "{func.proname}"({func.signature}) cascade;')
7785

7886
if len(drop_statements) == 0:
@@ -106,9 +114,13 @@ async def _drop_all_constraints(conn: Connection, schema: str):
106114
await conn.execute(drop_cmd)
107115

108116

109-
async def _drop_db_code(conn: Connection, schema: str):
117+
async def _drop_db_code(
118+
conn: Connection, schema: str, function_blacklist: list[str], function_blacklist_prefix: str | None
119+
):
110120
await _drop_all_triggers(conn, schema=schema)
111-
await _drop_all_functions(conn, schema=schema)
121+
await _drop_all_functions(
122+
conn, schema=schema, function_blacklist=function_blacklist, function_blacklist_prefix=function_blacklist_prefix
123+
)
112124
await _drop_all_views(conn, schema=schema)
113125
await _drop_all_constraints(conn, schema=schema)
114126

@@ -220,13 +232,29 @@ async def apply_db_code(conn: asyncpg.Connection, code_path: Path):
220232
await _run_postgres_code(conn, code, code_file)
221233

222234

223-
async def reload_db_code(conn: asyncpg.Connection, code_path: Path):
224-
await _drop_db_code(conn, schema="public")
235+
async def reload_db_code(
236+
conn: asyncpg.Connection,
237+
code_path: Path,
238+
function_blacklist: list[str] | None = None,
239+
function_blacklist_prefix: str | None = None,
240+
):
241+
_function_blacklist = function_blacklist or []
242+
await _drop_db_code(
243+
conn,
244+
schema="public",
245+
function_blacklist=_function_blacklist,
246+
function_blacklist_prefix=function_blacklist_prefix,
247+
)
225248
await apply_db_code(conn, code_path)
226249

227250

228251
async def apply_migrations(
229-
db_pool: asyncpg.Pool, migration_path: Path, code_path: Path, until_migration: str | None = None
252+
db_pool: asyncpg.Pool,
253+
migration_path: Path,
254+
code_path: Path,
255+
until_migration: str | None = None,
256+
function_blacklist: list[str] | None = None,
257+
function_blacklist_prefix: str | None = None,
230258
):
231259
migrations = SchemaMigration.migrations_from_dir(migration_path)
232260

@@ -236,7 +264,13 @@ async def apply_migrations(
236264

237265
curr_migration = await conn.fetchval(f"select version from {MIGRATION_TABLE} limit 1")
238266

239-
await _drop_db_code(conn=conn, schema="public")
267+
_function_blacklist = function_blacklist or []
268+
await _drop_db_code(
269+
conn=conn,
270+
schema="public",
271+
function_blacklist=_function_blacklist,
272+
function_blacklist_prefix=function_blacklist_prefix,
273+
)
240274
# TODO: perform a dry run to check all migrations before doing anything
241275

242276
found = curr_migration is None

0 commit comments

Comments
 (0)