Skip to content

Commit 5599c80

Browse files
authored
Add SQL schema updater checker to test suite (#1443)
1 parent 5326211 commit 5599c80

File tree

3 files changed

+2444
-2
lines changed

3 files changed

+2444
-2
lines changed

cms/db/task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ class Dataset(Base):
395395
nullable=True)
396396
memory_limit: int | None = Column(
397397
BigInteger,
398-
CheckConstraint("memory_limit > 0"),
399-
CheckConstraint("MOD(memory_limit, 1048576) = 0"),
398+
CheckConstraint("memory_limit > 0", name='datasets_memory_limit_check'),
399+
CheckConstraint("MOD(memory_limit, 1048576) = 0", name='datasets_memory_limit_check1'),
400400
nullable=True)
401401

402402
# Name of the TaskType child class suited for the task.
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import os
2+
import unittest
3+
import subprocess
4+
import difflib
5+
6+
from cms.conf import config
7+
from cms.db.drop import drop_db
8+
from cms.db.init import init_db
9+
from cms.db.session import custom_psycopg2_connection
10+
11+
"""
12+
Compare the DB schema obtained from upgrading an older version's database using
13+
an SQL updater, with the schema of a fresh install. These should be as close as
14+
possible, but there are a few quirks which means it's not possible for the
15+
updater to be perfect: columns can't be reordered, and enum values can't be
16+
removed. We thus sort columns in CREATE TABLE statements, and have special
17+
handing of enums that allows extra values in the updated form.
18+
19+
To make the diff output nicer in cases of mismatches, we first pair up
20+
statements by the first line (which, for most statements, just contains the
21+
affected object's name) and then diff the paired up statements. (One exception
22+
to the first line thing is ALTER TABLE ADD CONSTRAINT, in which the constraint
23+
name is on the second line. So we move the constraint name up to the first
24+
line.)
25+
26+
To update the files after a new release:
27+
28+
cmsInitDB
29+
pg_dump --schema-only >schema_vX.Y.sql
30+
31+
and replace update_from_vX.Y.sql with a blank file.
32+
"""
33+
34+
def split_schema(schema: str) -> list[list[str]]:
35+
statements: list[list[str]] = []
36+
cur_statement: list[str] = []
37+
for line in schema.splitlines():
38+
if line == "" or line.startswith("--"):
39+
continue
40+
cur_statement.append(line)
41+
if line.endswith(";"):
42+
statements.append(cur_statement)
43+
cur_statement = []
44+
assert cur_statement == []
45+
return statements
46+
47+
48+
def normalize_stmt(statement: list[str]) -> list[str]:
49+
if statement[0].startswith("CREATE TABLE "):
50+
# normalize order of columns by sorting the arguments to CREATE TABLE.
51+
52+
assert statement[-1] == ");"
53+
# add missing trailing comma on the last column.
54+
assert not statement[-2].endswith(",")
55+
statement[-2] += ","
56+
columns = statement[1:-1]
57+
columns.sort()
58+
return [statement[0]] + columns + [");"]
59+
elif (
60+
statement[0].startswith("ALTER TABLE ")
61+
and len(statement) > 1
62+
and statement[1].startswith(" ADD CONSTRAINT ")
63+
):
64+
# move the constraint name to the first line.
65+
name, rest = statement[1].removeprefix(" ADD CONSTRAINT ").split(" ", 1)
66+
return [statement[0] + " ADD CONSTRAINT " + name, rest] + statement[2:]
67+
else:
68+
return statement
69+
70+
71+
def is_create_enum(line: str) -> bool:
72+
return line.startswith("CREATE TYPE ") and line.endswith(" AS ENUM (")
73+
74+
75+
def compare_schemas(updated_schema: list[list[str]], fresh_schema: list[list[str]]) -> str:
76+
errors: list[str] = []
77+
78+
updated_map: dict[str, list[str]] = {}
79+
for stmt in map(normalize_stmt, updated_schema):
80+
assert stmt[0] not in updated_map
81+
updated_map[stmt[0]] = stmt
82+
83+
fresh_map: dict[str, list[str]] = {}
84+
for stmt in map(normalize_stmt, fresh_schema):
85+
assert stmt[0] not in fresh_map
86+
fresh_map[stmt[0]] = stmt
87+
88+
for updated_stmt in updated_map.values():
89+
if updated_stmt[0] not in fresh_map:
90+
errors += ["Updated schema contains extra statement:", *updated_stmt]
91+
else:
92+
fresh_stmt = fresh_map[updated_stmt[0]]
93+
if is_create_enum(updated_stmt[0]):
94+
# for enums, updated's values must be a superset of fresh.
95+
updated_values = {
96+
x.removesuffix(",").strip() for x in updated_stmt[1:-1]
97+
}
98+
fresh_values = {x.removesuffix(",").strip() for x in fresh_stmt[1:-1]}
99+
if not fresh_values.issubset(updated_values):
100+
errors += ["Updated schema is missing enum value(s):"]
101+
errors += ["Updated:"] + [" " + x for x in updated_stmt]
102+
errors += ["Fresh:"] + [" " + x for x in fresh_stmt]
103+
else:
104+
# Other statements must match exactly (in normalized form)
105+
if updated_stmt != fresh_stmt:
106+
differ = difflib.Differ()
107+
cmp = differ.compare(
108+
[x + "\n" for x in updated_stmt], [x + "\n" for x in fresh_stmt]
109+
)
110+
errors += ["Statement differs between updated and fresh schema:"]
111+
errors += ["".join(cmp).strip()]
112+
113+
for fresh_stmt in fresh_map.values():
114+
if fresh_stmt[0] not in updated_map:
115+
errors += ["Fresh schema contains extra statement:", *fresh_stmt]
116+
# if it exists, then it was already checked earlier
117+
# print('\n'.join(updated_map.keys()))
118+
return '\n'.join(errors)
119+
120+
def run_pg_dump() -> str:
121+
db_url = config.database.url
122+
db_url = db_url.replace("postgresql+psycopg2://", "postgresql://")
123+
result = subprocess.run(
124+
["pg_dump", "--schema-only", "--dbname", db_url],
125+
check=True,
126+
text=True,
127+
stdout=subprocess.PIPE,
128+
)
129+
return result.stdout
130+
131+
def get_updated_schema(schema_file: str, updater_file: str) -> str:
132+
drop_db()
133+
schema_sql = open(schema_file).read()
134+
# The schema sets the owner of every object explicitly. We actually want
135+
# these objects to be owned by whichever user CMS uses, so we skip the
136+
# OWNER TO commands and let the owners be defaulted to the current user.
137+
schema_sql = '\n'.join(
138+
line
139+
for line in schema_sql.splitlines()
140+
if not (line.startswith('ALTER ') and ' OWNER TO ' in line)
141+
)
142+
updater_sql = open(updater_file).read()
143+
# We need to do this in two separate connections, since the schema_sql sets
144+
# some connection properties which we don't want.
145+
for sql in [schema_sql, updater_sql]:
146+
conn = custom_psycopg2_connection()
147+
cursor = conn.cursor()
148+
cursor.execute(sql)
149+
conn.commit()
150+
conn.close()
151+
152+
return run_pg_dump()
153+
154+
def get_fresh_schema():
155+
drop_db()
156+
init_db()
157+
return run_pg_dump()
158+
159+
class TestSchemaDiff(unittest.TestCase):
160+
def test_schema_diff(self):
161+
dirname = os.path.dirname(__file__)
162+
schema_file = os.path.join(dirname, "schema_v1.5.sql")
163+
updater_file = os.path.join(dirname, "../../cmscontrib/updaters/update_from_1.5.sql")
164+
updated_schema = split_schema(get_updated_schema(schema_file, updater_file))
165+
fresh_schema = split_schema(get_fresh_schema())
166+
errors = compare_schemas(updated_schema, fresh_schema)
167+
self.longMessage = False
168+
self.assertTrue(errors == "", errors)

0 commit comments

Comments
 (0)