Skip to content

Commit 474799c

Browse files
committed
Manual rollback of committed data
1 parent c58a52a commit 474799c

File tree

2 files changed

+222
-16
lines changed

2 files changed

+222
-16
lines changed

pytest_invenio/database_tools.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# This file is part of pytest-invenio.
4+
# Copyright (C) 2025 CESNET i.l.e.
5+
#
6+
# pytest-invenio is free software; you can redistribute it and/or modify it
7+
# under the terms of the MIT License; see LICENSE file for more details.
8+
"""Database tools for test cleanup."""
9+
import logging
10+
11+
from sqlalchemy import MetaData, String, and_, func, select
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class InconsistentDatabaseError(Exception):
17+
"""Raised when the database is found to be inconsistent after test cleanup."""
18+
19+
pass
20+
21+
22+
def store_database_values(engine, conn):
23+
"""Introspect the session, get all the tables and store their primary key values.
24+
25+
The result is a dict[table_name, list[pk_tuple]]
26+
"""
27+
metadata = MetaData()
28+
metadata.reflect(engine)
29+
30+
dump = {}
31+
for table_name, table in metadata.tables.items():
32+
# Get primary key columns and foreign key columns
33+
pk_columns = [
34+
column
35+
for column in table.columns
36+
if column.primary_key or len(column.foreign_keys) > 0
37+
]
38+
39+
if not pk_columns:
40+
# Skip tables without primary keys
41+
continue
42+
43+
# Select only primary key columns, cast to string at database level
44+
pk_columns_as_string = [func.cast(col, String) for col in pk_columns]
45+
result = conn.execute(select(*pk_columns_as_string))
46+
try:
47+
dump[table_name] = [tuple(row) for row in result.fetchall()]
48+
except Exception as ex:
49+
raise RuntimeError(f"Could not fetch rows from table {table_name}") from ex
50+
51+
return dump
52+
53+
54+
def purge_database_values(engine, conn, stored_values):
55+
"""Delete rows that are not in the stored values."""
56+
metadata = MetaData()
57+
metadata.reflect(engine)
58+
59+
# Build a list of (table_name, delete_condition) tuples
60+
to_be_deleted = []
61+
62+
for table_name, table in metadata.tables.items():
63+
stored_rows = stored_values.get(table_name, [])
64+
65+
# Get primary key columns and foreign key columns
66+
pk_columns = [
67+
column
68+
for column in table.columns
69+
if column.primary_key or len(column.foreign_keys) > 0
70+
]
71+
72+
if not pk_columns:
73+
logger.warning(f"Table {table_name} has no primary key. Skipping.")
74+
continue
75+
76+
# Convert stored rows to a set of primary key tuples for fast lookup
77+
stored_pk_set = set(stored_rows)
78+
79+
# create a select statement that would include only rows that are not present
80+
# in the stored values. It will be not (pk1 == val1 and pk2 == val2 and ...) and not (...)
81+
row_matcher_conditions = []
82+
for stored_pk in stored_pk_set:
83+
# Cast columns to string at database level for comparison
84+
condition = and_(
85+
*(
86+
func.cast(pk_col, String) == pk_val
87+
for pk_col, pk_val in zip(pk_columns, stored_pk)
88+
)
89+
)
90+
# negate the condition to match rows that are not equal
91+
row_matcher_conditions.append(~condition)
92+
93+
if row_matcher_conditions:
94+
non_matching_condition = and_(*row_matcher_conditions)
95+
to_be_deleted.append(
96+
(table_name, table, non_matching_condition, len(stored_pk_set))
97+
)
98+
else:
99+
# delete everything
100+
to_be_deleted.append((table_name, table, None, len(stored_pk_set)))
101+
102+
# Try to delete rows with retry mechanism for foreign key constraints
103+
while to_be_deleted:
104+
failed_deletions = []
105+
106+
for table_name, table, where_condition, expected_count in to_be_deleted:
107+
# Execute deletion in a transaction so that we can rollback on failure
108+
with conn.begin():
109+
try:
110+
delete_stmt = table.delete()
111+
if where_condition is not None:
112+
delete_stmt = delete_stmt.where(where_condition)
113+
114+
conn.execute(delete_stmt)
115+
116+
existing_count = conn.execute(
117+
select(func.count()).select_from(table)
118+
).scalar()
119+
conn.commit()
120+
if expected_count > existing_count:
121+
122+
where_str = where_condition.compile(
123+
dialect=conn.dialect,
124+
compile_kwargs={"literal_binds": True},
125+
)
126+
127+
raise InconsistentDatabaseError(
128+
f"Expected to have {expected_count} rows in table {table_name} "
129+
f"in test cleanup but only {existing_count} remain after the test. "
130+
f"The test must have removed rows from module-level fixtures, "
131+
f"thus making the database inconsistent for subsequent tests."
132+
f"The conditions for rows: {where_str}"
133+
)
134+
logger.debug(
135+
"Deleted rows from table: %s, expected: %s, remaining: %s",
136+
table_name,
137+
expected_count,
138+
existing_count,
139+
)
140+
if existing_count != expected_count:
141+
logger.warning(
142+
"Not all rows deleted as expected, will try again."
143+
)
144+
failed_deletions.append(
145+
(table_name, table, where_condition, expected_count)
146+
)
147+
except InconsistentDatabaseError:
148+
# Reraise as the database is in an inconsistent state which can not be fixed
149+
raise
150+
except Exception:
151+
# Rollback on failure and retry in next iteration
152+
conn.rollback()
153+
failed_deletions.append(
154+
(table_name, table, where_condition, expected_count)
155+
)
156+
157+
if len(failed_deletions) == len(to_be_deleted):
158+
table_names = [table_name for table_name, _, _, _ in failed_deletions]
159+
raise RuntimeError(
160+
f"Could not delete the remaining rows due to foreign key cycles in tables: {table_names}"
161+
)
162+
else:
163+
# Update the list with failed deletions for next iteration
164+
to_be_deleted = failed_deletions

