Skip to content

Commit 06db6fa

Browse files
committed
Moved the insertion of starting ISPyB values to the session factory fixture, and made the insertions idempotent
1 parent 3c656de commit 06db6fa

File tree

1 file changed

+75
-38
lines changed

1 file changed

+75
-38
lines changed

tests/conftest.py

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import json
22
from configparser import ConfigParser
33
from pathlib import Path
4-
from typing import Generator
4+
from typing import Any, Generator, Type, TypeVar
55

66
import ispyb
77
import pytest
88
from ispyb.sqlalchemy import BLSession, Person, Proposal, url
9-
from sqlalchemy import create_engine
9+
from sqlalchemy import and_, create_engine, select
10+
from sqlalchemy.ext.declarative import DeclarativeMeta
1011
from sqlalchemy.orm import Session as SQLAlchemySession
1112
from sqlalchemy.orm import scoped_session, sessionmaker
1213

@@ -117,50 +118,86 @@ def ispyb_engine(mock_ispyb_credentials):
117118
ispyb_engine.dispose()
118119

119120

121+
SQLAlchemyTable = TypeVar("SQLAlchemyTable", bound=DeclarativeMeta)
122+
123+
124+
def get_or_create_db_entry(
125+
session: SQLAlchemySession,
126+
table: Type[SQLAlchemyTable],
127+
lookup_kwargs: dict[str, Any] = {},
128+
insert_kwargs: dict[str, Any] = {},
129+
) -> SQLAlchemyTable:
130+
"""
131+
Helper function to facilitate looking up SQLAlchemy tables for
132+
matching entries. Returns the entry if it exists and creates it
133+
if it doesn't.
134+
"""
135+
136+
# if lookup kwargs are provided, check if entry exists
137+
if lookup_kwargs:
138+
conditions = [
139+
getattr(table, key) == value for key, value in lookup_kwargs.items()
140+
]
141+
entry = session.execute(select(table).where(and_(*conditions)))
142+
if entry:
143+
return entry
144+
# If not present, create and return new entry
145+
# Use new kwargs if provided; otherwise, use lookup kwargs
146+
insert_kwargs = insert_kwargs or lookup_kwargs
147+
entry = table(**insert_kwargs)
148+
session.add(entry)
149+
session.commit()
150+
return entry
151+
152+
120153
@pytest.fixture(scope="session")
121154
def ispyb_session_factory(ispyb_engine):
122-
return scoped_session(sessionmaker(bind=ispyb_engine))
155+
factory = scoped_session(sessionmaker(bind=ispyb_engine))
156+
ispyb_db = factory()
157+
158+
# Populate the ISPyB table with some initial values
159+
# Return existing table entry if already present
160+
person_db_entry = get_or_create_db_entry(
161+
session=ispyb_db,
162+
table=Person,
163+
lookup_kwargs={
164+
"givenName": ExampleVisit.given_name,
165+
"familyName": ExampleVisit.family_name,
166+
"login": ExampleVisit.login,
167+
},
168+
)
169+
proposal_db_entry = get_or_create_db_entry(
170+
session=ispyb_db,
171+
table=Proposal,
172+
lookup_kwargs={
173+
"personId": person_db_entry.personId,
174+
"proposalCode": ExampleVisit.proposal_code,
175+
"proposalNumber": str(ExampleVisit.proposal_number),
176+
},
177+
)
178+
bl_session_db_entry = get_or_create_db_entry(
179+
session=ispyb_db,
180+
table=BLSession,
181+
lookup_kwargs={
182+
"proposalId": proposal_db_entry.proposalId,
183+
"beamLineName": ExampleVisit.instrument_name,
184+
"visit_number": ExampleVisit.visit_number,
185+
},
186+
)
187+
ispyb_db.add(bl_session_db_entry)
188+
ispyb_db.commit()
189+
190+
ispyb_db.close()
191+
return factory # Return its current state
123192

124193

125194
@pytest.fixture
126195
def ispyb_db(ispyb_session_factory) -> Generator[SQLAlchemySession, None, None]:
127196
# Get a new session from the session factory
128197
ispyb_db: SQLAlchemySession = ispyb_session_factory()
129-
save_point = ispyb_db.begin_nested() # Checkpoint to roll back database to
130-
131-
try:
132-
# Populate the ISPyB table with some default values
133-
person_db_entry = Person(
134-
givenName=ExampleVisit.given_name,
135-
familyName=ExampleVisit.family_name,
136-
login=ExampleVisit.login,
137-
)
138-
ispyb_db.add(person_db_entry)
139-
ispyb_db.commit()
140-
141-
proposal_db_entry = Proposal(
142-
personId=person_db_entry.personId,
143-
proposalCode=ExampleVisit.proposal_code,
144-
proposalNumber=str(ExampleVisit.proposal_number),
145-
)
146-
ispyb_db.add(proposal_db_entry)
147-
ispyb_db.commit()
148-
149-
bl_session_db_entry = BLSession(
150-
proposalId=proposal_db_entry.proposalId,
151-
beamLineName=ExampleVisit.instrument_name,
152-
visit_number=ExampleVisit.visit_number,
153-
)
154-
ispyb_db.add(bl_session_db_entry)
155-
ispyb_db.commit()
156-
157-
# Yield the Session and pass processing over to other function
158-
yield ispyb_db
159-
160-
# Tidying up
161-
finally:
162-
save_point.rollback()
163-
ispyb_db.close()
198+
yield ispyb_db
199+
ispyb_db.rollback()
200+
ispyb_db.close()
164201

165202

166203
"""

0 commit comments

Comments
 (0)