Skip to content

Commit a61a45b

Browse files
committed
Manual rollback of committed data
1 parent c58a52a commit a61a45b

File tree

2 files changed

+210
-16
lines changed

2 files changed

+210
-16
lines changed

pytest_invenio/database_tools.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import logging
2+
3+
from sqlalchemy import MetaData, String, and_, func, select
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
class InconsistentDatabaseError(Exception):
9+
pass
10+
11+
12+
def store_database_values(engine, conn):
13+
"""Introspect the session, get all the tables and store their primary key values.
14+
15+
The result is a dict[table_name, list[pk_tuple]]
16+
"""
17+
18+
metadata = MetaData()
19+
metadata.reflect(engine)
20+
21+
dump = {}
22+
for table_name, table in metadata.tables.items():
23+
# Get primary key columns and foreign key columns
24+
pk_columns = [
25+
column
26+
for column in table.columns
27+
if column.primary_key or len(column.foreign_keys) > 0
28+
]
29+
30+
if not pk_columns:
31+
# Skip tables without primary keys
32+
continue
33+
34+
# Select only primary key columns, cast to string at database level
35+
pk_columns_as_string = [func.cast(col, String) for col in pk_columns]
36+
result = conn.execute(select(*pk_columns_as_string))
37+
try:
38+
dump[table_name] = [tuple(row) for row in result.fetchall()]
39+
except Exception as ex:
40+
raise RuntimeError(f"Could not fetch rows from table {table_name}") from ex
41+
42+
return dump
43+
44+
45+
def purge_database_values(engine, conn, stored_values):
46+
"""Delete rows that are not in the stored values."""
47+
48+
metadata = MetaData()
49+
metadata.reflect(engine)
50+
51+
# Build a list of (table_name, delete_condition) tuples
52+
to_be_deleted = []
53+
54+
for table_name, table in metadata.tables.items():
55+
stored_rows = stored_values.get(table_name, [])
56+
57+
# Get primary key columns and foreign key columns
58+
pk_columns = [
59+
column
60+
for column in table.columns
61+
if column.primary_key or len(column.foreign_keys) > 0
62+
]
63+
64+
if not pk_columns:
65+
logger.warning(f"Table {table_name} has no primary key. Skipping.")
66+
continue
67+
68+
# Convert stored rows to a set of primary key tuples for fast lookup
69+
stored_pk_set = set(stored_rows)
70+
71+
# create a select statement that would include only rows that are not present
72+
# in the stored values. It will be not (pk1 == val1 and pk2 == val2 and ...) and not (...)
73+
row_matcher_conditions = []
74+
for stored_pk in stored_pk_set:
75+
# Cast columns to string at database level for comparison
76+
condition = and_(
77+
*(
78+
func.cast(pk_col, String) == pk_val
79+
for pk_col, pk_val in zip(pk_columns, stored_pk)
80+
)
81+
)
82+
# negate the condition to match rows that are not equal
83+
row_matcher_conditions.append(~condition)
84+
85+
if row_matcher_conditions:
86+
non_matching_condition = and_(*row_matcher_conditions)
87+
to_be_deleted.append(
88+
(table_name, table, non_matching_condition, len(stored_pk_set))
89+
)
90+
else:
91+
# delete everything
92+
to_be_deleted.append((table_name, table, None, len(stored_pk_set)))
93+
94+
# Try to delete rows with retry mechanism for foreign key constraints
95+
while to_be_deleted:
96+
failed_deletions = []
97+
98+
for table_name, table, where_condition, expected_count in to_be_deleted:
99+
# Execute deletion in a transaction so that we can rollback on failure
100+
with conn.begin():
101+
try:
102+
delete_stmt = table.delete()
103+
if where_condition is not None:
104+
delete_stmt = delete_stmt.where(where_condition)
105+
106+
conn.execute(delete_stmt)
107+
108+
existing_count = conn.execute(
109+
select(func.count()).select_from(table)
110+
).scalar()
111+
conn.commit()
112+
if expected_count > existing_count:
113+
114+
where_str = where_condition.compile(
115+
dialect=conn.dialect,
116+
compile_kwargs={"literal_binds": True},
117+
)
118+
119+
raise InconsistentDatabaseError(
120+
f"Expected to have {expected_count} rows in table {table_name} "
121+
f"in test cleanup but only {existing_count} remain after the test. "
122+
f"The test must have removed rows from module-level fixtures, "
123+
f"thus making the database inconsistent for subsequent tests."
124+
f"The conditions for rows: {where_str}"
125+
)
126+
logger.debug(
127+
"Deleted rows from table: %s, expected: %s, remaining: %s",
128+
table_name,
129+
expected_count,
130+
existing_count,
131+
)
132+
if existing_count != expected_count:
133+
logger.warning(
134+
"Not all rows deleted as expected, will try again."
135+
)
136+
failed_deletions.append(
137+
(table_name, table, where_condition, expected_count)
138+
)
139+
except InconsistentDatabaseError:
140+
# Reraise as the database is in an inconsistent state which can not be fixed
141+
raise
142+
except Exception:
143+
# Rollback on failure and retry in next iteration
144+
conn.rollback()
145+
failed_deletions.append(
146+
(table_name, table, where_condition, expected_count)
147+
)
148+
149+
if len(failed_deletions) == len(to_be_deleted):
150+
table_names = [table_name for table_name, _, _, _ in failed_deletions]
151+
raise RuntimeError(
152+
f"Could not delete the remaining rows due to foreign key cycles in tables: {table_names}"
153+
)
154+
else:
155+
# Update the list with failed deletions for next iteration
156+
to_be_deleted = failed_deletions

