|
1 | 1 | import json |
2 | 2 | from configparser import ConfigParser |
3 | 3 | from pathlib import Path |
4 | | -from typing import Generator |
| 4 | +from typing import Any, Generator, Type, TypeVar |
5 | 5 |
|
6 | 6 | import ispyb |
7 | 7 | import pytest |
8 | 8 | from ispyb.sqlalchemy import BLSession, Person, Proposal, url |
9 | | -from sqlalchemy import create_engine |
| 9 | +from sqlalchemy import and_, create_engine, select |
| 10 | +from sqlalchemy.ext.declarative import DeclarativeMeta |
10 | 11 | from sqlalchemy.orm import Session as SQLAlchemySession |
11 | 12 | from sqlalchemy.orm import scoped_session, sessionmaker |
12 | 13 |
|
@@ -117,50 +118,86 @@ def ispyb_engine(mock_ispyb_credentials): |
117 | 118 | ispyb_engine.dispose() |
118 | 119 |
|
119 | 120 |
|
| 121 | +SQLAlchemyTable = TypeVar("SQLAlchemyTable", bound=DeclarativeMeta) |
| 122 | + |
| 123 | + |
| 124 | +def get_or_create_db_entry( |
| 125 | + session: SQLAlchemySession, |
| 126 | + table: Type[SQLAlchemyTable], |
| 127 | + lookup_kwargs: dict[str, Any] = {}, |
| 128 | + insert_kwargs: dict[str, Any] = {}, |
| 129 | +) -> SQLAlchemyTable: |
| 130 | + """ |
| 131 | + Helper function to facilitate looking up SQLAlchemy tables for |
| 132 | + matching entries. Returns the entry if it exists and creates it |
| 133 | + if it doesn't. |
| 134 | + """ |
| 135 | + |
| 136 | + # if lookup kwargs are provided, check if entry exists |
| 137 | + if lookup_kwargs: |
| 138 | + conditions = [ |
| 139 | + getattr(table, key) == value for key, value in lookup_kwargs.items() |
| 140 | + ] |
| 141 | + entry = session.execute(select(table).where(and_(*conditions))) |
| 142 | + if entry: |
| 143 | + return entry |
| 144 | + # If not present, create and return new entry |
| 145 | + # Use new kwargs if provided; otherwise, use lookup kwargs |
| 146 | + insert_kwargs = insert_kwargs or lookup_kwargs |
| 147 | + entry = table(**insert_kwargs) |
| 148 | + session.add(entry) |
| 149 | + session.commit() |
| 150 | + return entry |
| 151 | + |
| 152 | + |
120 | 153 | @pytest.fixture(scope="session") |
121 | 154 | def ispyb_session_factory(ispyb_engine): |
122 | | - return scoped_session(sessionmaker(bind=ispyb_engine)) |
| 155 | + factory = scoped_session(sessionmaker(bind=ispyb_engine)) |
| 156 | + ispyb_db = factory() |
| 157 | + |
| 158 | + # Populate the ISPyB table with some initial values |
| 159 | + # Return existing table entry if already present |
| 160 | + person_db_entry = get_or_create_db_entry( |
| 161 | + session=ispyb_db, |
| 162 | + table=Person, |
| 163 | + lookup_kwargs={ |
| 164 | + "givenName": ExampleVisit.given_name, |
| 165 | + "familyName": ExampleVisit.family_name, |
| 166 | + "login": ExampleVisit.login, |
| 167 | + }, |
| 168 | + ) |
| 169 | + proposal_db_entry = get_or_create_db_entry( |
| 170 | + session=ispyb_db, |
| 171 | + table=Proposal, |
| 172 | + lookup_kwargs={ |
| 173 | + "personId": person_db_entry.personId, |
| 174 | + "proposalCode": ExampleVisit.proposal_code, |
| 175 | + "proposalNumber": str(ExampleVisit.proposal_number), |
| 176 | + }, |
| 177 | + ) |
| 178 | + bl_session_db_entry = get_or_create_db_entry( |
| 179 | + session=ispyb_db, |
| 180 | + table=BLSession, |
| 181 | + lookup_kwargs={ |
| 182 | + "proposalId": proposal_db_entry.proposalId, |
| 183 | + "beamLineName": ExampleVisit.instrument_name, |
| 184 | + "visit_number": ExampleVisit.visit_number, |
| 185 | + }, |
| 186 | + ) |
| 187 | + ispyb_db.add(bl_session_db_entry) |
| 188 | + ispyb_db.commit() |
| 189 | + |
| 190 | + ispyb_db.close() |
| 191 | + return factory # Return its current state |
123 | 192 |
|
124 | 193 |
|
125 | 194 | @pytest.fixture |
126 | 195 | def ispyb_db(ispyb_session_factory) -> Generator[SQLAlchemySession, None, None]: |
127 | 196 | # Get a new session from the session factory |
128 | 197 | ispyb_db: SQLAlchemySession = ispyb_session_factory() |
129 | | - save_point = ispyb_db.begin_nested() # Checkpoint to roll back database to |
130 | | - |
131 | | - try: |
132 | | - # Populate the ISPyB table with some default values |
133 | | - person_db_entry = Person( |
134 | | - givenName=ExampleVisit.given_name, |
135 | | - familyName=ExampleVisit.family_name, |
136 | | - login=ExampleVisit.login, |
137 | | - ) |
138 | | - ispyb_db.add(person_db_entry) |
139 | | - ispyb_db.commit() |
140 | | - |
141 | | - proposal_db_entry = Proposal( |
142 | | - personId=person_db_entry.personId, |
143 | | - proposalCode=ExampleVisit.proposal_code, |
144 | | - proposalNumber=str(ExampleVisit.proposal_number), |
145 | | - ) |
146 | | - ispyb_db.add(proposal_db_entry) |
147 | | - ispyb_db.commit() |
148 | | - |
149 | | - bl_session_db_entry = BLSession( |
150 | | - proposalId=proposal_db_entry.proposalId, |
151 | | - beamLineName=ExampleVisit.instrument_name, |
152 | | - visit_number=ExampleVisit.visit_number, |
153 | | - ) |
154 | | - ispyb_db.add(bl_session_db_entry) |
155 | | - ispyb_db.commit() |
156 | | - |
157 | | - # Yield the Session and pass processing over to other function |
158 | | - yield ispyb_db |
159 | | - |
160 | | - # Tidying up |
161 | | - finally: |
162 | | - save_point.rollback() |
163 | | - ispyb_db.close() |
| 198 | + yield ispyb_db |
| 199 | + ispyb_db.rollback() |
| 200 | + ispyb_db.close() |
164 | 201 |
|
165 | 202 |
|
166 | 203 | """ |
|
0 commit comments