Skip to content

Commit f0c2e77

Browse files
committed
Convert SQL update checker to unit test
1 parent bed5503 commit f0c2e77

File tree

3 files changed

+63
-78
lines changed

3 files changed

+63
-78
lines changed
Lines changed: 62 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
import argparse
1+
import os
2+
import unittest
23
import subprocess
34
import difflib
4-
import sys
5+
6+
from cms.conf import config
7+
from cms.db.drop import drop_db
8+
from cms.db.init import init_db
9+
from cms.db.session import custom_psycopg2_connection
510

611
"""
712
Compare the DB schema obtained from upgrading an older version's database using
@@ -17,10 +22,16 @@
1722
to the first line thing is ALTER TABLE ADD CONSTRAINT, in which the constraint
1823
name is on the second line. So we move the constraint name up to the first
1924
line.)
20-
"""
2125
26+
To update the files after a new release:
2227
23-
def split_schemma(schema: str):
28+
cmsInitDB
29+
pg_dump --schema-only >schema_vX.Y.sql
30+
31+
and replace update_from_vX.Y.sql with a blank file.
32+
"""
33+
34+
def split_schema(schema: str) -> list[list[str]]:
2435
statements: list[list[str]] = []
2536
cur_statement: list[str] = []
2637
for line in schema.splitlines():
@@ -34,9 +45,10 @@ def split_schemma(schema: str):
3445
return statements
3546

3647

37-
def normalize_stmt(statement: list[str]):
48+
def normalize_stmt(statement: list[str]) -> list[str]:
3849
if statement[0].startswith("CREATE TABLE "):
3950
# normalize order of columns by sorting the arguments to CREATE TABLE.
51+
4052
assert statement[-1] == ");"
4153
# add missing trailing comma on the last column.
4254
assert not statement[-2].endswith(",")
@@ -56,12 +68,12 @@ def normalize_stmt(statement: list[str]):
5668
return statement
5769

5870

59-
def is_create_enum(line: str):
71+
def is_create_enum(line: str) -> bool:
6072
return line.startswith("CREATE TYPE ") and line.endswith(" AS ENUM (")
6173

6274

63-
def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str]]):
64-
ok = True
75+
def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str]]) -> str:
76+
errors: list[str] = []
6577

6678
updated_map: dict[str, list[str]] = {}
6779
for stmt in map(normalize_stmt, updated_schema):
@@ -75,8 +87,7 @@ def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str
7587

7688
for updated_stmt in updated_map.values():
7789
if updated_stmt[0] not in fresh_map:
78-
print("Updated schema contains extra statement:", *updated_stmt, sep="\n")
79-
ok = False
90+
errors += ["Updated schema contains extra statement:", *updated_stmt]
8091
else:
8192
fresh_stmt = fresh_map[updated_stmt[0]]
8293
if is_create_enum(updated_stmt[0]):
@@ -86,87 +97,64 @@ def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str
8697
}
8798
fresh_values = {x.removesuffix(",").strip() for x in fresh_stmt[1:-1]}
8899
if not fresh_values.issubset(updated_values):
89-
print("Updated schema is missing enum value(s):")
90-
print("Updated:\n " + "\n ".join(updated_stmt))
91-
print("Fresh:\n " + "\n ".join(fresh_stmt))
100+
errors += ["Updated schema is missing enum value(s):"]
101+
errors += ["Updated:"] + [" " + x for x in updated_stmt]
102+
errors += ["Fresh:"] + [" " + x for x in fresh_stmt]
92103
else:
93104
# Other statements must match exactly (in normalized form)
94105
if updated_stmt != fresh_stmt:
95-
ok = False
96106
differ = difflib.Differ()
97107
cmp = differ.compare(
98108
[x + "\n" for x in updated_stmt], [x + "\n" for x in fresh_stmt]
99109
)
100-
print("Statement differs between updated and fresh schema:")
101-
print("".join(cmp))
110+
errors += ["Statement differs between updated and fresh schema:"]
111+
errors += ["".join(cmp).strip()]
102112

103113
for fresh_stmt in fresh_map.values():
104114
if fresh_stmt[0] not in updated_map:
105-
print("Fresh schema contains extra statement:", *fresh_stmt, sep="\n")
106-
ok = False
115+
errors += ["Fresh schema contains extra statement:", *fresh_stmt]
107116
# if it exists, then it was already checked earlier
108117
# print('\n'.join(updated_map.keys()))
109-
return ok
110-
118+
return '\n'.join(errors)
111119

112-
def get_updated_schema(user, host, name, schema_sql, updater_sql):
113-
args = [f"--username={user}", f"--host={host}", name]
114-
psql_flags = ["--quiet", "--set=ON_ERROR_STOP=1"]
115-
subprocess.run(["dropdb", "--if-exists", *args], check=True)
116-
subprocess.run(["createdb", *args], check=True)
117-
subprocess.run(
118-
["psql", *args, *psql_flags, f"--file={schema_sql}"],
119-
check=True,
120-
stdout=subprocess.PIPE,
121-
)
122-
subprocess.run(
123-
["psql", *args, *psql_flags, f"--file={updater_sql}"],
124-
check=True,
125-
)
120+
def run_pg_dump() -> str:
121+
db_url = config.database.url
122+
db_url = db_url.replace("postgresql+psycopg2://", "postgresql://")
126123
result = subprocess.run(
127-
["pg_dump", "--schema-only", *args],
124+
["pg_dump", "--schema-only", "--dbname", db_url],
128125
check=True,
129126
text=True,
130127
stdout=subprocess.PIPE,
131128
)
132129
return result.stdout
133130

134-
135-
def get_fresh_schema(user, host, name):
136-
args = [f"--username={user}", f"--host={host}", name]
137-
subprocess.run(["dropdb", "--if-exists", *args], check=True)
138-
subprocess.run(["createdb", *args], check=True)
139-
subprocess.run(["cmsInitDB"], check=True)
140-
result = subprocess.run(
141-
["pg_dump", "--schema-only", *args],
142-
check=True,
143-
text=True,
144-
stdout=subprocess.PIPE,
145-
)
146-
return result.stdout
147-
148-
149-
def main():
150-
parser = argparse.ArgumentParser()
151-
parser.add_argument("--user", required=True)
152-
parser.add_argument("--host", required=True)
153-
parser.add_argument("--name", required=True)
154-
parser.add_argument("--schema_sql", required=True)
155-
parser.add_argument("--updater_sql", required=True)
156-
args = parser.parse_args()
157-
print("Checking schema updater...")
158-
updated_schema = split_schemma(
159-
get_updated_schema(
160-
args.user, args.host, args.name, args.schema_sql, args.updater_sql
161-
)
162-
)
163-
fresh_schema = split_schemma(get_fresh_schema(args.user, args.host, args.name))
164-
if compare_schemas(updated_schema, fresh_schema):
165-
print("All good, updater works")
166-
sys.exit(0)
167-
else:
168-
sys.exit(1)
169-
170-
171-
if __name__ == "__main__":
172-
main()
131+
def get_updated_schema(schema_file: str, updater_file: str) -> str:
132+
drop_db()
133+
schema_sql = open(schema_file).read()
134+
updater_sql = open(updater_file).read()
135+
# We need to do this in two separate connections, since the schema_sql sets
136+
# some connection properties which we don't want.
137+
for sql in [schema_sql, updater_sql]:
138+
conn = custom_psycopg2_connection()
139+
cursor = conn.cursor()
140+
cursor.execute(sql)
141+
conn.commit()
142+
conn.close()
143+
144+
return run_pg_dump()
145+
146+
def get_fresh_schema():
147+
drop_db()
148+
init_db()
149+
return run_pg_dump()
150+
151+
class TestSchemaDiff(unittest.TestCase):
152+
def test_schema_diff(self):
153+
dirname = os.path.dirname(__file__)
154+
schema_file = os.path.join(dirname, "schema_v1.5.sql")
155+
updater_file = os.path.join(dirname, "../../cmscontrib/updaters/update_from_1.5.sql")
156+
updated_schema = split_schema(get_updated_schema(schema_file, updater_file))
157+
fresh_schema = split_schema(get_fresh_schema())
158+
errors = compare_schemas(updated_schema, fresh_schema)
159+
self.longMessage = False
160+
self.assertTrue(errors == "", errors)
File renamed without changes.

docker/_cms-test-internal.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,11 @@ cmsInitDB
1616
cmsRunFunctionalTests -v --coverage codecov/functionaltests.xml
1717
FUNC=$?
1818

19-
python cmstestsuite/check_schema_diff.py --user=postgres --host=testdb --name=cmsdbfortesting --schema_sql=cmstestsuite/schema_v1.5.sql --updater_sql=cmscontrib/updaters/update_from_1.5.sql
20-
SCHEMA=$?
21-
2219
# This check is needed because otherwise failing unit tests aren't reported in
2320
# the CI as long as the functional tests are passing. Ideally we should get rid
2421
# of `cmsRunFunctionalTests` and make those tests work with pytest so they can
2522
# be auto-discovered and run in a single command.
26-
if [ $UNIT -ne 0 ] || [ $FUNC -ne 0 ] || [ $SCHEMA -ne 0 ]
23+
if [ $UNIT -ne 0 ] || [ $FUNC -ne 0 ]
2724
then
2825
exit 1
2926
else

0 commit comments

Comments
 (0)