Skip to content

Commit 03d5638

Browse files
committed
Add SQL schema updater checker to test suite
1 parent 5326211 commit 03d5638

File tree

4 files changed

+2452
-3
lines changed

4 files changed

+2452
-3
lines changed

cms/db/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ class Dataset(Base):
395395
nullable=True)
396396
memory_limit: int | None = Column(
397397
BigInteger,
398-
CheckConstraint("memory_limit > 0"),
399-
CheckConstraint("MOD(memory_limit, 1048576) = 0"),
398+
CheckConstraint("memory_limit > 0", name='datasets_memory_limit_check'),
399+
CheckConstraint("MOD(memory_limit, 1048576) = 0", name='datasets_memory_limit_check1'),
400400
nullable=True)
401401

402402
# Name of the TaskType child class suited for the task.

cmstestsuite/check_schema_diff.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import argparse
2+
import subprocess
3+
import difflib
4+
import sys
5+
6+
"""
7+
Compare the DB schema obtained from upgrading an older version's database using
8+
an SQL updater, with the schema of a fresh install. These should be as close as
9+
possible, but there are a few quirks which means it's not possible for the
10+
updater to be perfect: columns can't be reordered, and enum values can't be
11+
removed. We thus sort columns in CREATE TABLE statements, and have special
12+
handing of enums that allows extra values in the updated form.
13+
14+
To make the diff output nicer in cases of mismatches, we first pair up
15+
statements by the first line (which, for most statements, just contains the
16+
affected object's name) and then diff the paired up statements. (One exception
17+
to the first line thing is ALTER TABLE ADD CONSTRAINT, in which the constraint
18+
name is on the second line. So we move the constraint name up to the first
19+
line.)
20+
"""
21+
22+
23+
def split_schemma(schema: str):
24+
statements: list[list[str]] = []
25+
cur_statement: list[str] = []
26+
for line in schema.splitlines():
27+
if line == "" or line.startswith("--"):
28+
continue
29+
cur_statement.append(line)
30+
if line.endswith(";"):
31+
statements.append(cur_statement)
32+
cur_statement = []
33+
assert cur_statement == []
34+
return statements
35+
36+
37+
def normalize_stmt(statement: list[str]):
38+
if statement[0].startswith("CREATE TABLE "):
39+
# normalize order of columns by sorting the arguments to CREATE TABLE.
40+
assert statement[-1] == ");"
41+
# add missing trailing comma on the last column.
42+
assert not statement[-2].endswith(",")
43+
statement[-2] += ","
44+
columns = statement[1:-1]
45+
columns.sort()
46+
return [statement[0]] + columns + [");"]
47+
elif (
48+
statement[0].startswith("ALTER TABLE ")
49+
and len(statement) > 1
50+
and statement[1].startswith(" ADD CONSTRAINT ")
51+
):
52+
# move the constraint name to the first line.
53+
name, rest = statement[1].removeprefix(" ADD CONSTRAINT ").split(" ", 1)
54+
return [statement[0] + " ADD CONSTRAINT " + name, rest] + statement[2:]
55+
else:
56+
return statement
57+
58+
59+
def is_create_enum(line: str):
60+
return line.startswith("CREATE TYPE ") and line.endswith(" AS ENUM (")
61+
62+
63+
def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str]]):
64+
ok = True
65+
66+
updated_map: dict[str, list[str]] = {}
67+
for stmt in map(normalize_stmt, updated_schema):
68+
assert stmt[0] not in updated_map
69+
updated_map[stmt[0]] = stmt
70+
71+
fresh_map: dict[str, list[str]] = {}
72+
for stmt in map(normalize_stmt, fresh_schema):
73+
assert stmt[0] not in fresh_map
74+
fresh_map[stmt[0]] = stmt
75+
76+
for updated_stmt in updated_map.values():
77+
if updated_stmt[0] not in fresh_map:
78+
print("Updated schema contains extra statement:", *updated_stmt, sep="\n")
79+
ok = False
80+
else:
81+
fresh_stmt = fresh_map[updated_stmt[0]]
82+
if is_create_enum(updated_stmt[0]):
83+
# for enums, updated's values must be a superset of fresh.
84+
updated_values = {
85+
x.removesuffix(",").strip() for x in updated_stmt[1:-1]
86+
}
87+
fresh_values = {x.removesuffix(",").strip() for x in fresh_stmt[1:-1]}
88+
if not fresh_values.issubset(updated_values):
89+
print("Updated schema is missing enum value(s):")
90+
print("Updated:\n " + "\n ".join(updated_stmt))
91+
print("Fresh:\n " + "\n ".join(fresh_stmt))
92+
else:
93+
# Other statements must match exactly (in normalized form)
94+
if updated_stmt != fresh_stmt:
95+
ok = False
96+
differ = difflib.Differ()
97+
cmp = differ.compare(
98+
[x + "\n" for x in updated_stmt], [x + "\n" for x in fresh_stmt]
99+
)
100+
print("Statement differs between updated and fresh schema:")
101+
print("".join(cmp))
102+
103+
for fresh_stmt in fresh_map.values():
104+
if fresh_stmt[0] not in updated_map:
105+
print("Fresh schema contains extra statement:", *fresh_stmt, sep="\n")
106+
ok = False
107+
# if it exists, then it was already checked earlier
108+
# print('\n'.join(updated_map.keys()))
109+
return ok
110+
111+
112+
def get_updated_schema(user, host, name, schema_sql, updater_sql):
113+
args = [f"--username={user}", f"--host={host}", name]
114+
psql_flags = ["--quiet", "--set=ON_ERROR_STOP=1"]
115+
subprocess.run(["dropdb", "--if-exists", *args], check=True)
116+
subprocess.run(["createdb", *args], check=True)
117+
subprocess.run(
118+
["psql", *args, *psql_flags, f"--file={schema_sql}"],
119+
check=True,
120+
stdout=subprocess.PIPE,
121+
)
122+
subprocess.run(
123+
["psql", *args, *psql_flags, f"--file={updater_sql}"],
124+
check=True,
125+
)
126+
result = subprocess.run(
127+
["pg_dump", "--schema-only", *args],
128+
check=True,
129+
text=True,
130+
stdout=subprocess.PIPE,
131+
)
132+
return result.stdout
133+
134+
135+
def get_fresh_schema(user, host, name):
136+
args = [f"--username={user}", f"--host={host}", name]
137+
subprocess.run(["dropdb", "--if-exists", *args], check=True)
138+
subprocess.run(["createdb", *args], check=True)
139+
subprocess.run(["cmsInitDB"], check=True)
140+
result = subprocess.run(
141+
["pg_dump", "--schema-only", *args],
142+
check=True,
143+
text=True,
144+
stdout=subprocess.PIPE,
145+
)
146+
return result.stdout
147+
148+
149+
def main():
150+
parser = argparse.ArgumentParser()
151+
parser.add_argument("--user", required=True)
152+
parser.add_argument("--host", required=True)
153+
parser.add_argument("--name", required=True)
154+
parser.add_argument("--schema_sql", required=True)
155+
parser.add_argument("--updater_sql", required=True)
156+
args = parser.parse_args()
157+
print("Checking schema updater...")
158+
updated_schema = split_schemma(
159+
get_updated_schema(
160+
args.user, args.host, args.name, args.schema_sql, args.updater_sql
161+
)
162+
)
163+
fresh_schema = split_schemma(get_fresh_schema(args.user, args.host, args.name))
164+
if compare_schemas(updated_schema, fresh_schema):
165+
print("All good, updater works")
166+
sys.exit(0)
167+
else:
168+
sys.exit(1)
169+
170+
171+
if __name__ == "__main__":
172+
main()

0 commit comments

Comments
 (0)