Skip to content

Commit c4e3eef

Browse files
committed
Manual rollback of committed data
1 parent c58a52a commit c4e3eef

File tree

2 files changed

+158
-2
lines changed

2 files changed

+158
-2
lines changed

pytest_invenio/database_tools.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import logging
2+
import sys
3+
4+
from sqlalchemy import MetaData, and_, func, select
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def store_database_values(engine, conn):
10+
"""Introspect the session, get all the tables and store their primary key values.
11+
12+
The result is a dict[table_name, list[pk_tuple]]
13+
"""
14+
print("Storing database values", file=sys.stderr, flush=True)
15+
16+
metadata = MetaData()
17+
metadata.reflect(engine)
18+
19+
dump = {}
20+
for table_name, table in metadata.tables.items():
21+
# Get primary key columns and foreign key columns
22+
pk_columns = [
23+
column
24+
for column in table.columns
25+
if column.primary_key or len(column.foreign_keys) > 0
26+
]
27+
28+
if not pk_columns:
29+
# Skip tables without primary keys
30+
continue
31+
32+
# Select only primary key columns
33+
result = conn.execute(select(*pk_columns))
34+
try:
35+
dump[table_name] = [tuple(row) for row in result.fetchall()]
36+
except Exception as ex:
37+
raise RuntimeError(f"Could not fetch rows from table {table_name}") from ex
38+
39+
print("Storing database values done", file=sys.stderr, flush=True)
40+
41+
return dump
42+
43+
44+
def purge_database_values(engine, conn, stored_values):
45+
"""Delete rows that are not in the stored values."""
46+
47+
print("Purging database values", file=sys.stderr, flush=True)
48+
metadata = MetaData()
49+
metadata.reflect(engine)
50+
51+
# Build a list of (table_name, delete_condition) tuples
52+
to_be_deleted = []
53+
54+
for table_name, table in metadata.tables.items():
55+
stored_rows = stored_values.get(table_name, [])
56+
57+
# Get primary key columns and foreign key columns
58+
pk_columns = [
59+
column
60+
for column in table.columns
61+
if column.primary_key or len(column.foreign_keys) > 0
62+
]
63+
64+
if not pk_columns:
65+
logger.warning(f"Table {table_name} has no primary key. Skipping.")
66+
continue
67+
68+
# Convert stored rows to a set of primary key tuples for fast lookup
69+
stored_pk_set = set(stored_rows)
70+
71+
# create a select statement that would include only rows that are not present
72+
# in the stored values. It will be not (pk1 == val1 and pk2 == val2 and ...) and not (...)
73+
row_matcher_conditions = []
74+
for stored_pk in stored_pk_set:
75+
condition = and_(
76+
pk_col == pk_val for pk_col, pk_val in zip(pk_columns, stored_pk)
77+
)
78+
# negate the condition to match rows that are not equal
79+
row_matcher_conditions.append(~condition)
80+
81+
if row_matcher_conditions:
82+
non_matching_condition = and_(*row_matcher_conditions)
83+
to_be_deleted.append(
84+
(table_name, table, non_matching_condition, len(stored_pk_set))
85+
)
86+
else:
87+
# delete everything
88+
to_be_deleted.append((table_name, table, None, len(stored_pk_set)))
89+
90+
# Try to delete rows with retry mechanism for foreign key constraints
91+
while to_be_deleted:
92+
failed_deletions = []
93+
94+
for table_name, table, where_condition, expected_count in to_be_deleted:
95+
try:
96+
delete_stmt = table.delete()
97+
if where_condition is not None:
98+
delete_stmt = delete_stmt.where(where_condition)
99+
conn.execute(delete_stmt)
100+
conn.commit()
101+
existing_count = conn.execute(
102+
select(func.count()).select_from(table)
103+
).scalar()
104+
if existing_count != expected_count:
105+
logger.warning("Not all rows deleted as expected, will try again.")
106+
failed_deletions.append(
107+
(table_name, table, where_condition, expected_count)
108+
)
109+
except Exception:
110+
# Rollback on failure and retry in next iteration
111+
conn.rollback()
112+
failed_deletions.append(
113+
(table_name, table, where_condition, expected_count)
114+
)
115+
116+
if len(failed_deletions) == len(to_be_deleted):
117+
table_names = [table_name for table_name, _, _, _ in failed_deletions]
118+
raise RuntimeError(
119+
f"Could not delete any rows in this iteration due to foreign key cycles in tables: {table_names}"
120+
)
121+
else:
122+
# Update the list with failed deletions for next iteration
123+
to_be_deleted = failed_deletions
124+
125+
print("Purging database values done", file=sys.stderr, flush=True)

pytest_invenio/fixtures.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,11 @@ def db(database, db_session_options):
551551
"""
552552
from flask_sqlalchemy.session import Session as FlaskSQLAlchemySession
553553

554+
from pytest_invenio.database_tools import (
555+
purge_database_values,
556+
store_database_values,
557+
)
558+
554559
class PytestInvenioSession(FlaskSQLAlchemySession):
555560
def get_bind(self, mapper=None, clause=None, bind=None, **kwargs):
556561
if self.bind:
@@ -563,6 +568,22 @@ def rollback(self) -> None:
563568
else:
564569
self._transaction.rollback(_to_root=False)
565570

571+
# the session.rollback() does not always clean everything, if the test
572+
# used db.session.commit() and has not cleaned up after itself. We can not
573+
# use nested transactions because a lot of Invenio code would need to be updated
574+
# so that it is aware of the nested transaction concept. Instead, we store
575+
# the database values here and purge any new rows after the test.
576+
#
577+
# We do it in explicit connection to avoid issues in tests that drop all tables
578+
# (causes deadlock in alembic tests of invenio-pages on github actions, not
579+
# reproducible locally).
580+
print("MS PATCH Storing database values", file=sys.stderr, flush=True)
581+
connection = database.engine.connect()
582+
connection.begin()
583+
stored_values = store_database_values(database.engine, connection)
584+
connection.close()
585+
print("MS PATCH Stored", file=sys.stderr, flush=True)
586+
566587
connection = database.engine.connect()
567588
connection.begin()
568589

@@ -572,19 +593,29 @@ def rollback(self) -> None:
572593
**db_session_options,
573594
class_=PytestInvenioSession,
574595
)
575-
session = database._make_scoped_session(options=options)
576596

577-
session.begin_nested()
597+
session = database._make_scoped_session(options=options)
578598

579599
old_session = database.session
580600
database.session = session
581601

582602
yield database
583603

604+
print(
605+
"MS PATCH Rollback and original connection close", file=sys.stderr, flush=True
606+
)
584607
session.rollback()
585608
connection.close()
586609
database.session = old_session
587610

611+
print("MS PATCH Purging database changes", file=sys.stderr, flush=True)
612+
# use a brand new connection for the purge operation
613+
connection = database.engine.connect()
614+
connection.begin()
615+
purge_database_values(database.engine, connection, stored_values)
616+
connection.close()
617+
print("MS PATCH Purged and closed", file=sys.stderr, flush=True)
618+
588619

589620
@pytest.fixture(scope="function")
590621
def mailbox(base_app):

0 commit comments

Comments
 (0)