66
77import asyncpg
88
9+ from sftkit .database import Connection
10+ from sftkit .database .introspection import list_constraints , list_functions , list_triggers , list_views
11+
912logger = logging .getLogger (__name__ )
1013
1114MIGRATION_VERSION_RE = re .compile (r"^-- migration: (?P<version>\w+)$" )
1215MIGRATION_REQURES_RE = re .compile (r"^-- requires: (?P<version>\w+)$" )
1316MIGRATION_TABLE = "schema_revision"
1417
1518
16- async def _run_postgres_code (conn : asyncpg . Connection , code : str , file_name : Path ):
19+ async def _run_postgres_code (conn : Connection , code : str , file_name : Path ):
1720 if all (line .startswith ("--" ) for line in code .splitlines ()):
1821 return
1922 try :
@@ -32,33 +35,23 @@ async def _run_postgres_code(conn: asyncpg.Connection, code: str, file_name: Pat
3235 raise ValueError (f"Syntax or Access error when executing SQL code ({ file_name !s} ): { message !r} " ) from exc
3336
3437
35- async def _drop_all_views (conn : asyncpg . Connection , schema : str ):
38+ async def _drop_all_views (conn : Connection , schema : str ):
3639 # TODO: we might have to find out the dependency order of the views if drop cascade does not work
37- result = await conn .fetch (
38- "select table_name from information_schema.views where table_schema = $1 and table_name !~ '^pg_';" ,
39- schema ,
40- )
41- views = [row ["table_name" ] for row in result ]
40+ views = await list_views (conn , schema )
4241 if len (views ) == 0 :
4342 return
4443
4544 # we use drop if exists here as the cascade dropping might lead the view to being already dropped
4645 # due to being a dependency of another view
47- drop_statements = "\n " .join ([f"drop view if exists { view } cascade;" for view in views ])
46+ drop_statements = "\n " .join ([f"drop view if exists { view . table_name } cascade;" for view in views ])
4847 await conn .execute (drop_statements )
4948
5049
51- async def _drop_all_triggers (conn : asyncpg .Connection , schema : str ):
52- result = await conn .fetch (
53- "select distinct on (trigger_name, event_object_table) trigger_name, event_object_table "
54- "from information_schema.triggers where trigger_schema = $1" ,
55- schema ,
56- )
50+ async def _drop_all_triggers (conn : Connection , schema : str ):
51+ triggers = await list_triggers (conn , schema )
5752 statements = []
58- for row in result :
59- trigger_name = row ["trigger_name" ]
60- table = row ["event_object_table" ]
61- statements .append (f"drop trigger { trigger_name } on { table } ;" )
53+ for trigger in triggers :
54+ statements .append (f'drop trigger "{ trigger .trigger_name } " on "{ trigger .event_object_table } ";' )
6255
6356 if len (statements ) == 0 :
6457 return
@@ -67,27 +60,20 @@ async def _drop_all_triggers(conn: asyncpg.Connection, schema: str):
6760 await conn .execute (drop_statements )
6861
6962
70- async def _drop_all_functions (conn : asyncpg .Connection , schema : str ):
71- result = await conn .fetch (
72- "select proname, pg_get_function_identity_arguments(oid) as signature, prokind from pg_proc "
73- "where pronamespace = $1::regnamespace;" ,
74- schema ,
75- )
63+ async def _drop_all_functions (conn : Connection , schema : str ):
64+ funcs = await list_functions (conn , schema )
7665 drop_statements = []
77- for row in result :
78- kind = row ["prokind" ].decode ("utf-8" )
79- name = row ["proname" ]
80- signature = row ["signature" ]
81- if kind in ("f" , "w" ):
66+ for func in funcs :
67+ if func .prokind in ("f" , "w" ):
8268 drop_type = "function"
83- elif kind == "a" :
69+ elif func . prokind == "a" :
8470 drop_type = "aggregate"
85- elif kind == "p" :
71+ elif func . prokind == "p" :
8672 drop_type = "procedure"
8773 else :
88- raise RuntimeError (f'Unknown postgres function type "{ kind } "' )
74+ raise RuntimeError (f'Unknown postgres function type "{ func . prokind } "' )
8975
90- drop_statements .append (f" drop { drop_type } { name } ( { signature } ) cascade;" )
76+ drop_statements .append (f' drop { drop_type } " { func . proname } "( { func . signature } ) cascade;' )
9177
9278 if len (drop_statements ) == 0 :
9379 return
@@ -96,37 +82,31 @@ async def _drop_all_functions(conn: asyncpg.Connection, schema: str):
9682 await conn .execute (drop_code )
9783
9884
99- async def _drop_all_constraints (conn : asyncpg . Connection , schema : str ):
85+ async def _drop_all_constraints (conn : Connection , schema : str ):
10086 """drop all constraints in the given schema which are not unique, primary or foreign key constraints"""
101- result = await conn .fetch (
102- "select con.conname as constraint_name, rel.relname as table_name, con.contype as constraint_type "
103- "from pg_catalog.pg_constraint con "
104- " join pg_catalog.pg_namespace nsp on nsp.oid = con.connamespace "
105- " left join pg_catalog.pg_class rel on rel.oid = con.conrelid "
106- "where nsp.nspname = $1 and con.conname !~ '^pg_' "
107- " and con.contype != 'p' and con.contype != 'f' and con.contype != 'u';" ,
108- schema ,
109- )
110- constraints = []
111- for row in result :
112- constraint_name = row ["constraint_name" ]
113- constraint_type = row ["constraint_type" ].decode ("utf-8" )
114- table_name = row ["table_name" ]
87+ constraints = await list_constraints (conn , schema )
88+ drop_statements = []
89+ for constraint in constraints :
90+ constraint_name = constraint .conname
91+ constraint_type = constraint .contype
92+ table_name = constraint .relname
93+ if constraint_type in ("p" , "f" , "u" ):
94+ continue
11595 if constraint_type == "c" :
116- constraints .append (f" alter table { table_name } drop constraint { constraint_name } ;" )
96+ drop_statements .append (f' alter table " { table_name } " drop constraint " { constraint_name } ";' )
11797 elif constraint_type == "t" :
118- constraints .append (f"drop constraint trigger { constraint_name } ;" )
98+ drop_statements .append (f"drop constraint trigger { constraint_name } ;" )
11999 else :
120100 raise RuntimeError (f'Unknown constraint type "{ constraint_type } " for constraint "{ constraint_name } "' )
121101
122- if len (constraints ) == 0 :
102+ if len (drop_statements ) == 0 :
123103 return
124104
125- drop_statements = "\n " .join (constraints )
126- await conn .execute (drop_statements )
105+ drop_cmd = "\n " .join (drop_statements )
106+ await conn .execute (drop_cmd )
127107
128108
129- async def _drop_db_code (conn : asyncpg . Connection , schema : str ):
109+ async def _drop_db_code (conn : Connection , schema : str ):
130110 await _drop_all_triggers (conn , schema = schema )
131111 await _drop_all_functions (conn , schema = schema )
132112 await _drop_all_views (conn , schema = schema )
0 commit comments