|
6 | 6 | import ispyb |
7 | 7 | import pytest |
8 | 8 | from ispyb.sqlalchemy import BLSession, Person, Proposal, url |
9 | | -from sqlalchemy import and_, create_engine, select |
| 9 | +from sqlalchemy import Engine, RootTransaction, and_, create_engine, event, select |
10 | 10 | from sqlalchemy.ext.declarative import DeclarativeMeta |
11 | 11 | from sqlalchemy.orm import Session as SQLAlchemySession |
12 | | -from sqlalchemy.orm import scoped_session, sessionmaker |
| 12 | +from sqlalchemy.orm import sessionmaker |
13 | 13 |
|
14 | 14 | from murfey.util.db import Session as MurfeySession |
15 | 15 | from murfey.util.db import clear, setup |
@@ -118,6 +118,11 @@ def ispyb_engine(mock_ispyb_credentials): |
118 | 118 | ispyb_engine.dispose() |
119 | 119 |
|
120 | 120 |
|
| 121 | +@pytest.fixture(scope="session") |
| 122 | +def ispyb_db_session_factory(ispyb_engine): |
| 123 | + return sessionmaker(bind=ispyb_engine, expire_on_commit=False) |
| 124 | + |
| 125 | + |
121 | 126 | SQLAlchemyTable = TypeVar("SQLAlchemyTable", bound=DeclarativeMeta) |
122 | 127 |
|
123 | 128 |
|
@@ -154,12 +159,11 @@ def get_or_create_db_entry( |
154 | 159 |
|
155 | 160 |
|
156 | 161 | @pytest.fixture(scope="session") |
157 | | -def ispyb_db_session_factory(ispyb_engine): |
158 | | - factory = scoped_session(sessionmaker(bind=ispyb_engine)) |
| 162 | +def seed_ispyb_db(ispyb_db_session_factory): |
159 | 163 |
|
160 | 164 | # Populate the ISPyB table with some initial values |
161 | 165 | # Return existing table entry if already present |
162 | | - ispyb_db_session = factory() |
| 166 | + ispyb_db_session: SQLAlchemySession = ispyb_db_session_factory() |
163 | 167 | person_db_entry = get_or_create_db_entry( |
164 | 168 | session=ispyb_db_session, |
165 | 169 | table=Person, |
@@ -188,23 +192,48 @@ def ispyb_db_session_factory(ispyb_engine): |
188 | 192 | }, |
189 | 193 | ) |
190 | 194 | ispyb_db_session.close() |
191 | | - return factory # Return its current state |
| 195 | + |
| 196 | + |
| 197 | +def restart_savepoint(session: SQLAlchemySession, transaction: RootTransaction): |
| 198 | + """ |
| 199 | + Re-establish a SAVEPOINT after a nested transaction is committed or rolled back. |
| 200 | + This helps to maintain isolation across different test cases. |
| 201 | + """ |
| 202 | + if transaction.nested and not transaction._parent.nested: |
| 203 | + session.begin_nested() |
| 204 | + |
| 205 | + |
| 206 | +def attach_event_listener(session: SQLAlchemySession): |
| 207 | + """ |
| 208 | + Attach the restart_savepoint function as an event listener for after_transaction_end |
| 209 | + """ |
| 210 | + event.listen(session, "after_transaction_end", restart_savepoint) |
192 | 211 |
|
193 | 212 |
|
194 | 213 | @pytest.fixture |
195 | 214 | def ispyb_db_session( |
196 | 215 | ispyb_db_session_factory, |
| 216 | + ispyb_engine: Engine, |
| 217 | + seed_ispyb_db, |
197 | 218 | ) -> Generator[SQLAlchemySession, None, None]: |
| 219 | + """ |
| 220 | + Returns a test-safe session that wraps each test in a rollback-safe SAVEPOINT. |
| 221 | + """ |
| 222 | + connection = ispyb_engine.connect() |
| 223 | + transaction = connection.begin() # Outer transaction |
198 | 224 |
|
199 | | - # Get a new session from the session factory |
200 | | - ispyb_db_session: SQLAlchemySession = ispyb_db_session_factory() |
| 225 | + session: SQLAlchemySession = ispyb_db_session_factory(bind=connection) |
| 226 | + session.begin_nested() # Save point for test |
201 | 227 |
|
202 | | - # Let other function run |
203 | | - yield ispyb_db_session |
| 228 | + # Attach the listener to the session for this connection |
| 229 | + attach_event_listener(session) |
204 | 230 |
|
205 | | - # Tidy up after function is complete |
206 | | - ispyb_db_session.rollback() |
207 | | - ispyb_db_session.close() |
| 231 | + try: |
| 232 | + yield session |
| 233 | + finally: |
| 234 | + session.close() |
| 235 | + transaction.rollback() |
| 236 | + connection.close() |
208 | 237 |
|
209 | 238 |
|
210 | 239 | """ |
|
0 commit comments