Skip to content

Commit f66ccb2

Browse files
committed
Manual rollback of committed data
1 parent c58a52a commit f66ccb2

File tree

2 files changed

+152
-2
lines changed

2 files changed

+152
-2
lines changed

pytest_invenio/database_tools.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import logging
2+
import sys
3+
4+
from sqlalchemy import MetaData, and_, func, select
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def store_database_values(engine, conn):
10+
"""Introspect the session, get all the tables and store their primary key values.
11+
12+
The result is a dict[table_name, list[pk_tuple]]
13+
"""
14+
print("Storing database values", file=sys.stderr, flush=True)
15+
16+
metadata = MetaData()
17+
metadata.reflect(engine)
18+
19+
dump = {}
20+
for table_name, table in metadata.tables.items():
21+
# Get primary key columns and foreign key columns
22+
pk_columns = [
23+
column
24+
for column in table.columns
25+
if column.primary_key or len(column.foreign_keys) > 0
26+
]
27+
28+
if not pk_columns:
29+
# Skip tables without primary keys
30+
continue
31+
32+
# Select only primary key columns
33+
result = conn.execute(select(*pk_columns))
34+
try:
35+
dump[table_name] = [tuple(row) for row in result.fetchall()]
36+
except Exception as ex:
37+
raise RuntimeError(f"Could not fetch rows from table {table_name}") from ex
38+
39+
print("Storing database values done", file=sys.stderr, flush=True)
40+
41+
return dump
42+
43+
44+
def purge_database_values(engine, conn, stored_values):
45+
"""Delete rows that are not in the stored values."""
46+
47+
print("Purging database values", file=sys.stderr, flush=True)
48+
metadata = MetaData()
49+
metadata.reflect(engine)
50+
51+
# Build a list of (table_name, delete_condition) tuples
52+
to_be_deleted = []
53+
54+
for table_name, table in metadata.tables.items():
55+
stored_rows = stored_values.get(table_name, [])
56+
57+
# Get primary key columns and foreign key columns
58+
pk_columns = [
59+
column
60+
for column in table.columns
61+
if column.primary_key or len(column.foreign_keys) > 0
62+
]
63+
64+
if not pk_columns:
65+
logger.warning(f"Table {table_name} has no primary key. Skipping.")
66+
continue
67+
68+
# Convert stored rows to a set of primary key tuples for fast lookup
69+
stored_pk_set = set(stored_rows)
70+
71+
# create a select statement that would include only rows that are not present
72+
# in the stored values. It will be not (pk1 == val1 and pk2 == val2 and ...) and not (...)
73+
row_matcher_conditions = []
74+
for stored_pk in stored_pk_set:
75+
condition = and_(
76+
pk_col == pk_val for pk_col, pk_val in zip(pk_columns, stored_pk)
77+
)
78+
# negate the condition to match rows that are not equal
79+
row_matcher_conditions.append(~condition)
80+
81+
if row_matcher_conditions:
82+
non_matching_condition = and_(*row_matcher_conditions)
83+
to_be_deleted.append(
84+
(table_name, table, non_matching_condition, len(stored_pk_set))
85+
)
86+
else:
87+
# delete everything
88+
to_be_deleted.append((table_name, table, None, len(stored_pk_set)))
89+
90+
# Try to delete rows with retry mechanism for foreign key constraints
91+
while to_be_deleted:
92+
print(
93+
"Purging iteration, tables to process:",
94+
[table_name for table_name, _, _, _ in to_be_deleted],
95+
file=sys.stderr,
96+
flush=True,
97+
)
98+
failed_deletions = []
99+
100+
for table_name, table, where_condition, expected_count in to_be_deleted:
101+
try:
102+
delete_stmt = table.delete()
103+
if where_condition is not None:
104+
delete_stmt = delete_stmt.where(where_condition)
105+
conn.execute(delete_stmt)
106+
conn.commit()
107+
existing_count = conn.execute(
108+
select(func.count()).select_from(table)
109+
).scalar()
110+
if existing_count != expected_count:
111+
logger.warning("Not all rows deleted as expected, will try again.")
112+
failed_deletions.append(
113+
(table_name, table, where_condition, expected_count)
114+
)
115+
except Exception:
116+
# Rollback on failure and retry in next iteration
117+
conn.rollback()
118+
failed_deletions.append(
119+
(table_name, table, where_condition, expected_count)
120+
)
121+
122+
if len(failed_deletions) == len(to_be_deleted):
123+
table_names = [table_name for table_name, _, _, _ in failed_deletions]
124+
raise RuntimeError(
125+
f"Could not delete any rows in this iteration due to foreign key cycles in tables: {table_names}"
126+
)
127+
else:
128+
# Update the list with failed deletions for next iteration
129+
to_be_deleted = failed_deletions
130+
131+
print("Purging database values done", file=sys.stderr, flush=True)

pytest_invenio/fixtures.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,11 @@ def db(database, db_session_options):
551551
"""
552552
from flask_sqlalchemy.session import Session as FlaskSQLAlchemySession
553553

554+
from pytest_invenio.database_tools import (
555+
purge_database_values,
556+
store_database_values,
557+
)
558+
554559
class PytestInvenioSession(FlaskSQLAlchemySession):
555560
def get_bind(self, mapper=None, clause=None, bind=None, **kwargs):
556561
if self.bind:
@@ -574,17 +579,31 @@ def rollback(self) -> None:
574579
)
575580
session = database._make_scoped_session(options=options)
576581

577-
session.begin_nested()
578-
579582
old_session = database.session
580583
database.session = session
581584

585+
# the session.rollback() does not always clean everything, if the test
586+
# used db.session.commit() and has not cleaned up after itself. We can not
587+
# use nested transactions because a lot of Invenio code would need to be updated
588+
# so that it is aware of the nested transaction concept. Instead, we store
589+
# the database values here and purge any new rows after the test.
590+
stored_values = store_database_values(database.engine, connection)
591+
582592
yield database
583593

594+
print("Rollback and original connection close", file=sys.stderr, flush=True)
584595
session.rollback()
585596
connection.close()
586597
database.session = old_session
587598

599+
print("Purging database changes", file=sys.stderr, flush=True)
600+
# use a brand new connection for the purge operation
601+
connection = database.engine.connect()
602+
connection.begin()
603+
purge_database_values(database.engine, connection, stored_values)
604+
connection.close()
605+
print("Purged and closed", file=sys.stderr, flush=True)
606+
588607

589608
@pytest.fixture(scope="function")
590609
def mailbox(base_app):

0 commit comments

Comments
 (0)