Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c6331ea
removed global_session
qcdyx Nov 5, 2024
215046a
updated batch-datasets cloud functions
qcdyx Nov 6, 2024
fb25b98
updated batch-process-datasets cloud functions
qcdyx Nov 6, 2024
3e3f1b2
modified extract_location
qcdyx Nov 12, 2024
c5d4c35
applied psycopg2 connection pooling
qcdyx Nov 14, 2024
130092a
code refactoring: implemented a with_db_session decorator to streamli…
qcdyx Nov 20, 2024
099cae4
removed SHOULD_CLOSE_DB_SESSION environment variable
qcdyx Nov 22, 2024
0092299
used with_db_session decorator to manage session in GCP functions
qcdyx Nov 24, 2024
edd6665
refactored cloud functions db session management
qcdyx Nov 25, 2024
0ccb413
fixed test
qcdyx Nov 25, 2024
82aaef6
more refactoring
qcdyx Nov 25, 2024
a8a62b0
updated FEEDS_DATABASE_URL
qcdyx Nov 29, 2024
aba9808
Merge branch 'main' into 293-Psycopg
qcdyx Nov 29, 2024
bae94e0
cleanup
qcdyx Nov 29, 2024
94884ae
fixed broken tests
qcdyx Dec 2, 2024
6662557
Merge branch 'main' into 293-Psycopg
qcdyx Dec 2, 2024
111b29e
fixed lint errors
qcdyx Dec 2, 2024
88c0528
Merge branch 'main' into 293-Psycopg
qcdyx Dec 17, 2024
996298c
resolved PR comments
qcdyx Dec 17, 2024
d1f7a4b
lint error fixes
qcdyx Dec 17, 2024
5707126
Merge branch 'main' into 293-Psycopg
qcdyx Dec 17, 2024
ecbd755
skip the test geocoding
qcdyx Dec 17, 2024
df0c70e
added pytest import
qcdyx Dec 17, 2024
f5d5d80
temporarily change coverage threshold to 80
qcdyx Dec 17, 2024
65962ea
added back await and use the with statement, no @with_db_session
qcdyx Dec 17, 2024
48a242e
used with statement
qcdyx Dec 17, 2024
b95e99b
fixed lint errors
qcdyx Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 63 additions & 211 deletions api/src/database/database.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from contextlib import contextmanager
import itertools
import os
import threading
import uuid
from typing import Type, Callable
from dotenv import load_dotenv
from sqlalchemy import create_engine, inspect
from sqlalchemy import create_engine
from sqlalchemy.orm import load_only, Query, class_mapper, Session

from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed
from sqlalchemy.orm import sessionmaker
import logging
from typing import Final


SHOULD_CLOSE_DB_SESSION: Final[str] = "SHOULD_CLOSE_DB_SESSION"
lock = threading.Lock()
global_session = None


def generate_unique_id() -> str:
Expand Down Expand Up @@ -48,6 +44,37 @@ def configure_polymorphic_mappers():
gbfsfeed_mapper.polymorphic_identity = Gbfsfeed.__tablename__.lower()


def with_db_session(func):
"""
Decorator to handle the session management for the decorated function.

This decorator ensures that a database session is properly created, committed, rolled back in case of an exception,
and closed. It uses the @contextmanager decorator to manage the lifecycle of the session, providing a clean and
efficient way to handle database interactions.

How it works:
- The decorator checks if a 'db_session' keyword argument is provided to the decorated function.
- If 'db_session' is not provided, it creates a new Database instance and starts a new session using the
start_db_session context manager.
- The context manager ensures that the session is properly committed if no exceptions occur, rolled back if an
exception occurs, and closed in either case.
- The session is then passed to the decorated function as the 'db_session' keyword argument.
- If 'db_session' is already provided, it simply calls the decorated function with the existing session.
"""

def wrapper(*args, **kwargs):
db_session = kwargs.get("db_session")
if db_session is None:
db = Database()
with db.start_db_session() as session:
kwargs["db_session"] = session
return func(*args, **kwargs)
else:
return func(*args, **kwargs)

return wrapper