pytest_invenio/fixtures.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright (C) 2017-2025 CERN.
55
# Copyright (C) 2018 Esteban J. G. Garbancho.
66
# Copyright (C) 2024-2025 Graz University of Technology.
7+
# Copyright (C) 2025 CESNET i.l.e.
78
#
89
# pytest-invenio is free software; you can redistribute it and/or modify it
910
# under the terms of the MIT License; see LICENSE file for more details.
@@ -516,11 +517,29 @@ def _compile_unlogged(element, compiler, **kwargs):
516517

517518
db_.create_all()
518519

520+
# The test should use 1 connection (the same as the one here). If there is no
521+
# open connection at this point, assume 1 connection is used.
522+
connections_at_the_beginning = db_.engine.pool.checkedout() or 1
523+
519524
yield db_
520525

526+
# get the number of checked out connections at the end of the tests
527+
# if it is higher than at the beginning, some connections were not properly closed
528+
# i.e. we have potential connection leaks either in tests or in tested code
529+
connections_at_the_end = db_.engine.pool.checkedout() or 1
530+
if connections_at_the_beginning < connections_at_the_end:
531+
raise RuntimeError(
532+
"Database connections were not properly closed. "
533+
f"Connections at the beginning: {connections_at_the_beginning}, "
534+
f"at the end: {connections_at_the_end}."
535+
)
536+
521537
db_.session.remove()
522538
db_.drop_all()
523539

540+
# dispose of the engine to close the underlying connection pool
541+
db_.engine.dispose()
542+
524543

525544
@pytest.fixture(scope="function")
526545
def db_session_options():
@@ -551,6 +570,11 @@ def db(database, db_session_options):
551570
"""
552571
from flask_sqlalchemy.session import Session as FlaskSQLAlchemySession
553572

573+
from pytest_invenio.database_tools import (
574+
purge_database_values,
575+
store_database_values,
576+
)
577+
554578
class PytestInvenioSession(FlaskSQLAlchemySession):
555579
def get_bind(self, mapper=None, clause=None, bind=None, **kwargs):
556580
if self.bind:
@@ -563,27 +587,45 @@ def rollback(self) -> None:
563587
else:
564588
self._transaction.rollback(_to_root=False)
565589

566-
connection = database.engine.connect()
567-
connection.begin()
590+
# the session.rollback() does not always clean everything, if the test
591+
# used db.session.commit() and has not cleaned up after itself. We can not
592+
# use nested transactions because a lot of Invenio code would need to be updated
593+
# so that it is aware of the nested transaction concept. Instead, we store
594+
# the database values here and purge any new rows after the test.
595+
#
596+
# We do it in explicit connection to avoid issues in tests that drop all tables
597+
# (causes deadlock in alembic tests of invenio-pages on github actions, not
598+
# reproducible locally).
599+
with database.engine.connect() as connection:
600+
with connection.begin():
601+
stored_values = store_database_values(database.engine, connection)
602+
603+
with database.engine.connect() as connection:
604+
with connection.begin():
605+
606+
options = dict(
607+
bind=connection,
608+
binds={},
609+
**db_session_options,
610+
class_=PytestInvenioSession,
611+
)
568612

569-
options = dict(
570-
bind=connection,
571-
binds={},
572-
**db_session_options,
573-
class_=PytestInvenioSession,
574-
)
575-
session = database._make_scoped_session(options=options)
613+
session = database._make_scoped_session(options=options)
614+
615+
old_session = database.session
616+
database.session = session
576617

577-
session.begin_nested()
618+
yield database
578619

579-
old_session = database.session
580-
database.session = session
620+
session.rollback()
621+
database.session = old_session
581622

582-
yield database
623+
# use a brand new connection for the purge operation
624+
with database.engine.connect() as connection:
625+
purge_database_values(database.engine, connection, stored_values)
583626

584-
session.rollback()
585-
connection.close()
586-
database.session = old_session
627+
# expire all as there might be some stale data in the original database session
628+
database.session.expire_all()
587629

588630

589631
@pytest.fixture(scope="function")

0 commit comments

Comments
 (0)