pytest_invenio/fixtures.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -516,8 +516,23 @@ def _compile_unlogged(element, compiler, **kwargs):
516516

517517
db_.create_all()
518518

519+
# The test should use 1 connection (the same as the one here). If there is no
520+
# open connection at this point, assume 1 connection is used.
521+
connections_at_the_beginning = db_.engine.pool.checkedout() or 1
522+
519523
yield db_
520524

525+
# get the number of checked out connections at the end of the tests
526+
# if it is higher than at the beginning, some connections were not properly closed
527+
# i.e. we have potential connection leaks either in tests or in tested code
528+
connections_at_the_end = db_.engine.pool.checkedout() or 1
529+
if connections_at_the_beginning < connections_at_the_end:
530+
raise RuntimeError(
531+
"Database connections were not properly closed. "
532+
f"Connections at the beginning: {connections_at_the_beginning}, "
533+
f"at the end: {connections_at_the_end}."
534+
)
535+
521536
db_.session.remove()
522537
db_.drop_all()
523538

@@ -551,6 +566,11 @@ def db(database, db_session_options):
551566
"""
552567
from flask_sqlalchemy.session import Session as FlaskSQLAlchemySession
553568

569+
from pytest_invenio.database_tools import (
570+
purge_database_values,
571+
store_database_values,
572+
)
573+
554574
class PytestInvenioSession(FlaskSQLAlchemySession):
555575
def get_bind(self, mapper=None, clause=None, bind=None, **kwargs):
556576
if self.bind:
@@ -563,27 +583,45 @@ def rollback(self) -> None:
563583
else:
564584
self._transaction.rollback(_to_root=False)
565585

566-
connection = database.engine.connect()
567-
connection.begin()
586+
# the session.rollback() does not always clean everything, if the test
587+
# used db.session.commit() and has not cleaned up after itself. We can not
588+
# use nested transactions because a lot of Invenio code would need to be updated
589+
# so that it is aware of the nested transaction concept. Instead, we store
590+
# the database values here and purge any new rows after the test.
591+
#
592+
# We do it in explicit connection to avoid issues in tests that drop all tables
593+
# (causes deadlock in alembic tests of invenio-pages on github actions, not
594+
# reproducible locally).
595+
with database.engine.connect() as connection:
596+
with connection.begin():
597+
stored_values = store_database_values(database.engine, connection)
598+
599+
with database.engine.connect() as connection:
600+
with connection.begin():
601+
602+
options = dict(
603+
bind=connection,
604+
binds={},
605+
**db_session_options,
606+
class_=PytestInvenioSession,
607+
)
568608

569-
options = dict(
570-
bind=connection,
571-
binds={},
572-
**db_session_options,
573-
class_=PytestInvenioSession,
574-
)
575-
session = database._make_scoped_session(options=options)
609+
session = database._make_scoped_session(options=options)
610+
611+
old_session = database.session
612+
database.session = session
576613

577-
session.begin_nested()
614+
yield database
578615

579-
old_session = database.session
580-
database.session = session
616+
session.rollback()
617+
database.session = old_session
581618

582-
yield database
619+
# use a brand new connection for the purge operation
620+
with database.engine.connect() as connection:
621+
purge_database_values(database.engine, connection, stored_values)
583622

584-
session.rollback()
585-
connection.close()
586-
database.session = old_session
623+
# expire all as there might be some stale data in the original database session
624+
database.session.expire_all()
587625

588626

589627
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)