1
- import argparse
1
+ import os
2
+ import unittest
2
3
import subprocess
3
4
import difflib
4
- import sys
5
+
6
+ from cms .conf import config
7
+ from cms .db .drop import drop_db
8
+ from cms .db .init import init_db
9
+ from cms .db .session import custom_psycopg2_connection
5
10
6
11
"""
7
12
Compare the DB schema obtained from upgrading an older version's database using
17
22
to the first line thing is ALTER TABLE ADD CONSTRAINT, in which the constraint
18
23
name is on the second line. So we move the constraint name up to the first
19
24
line.)
20
- """
21
25
26
+ To update the files after a new release:
22
27
23
- def split_schemma (schema : str ):
28
+ cmsInitDB
29
+ pg_dump --schema-only >schema_vX.Y.sql
30
+
31
+ and replace update_from_vX.Y.sql with a blank file.
32
+ """
33
+
34
+ def split_schema (schema : str ) -> list [list [str ]]:
24
35
statements : list [list [str ]] = []
25
36
cur_statement : list [str ] = []
26
37
for line in schema .splitlines ():
@@ -34,9 +45,10 @@ def split_schemma(schema: str):
34
45
return statements
35
46
36
47
37
- def normalize_stmt (statement : list [str ]):
48
+ def normalize_stmt (statement : list [str ]) -> list [ str ] :
38
49
if statement [0 ].startswith ("CREATE TABLE " ):
39
50
# normalize order of columns by sorting the arguments to CREATE TABLE.
51
+
40
52
assert statement [- 1 ] == ");"
41
53
# add missing trailing comma on the last column.
42
54
assert not statement [- 2 ].endswith ("," )
@@ -56,12 +68,12 @@ def normalize_stmt(statement: list[str]):
56
68
return statement
57
69
58
70
59
- def is_create_enum (line : str ):
71
+ def is_create_enum (line : str ) -> bool :
60
72
return line .startswith ("CREATE TYPE " ) and line .endswith (" AS ENUM (" )
61
73
62
74
63
- def compare_schemas (updated_schema : list [list [str ]], fresh_schema : list [list [str ]]):
64
- ok = True
75
+ def compare_schemas (updated_schema : list [list [str ]], fresh_schema : list [list [str ]]) -> str :
76
+ errors : list [ str ] = []
65
77
66
78
updated_map : dict [str , list [str ]] = {}
67
79
for stmt in map (normalize_stmt , updated_schema ):
@@ -75,8 +87,7 @@ def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str
75
87
76
88
for updated_stmt in updated_map .values ():
77
89
if updated_stmt [0 ] not in fresh_map :
78
- print ("Updated schema contains extra statement:" , * updated_stmt , sep = "\n " )
79
- ok = False
90
+ errors += ["Updated schema contains extra statement:" , * updated_stmt ]
80
91
else :
81
92
fresh_stmt = fresh_map [updated_stmt [0 ]]
82
93
if is_create_enum (updated_stmt [0 ]):
@@ -86,87 +97,64 @@ def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str
86
97
}
87
98
fresh_values = {x .removesuffix ("," ).strip () for x in fresh_stmt [1 :- 1 ]}
88
99
if not fresh_values .issubset (updated_values ):
89
- print ( "Updated schema is missing enum value(s):" )
90
- print ( "Updated:\n " + " \n " . join ( updated_stmt ))
91
- print ( "Fresh:\n " + " \n " . join ( fresh_stmt ))
100
+ errors += [ "Updated schema is missing enum value(s):" ]
101
+ errors += [ "Updated:" ] + [ " " + x for x in updated_stmt ]
102
+ errors += [ "Fresh:" ] + [ " " + x for x in fresh_stmt ]
92
103
else :
93
104
# Other statements must match exactly (in normalized form)
94
105
if updated_stmt != fresh_stmt :
95
- ok = False
96
106
differ = difflib .Differ ()
97
107
cmp = differ .compare (
98
108
[x + "\n " for x in updated_stmt ], [x + "\n " for x in fresh_stmt ]
99
109
)
100
- print ( "Statement differs between updated and fresh schema:" )
101
- print ( "" .join (cmp ))
110
+ errors += [ "Statement differs between updated and fresh schema:" ]
111
+ errors += [ "" .join (cmp ). strip ()]
102
112
103
113
for fresh_stmt in fresh_map .values ():
104
114
if fresh_stmt [0 ] not in updated_map :
105
- print ("Fresh schema contains extra statement:" , * fresh_stmt , sep = "\n " )
106
- ok = False
115
+ errors += ["Fresh schema contains extra statement:" , * fresh_stmt ]
107
116
# if it exists, then it was already checked earlier
108
117
# print('\n'.join(updated_map.keys()))
109
- return ok
110
-
118
+ return '\n ' .join (errors )
111
119
112
- def get_updated_schema (user , host , name , schema_sql , updater_sql ):
113
- args = [f"--username={ user } " , f"--host={ host } " , name ]
114
- psql_flags = ["--quiet" , "--set=ON_ERROR_STOP=1" ]
115
- subprocess .run (["dropdb" , "--if-exists" , * args ], check = True )
116
- subprocess .run (["createdb" , * args ], check = True )
117
- subprocess .run (
118
- ["psql" , * args , * psql_flags , f"--file={ schema_sql } " ],
119
- check = True ,
120
- stdout = subprocess .PIPE ,
121
- )
122
- subprocess .run (
123
- ["psql" , * args , * psql_flags , f"--file={ updater_sql } " ],
124
- check = True ,
125
- )
120
+ def run_pg_dump () -> str :
121
+ db_url = config .database .url
122
+ db_url = db_url .replace ("postgresql+psycopg2://" , "postgresql://" )
126
123
result = subprocess .run (
127
- ["pg_dump" , "--schema-only" , * args ],
124
+ ["pg_dump" , "--schema-only" , "--dbname" , db_url ],
128
125
check = True ,
129
126
text = True ,
130
127
stdout = subprocess .PIPE ,
131
128
)
132
129
return result .stdout
133
130
134
-
135
- def get_fresh_schema (user , host , name ):
136
- args = [f"--username={ user } " , f"--host={ host } " , name ]
137
- subprocess .run (["dropdb" , "--if-exists" , * args ], check = True )
138
- subprocess .run (["createdb" , * args ], check = True )
139
- subprocess .run (["cmsInitDB" ], check = True )
140
- result = subprocess .run (
141
- ["pg_dump" , "--schema-only" , * args ],
142
- check = True ,
143
- text = True ,
144
- stdout = subprocess .PIPE ,
145
- )
146
- return result .stdout
147
-
148
-
149
- def main ():
150
- parser = argparse .ArgumentParser ()
151
- parser .add_argument ("--user" , required = True )
152
- parser .add_argument ("--host" , required = True )
153
- parser .add_argument ("--name" , required = True )
154
- parser .add_argument ("--schema_sql" , required = True )
155
- parser .add_argument ("--updater_sql" , required = True )
156
- args = parser .parse_args ()
157
- print ("Checking schema updater..." )
158
- updated_schema = split_schemma (
159
- get_updated_schema (
160
- args .user , args .host , args .name , args .schema_sql , args .updater_sql
161
- )
162
- )
163
- fresh_schema = split_schemma (get_fresh_schema (args .user , args .host , args .name ))
164
- if compare_schemas (updated_schema , fresh_schema ):
165
- print ("All good, updater works" )
166
- sys .exit (0 )
167
- else :
168
- sys .exit (1 )
169
-
170
-
171
- if __name__ == "__main__" :
172
- main ()
131
+ def get_updated_schema (schema_file : str , updater_file : str ) -> str :
132
+ drop_db ()
133
+ schema_sql = open (schema_file ).read ()
134
+ updater_sql = open (updater_file ).read ()
135
+ # We need to do this in two separate connections, since the schema_sql sets
136
+ # some connection properties which we don't want.
137
+ for sql in [schema_sql , updater_sql ]:
138
+ conn = custom_psycopg2_connection ()
139
+ cursor = conn .cursor ()
140
+ cursor .execute (sql )
141
+ conn .commit ()
142
+ conn .close ()
143
+
144
+ return run_pg_dump ()
145
+
146
+ def get_fresh_schema ():
147
+ drop_db ()
148
+ init_db ()
149
+ return run_pg_dump ()
150
+
151
+ class TestSchemaDiff (unittest .TestCase ):
152
+ def test_schema_diff (self ):
153
+ dirname = os .path .dirname (__file__ )
154
+ schema_file = os .path .join (dirname , "schema_v1.5.sql" )
155
+ updater_file = os .path .join (dirname , "../../cmscontrib/updaters/update_from_1.5.sql" )
156
+ updated_schema = split_schema (get_updated_schema (schema_file , updater_file ))
157
+ fresh_schema = split_schema (get_fresh_schema ())
158
+ errors = compare_schemas (updated_schema , fresh_schema )
159
+ self .longMessage = False
160
+ self .assertTrue (errors == "" , errors )
0 commit comments