class Database:
"""
This class represents a database instance
Expand All @@ -59,7 +86,7 @@ class Database:

def __new__(cls, *args, **kwargs):
if not isinstance(cls.instance, cls):
with lock:
with cls.lock:
if not isinstance(cls.instance, cls):
cls.instance = object.__new__(cls)
return cls.instance
Expand All @@ -77,82 +104,51 @@ def __init__(self, echo_sql=False):
with Database.lock:
if Database.initialized:
return

Database.initialized = True
load_dotenv()
self.engine = None
self.logger = logging.getLogger(__name__)
self.connection_attempts = 0
self.SQLALCHEMY_DATABASE_URL = os.getenv("FEEDS_DATABASE_URL")
self.echo_sql = echo_sql
self.start_session()
database_url = os.getenv("FEEDS_DATABASE_URL")
if database_url is None:
raise Exception("Database URL not provided.")
self.engine = create_engine(database_url, echo=echo_sql, pool_size=10, max_overflow=0)
# creates a session factory
self.Session = sessionmaker(bind=self.engine, autoflush=False)

def is_connected(self):
"""
Checks the connection status
:return: True if the database is accessible False otherwise
"""
return self.engine is not None or global_session is not None
return self.engine is not None or self.session is not None

def start_session(self):
"""
:return: Database singleton session
@contextmanager
def start_db_session(self):
"""
global global_session
try:
if global_session is not None:
logging.info("Database session reused.")
return global_session
if global_session is None or not global_session.is_active:
global_session = self.start_new_db_session()
logging.info("Global Singleton Database session started.")
return global_session
except Exception as error:
raise Exception(f"Error creating database session: {error}")

def start_new_db_session(self):
global global_session
try:
lock.acquire()
if global_session is not None and global_session.is_active:
logging.info("Database session reused.")
return global_session
if self.SQLALCHEMY_DATABASE_URL is None:
raise Exception("Database URL is not set")
else:
logging.info("Starting new global database session.")
self.engine = create_engine(self.SQLALCHEMY_DATABASE_URL, echo=self.echo_sql)
global_session = sessionmaker(bind=self.engine)()
global_session.expire_on_commit = False
self.session = global_session
return global_session
except Exception as error:
raise Exception(f"Error creating database session: {error}")
finally:
lock.release()

def should_close_db_session(self):
return os.getenv("%s" % SHOULD_CLOSE_DB_SESSION, "false").lower() == "true"
Context manager to start a database session with optional echo.

def close_session(self):
"""
Closes a session
:return: True if the session was started, False otherwise
This method manages the lifecycle of a database session, ensuring that the session is properly created,
committed, rolled back in case of an exception, and closed. The @contextmanager decorator simplifies
resource management by handling the setup and cleanup logic within a single function.
"""
session = self.Session()
try:
should_close = self.should_close_db_session()
if should_close and global_session is not None and global_session.is_active:
global_session.close()
logging.info("Database session closed.")
except Exception as e:
logging.error(f"Session closing failed with exception: \n {e}")
return self.is_connected()
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

def select(
self,
session: "Session",
model: Type[Base] = None,
query: Query = None,
conditions: list = None,
attributes: list = None,
update_session: bool = True,
limit: int = None,
offset: int = None,
group_by: Callable = None,
Expand All @@ -171,10 +167,8 @@ def select(
:return: None if database is inaccessible, the results of the query otherwise
"""
try:
if update_session:
self.start_session()
if query is None:
query = global_session.query(model)
query = session.query(model)
if conditions:
for condition in conditions:
query = query.filter(condition)
Expand All @@ -184,159 +178,17 @@ def select(
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
results = global_session.execute(query).all()
results = session.execute(query).all()
if group_by:
return [list(group) for _, group in itertools.groupby(results, group_by)]
return results
except Exception as e:
logging.error(f"SELECT query failed with exception: \n{e}")
if global_session is not None:
global_session.rollback()
self.logger.error(f"SELECT query failed with exception: \n{e}")
return None
finally:
if update_session:
self.close_session()

def get_session(self) -> Session:
"""
:return: the current session
"""
return self.session

def get_query_model(self, model: Type[Base]) -> Query:
def get_query_model(self, session: Session, model: Type[Base]) -> Query:
"""
:param model: the sqlalchemy model to query
:return: the query model
"""
return self.get_session().query(model)

def select_from_active_session(self, model: Base, conditions: list = None, attributes: list = None):
"""
Select an object within the uncommitted session objects
:param model: the sqlalchemy model to query
:param conditions: list of conditions (filters for the query)
:param attributes: list of model's attribute names that you want to fetch. If not given, fetches all attributes.
:return: Empty list if database is inaccessible, the results of the query otherwise
"""
try:
if not global_session or not global_session.is_active:
raise Exception("Inactive session")
results = [obj for obj in global_session.new if isinstance(obj, model)]
if conditions:
for condition in conditions:
attribute_name = condition.left.name
attribute_value = condition.right.value
results = [result for result in results if getattr(result, attribute_name) == attribute_value]
if attributes:
results = [{attr: getattr(obj, attr) for attr in attributes} for obj in results]
return results
except Exception as e:
logging.error(f"Object selection within the uncommitted session objects failed with exception: \n{e}")
return []

def merge(
self,
orm_object: Base,
update_session: bool = False,
auto_commit: bool = False,
load: bool = True,
):
"""
Updates or inserts an object in the database
:param orm_object: the modeled object to update or insert
:param update_session: option to update the session before running the merge query (defaults to False)
:param auto_commit: option to automatically commit merge (defaults to False)
:param load: controls whether the database should be queried for the object being merged (defaults to True)
:return: True if merge was successful, False otherwise
"""
try:
if update_session:
self.start_session()
global_session.merge(orm_object, load=load)
if auto_commit:
global_session.commit()
return True
except Exception as e:
logging.error(f"Merge query failed with exception: \n{e}")
return False
# finally:
# if not update_session:
# self.close_session()

def commit(self):
"""
Commits the changes in the current session i.e. synch the changes with the database
and close the session
:return: True if commit was successful, False otherwise
"""
try:
if global_session is not None and global_session.is_active:
global_session.commit()
return True
return False
except Exception as e:
logging.error(f"Commit failed with exception: \n{e}")
return False
finally:
if global_session is not None:
global_session.close()

def flush(self):
"""
Flush the active session i.e. synch the changes with the database but keep the
session active
:return: True if flush was successful, False otherwise
"""
try:
if global_session is not None and global_session.is_active:
global_session.flush()
return True
return False
except Exception as e:
logging.error(f"Flush failed with exception: \n{e}")
return False

def merge_relationship(
self,
parent_model: Base.__class__,
parent_key_values: dict,
child: Base,
relationship_name: str,
update_session: bool = False,
auto_commit: bool = False,
uncommitted: bool = False,
):
"""
Adds a child instance to a parent's related items. If the parent doesn't exist, it creates a new one.
:param parent_model: the orm model class of the parent containing the relationship
:param parent_key_values: the dictionary of primary keys and their values of the parent
:param child: the child instance to be added
:param relationship_name: the name of the attribute on the parent model that holds related children
:param update_session: option to update the session before running the merge query (defaults to False)
:param auto_commit: option to automatically commit merge (defaults to False)
:param uncommitted: option to merge relationship with uncommitted objects in the session (defaults to False)
:return: True if the operation was successful, False otherwise
"""
try:
primary_keys = inspect(parent_model).primary_key
conditions = [key == parent_key_values[key.name] for key in primary_keys]

# Query for the existing parent using primary keys
if uncommitted:
parent = self.select_from_active_session(parent_model, conditions)
else:
parent = self.select(parent_model, conditions, update_session=update_session)
if not parent:
return False
else:
parent = parent[0]

# add child to the list of related children from the parent
relationship_elements = getattr(parent, relationship_name)
relationship_elements.append(child)
if not uncommitted:
return self.merge(parent, update_session=update_session, auto_commit=auto_commit)
return True
except Exception as e:
logging.error(f"Adding {child.__class__.__name__} to {parent_model.__name__} failed with exception: \n{e}")
return False
return session.query(model)
15 changes: 7 additions & 8 deletions api/src/feeds/impl/datasets_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Gtfsdataset,
Feed,
Expand Down Expand Up @@ -93,9 +93,10 @@ def apply_bounding_filtering(
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))

@staticmethod
def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> List[GtfsDataset]:
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
# Results are sorted by stable_id because Database.select(group_by=) requires it so
dataset_groups = Database().select(
session=session,
query=query.order_by(Gtfsdataset.stable_id),
limit=limit,
offset=offset,
Expand All @@ -109,15 +110,13 @@ def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> Li
gtfs_datasets.append(GtfsDatasetImpl.from_orm(dataset_objects[0]))
return gtfs_datasets

def get_dataset_gtfs(
self,
id: str,
) -> GtfsDataset:
@with_db_session
def get_dataset_gtfs(self, id: str, db_session: Session) -> GtfsDataset:
"""Get the specified dataset from the Mobility Database."""

query = DatasetsApiImpl.create_dataset_query().filter(Gtfsdataset.stable_id == id)

if (ret := DatasetsApiImpl.get_datasets_gtfs(query)) and len(ret) == 1:
if (ret := DatasetsApiImpl.get_datasets_gtfs(query, db_session)) and len(ret) == 1:
return ret[0]
else:
raise_http_error(404, dataset_not_found.format(id))
Loading
Loading