|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | import json |
2 | 4 | import os |
3 | 5 | from configparser import ConfigParser |
|
11 | 13 | from sqlalchemy.ext.declarative import DeclarativeMeta |
12 | 14 | from sqlalchemy.orm import Session as SQLAlchemySession |
13 | 15 | from sqlalchemy.orm import sessionmaker |
| 16 | +from sqlmodel import Session as SQLModelSession |
14 | 17 | from sqlmodel import SQLModel |
15 | 18 |
|
16 | 19 | from murfey.util.db import Session as MurfeySession |
@@ -108,7 +111,7 @@ class ExampleVisit: |
108 | 111 |
|
109 | 112 |
|
110 | 113 | def get_or_create_db_entry( |
111 | | - session: SQLAlchemySession, |
| 114 | + session: SQLAlchemySession | SQLModelSession, |
112 | 115 | table: Type[SQLAlchemyTable], |
113 | 116 | lookup_kwargs: dict[str, Any] = {}, |
114 | 117 | insert_kwargs: dict[str, Any] = {}, |
@@ -139,7 +142,9 @@ def get_or_create_db_entry( |
139 | 142 | return entry |
140 | 143 |
|
141 | 144 |
|
142 | | -def restart_savepoint(session: SQLAlchemySession, transaction: RootTransaction): |
| 145 | +def restart_savepoint( |
| 146 | + session: SQLAlchemySession | SQLModelSession, transaction: RootTransaction |
| 147 | +): |
143 | 148 | """ |
144 | 149 | Re-establish a SAVEPOINT after a nested transaction is committed or rolled back. |
145 | 150 | This helps to maintain isolation across different test cases. |
@@ -260,37 +265,39 @@ def murfey_db_engine(): |
260 | 265 |
|
261 | 266 | @pytest.fixture(scope="session") |
262 | 267 | def murfey_db_session_factory(murfey_db_engine): |
263 | | - return sessionmaker(bind=murfey_db_engine, expire_on_commit=False) |
| 268 | + return sessionmaker( |
| 269 | + bind=murfey_db_engine, expire_on_commit=False, class_=SQLModelSession |
| 270 | + ) |
264 | 271 |
|
265 | 272 |
|
266 | 273 | @pytest.fixture(scope="session") |
267 | 274 | def seed_murfey_db(murfey_db_session_factory): |
268 | 275 | # Populate Murfey database with initial values |
269 | | - murfey_session: SQLAlchemySession = murfey_db_session_factory() |
| 276 | + session: SQLModelSession = murfey_db_session_factory() |
270 | 277 | _ = get_or_create_db_entry( |
271 | | - session=murfey_session, |
| 278 | + session=session, |
272 | 279 | table=MurfeySession, |
273 | 280 | lookup_kwargs={ |
274 | 281 | "id": ExampleVisit.murfey_session_id, |
275 | 282 | "name": f"{ExampleVisit.proposal_code}{ExampleVisit.proposal_number}-{ExampleVisit.visit_number}", |
276 | 283 | }, |
277 | 284 | ) |
278 | | - murfey_session.close() |
| 285 | + session.close() |
279 | 286 |
|
280 | 287 |
|
281 | 288 | @pytest.fixture |
282 | 289 | def murfey_db_session( |
283 | 290 | murfey_db_session_factory, |
284 | 291 | murfey_db_engine: Engine, |
285 | 292 | seed_murfey_db, |
286 | | -) -> Generator[SQLAlchemySession, None, None]: |
| 293 | +) -> Generator[SQLModelSession, None, None]: |
287 | 294 | """ |
288 | 295 | Returns a test-safe session that wraps each test in a rollback-safe save point |
289 | 296 | """ |
290 | 297 | connection = murfey_db_engine.connect() |
291 | 298 | transaction = connection.begin() |
292 | 299 |
|
293 | | - session: SQLAlchemySession = murfey_db_session_factory(bind=connection) |
| 300 | + session: SQLModelSession = murfey_db_session_factory(bind=connection) |
294 | 301 | session.begin_nested() # Save point for test |
295 | 302 |
|
296 | 303 | # Trigger the restart_savepoint function after the end of the transaction |
|
0 commit comments