Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c6331ea
removed global_session
qcdyx Nov 5, 2024
215046a
updated batch-datasets cloud functions
qcdyx Nov 6, 2024
fb25b98
updated batch-process-datasets cloud functions
qcdyx Nov 6, 2024
3e3f1b2
modified extract_location
qcdyx Nov 12, 2024
c5d4c35
applied psycopg2 connection pooling
qcdyx Nov 14, 2024
130092a
code refactoring: implemented a with_db_session decorator to streamli…
qcdyx Nov 20, 2024
099cae4
removed SHOULD_CLOSE_DB_SESSION environment variable
qcdyx Nov 22, 2024
0092299
used with_db_session decorator to manage session in GCP functions
qcdyx Nov 24, 2024
edd6665
refactored cloud functions db session management
qcdyx Nov 25, 2024
0ccb413
fixed test
qcdyx Nov 25, 2024
82aaef6
more refactoring
qcdyx Nov 25, 2024
a8a62b0
updated FEEDS_DATABASE_URL
qcdyx Nov 29, 2024
aba9808
Merge branch 'main' into 293-Psycopg
qcdyx Nov 29, 2024
bae94e0
cleanup
qcdyx Nov 29, 2024
94884ae
fixed broken tests
qcdyx Dec 2, 2024
6662557
Merge branch 'main' into 293-Psycopg
qcdyx Dec 2, 2024
111b29e
fixed lint errors
qcdyx Dec 2, 2024
88c0528
Merge branch 'main' into 293-Psycopg
qcdyx Dec 17, 2024
996298c
resolved PR comments
qcdyx Dec 17, 2024
d1f7a4b
lint error fixes
qcdyx Dec 17, 2024
5707126
Merge branch 'main' into 293-Psycopg
qcdyx Dec 17, 2024
ecbd755
skip the test geocoding
qcdyx Dec 17, 2024
df0c70e
added pytest import
qcdyx Dec 17, 2024
f5d5d80
temporarily change coverage threshold to 80
qcdyx Dec 17, 2024
65962ea
added back await and use the with statement, no @with_db_session
qcdyx Dec 17, 2024
48a242e
used with statement
qcdyx Dec 17, 2024
b95e99b
fixed lint errors
qcdyx Dec 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 77 additions & 224 deletions api/src/database/database.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions api/src/feeds/impl/datasets_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Gtfsdataset,
Feed,
Expand Down Expand Up @@ -93,9 +93,10 @@ def apply_bounding_filtering(
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))

@staticmethod
def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> List[GtfsDataset]:
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
# Results are sorted by stable_id because Database.select(group_by=) requires it so
dataset_groups = Database().select(
session=session,
query=query.order_by(Gtfsdataset.stable_id),
limit=limit,
offset=offset,
Expand All @@ -109,15 +110,13 @@ def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> Li
gtfs_datasets.append(GtfsDatasetImpl.from_orm(dataset_objects[0]))
return gtfs_datasets

def get_dataset_gtfs(
self,
id: str,
) -> GtfsDataset:
@with_db_session
def get_dataset_gtfs(self, id: str, db_session: Session) -> GtfsDataset:
"""Get the specified dataset from the Mobility Database."""

query = DatasetsApiImpl.create_dataset_query().filter(Gtfsdataset.stable_id == id)

if (ret := DatasetsApiImpl.get_datasets_gtfs(query)) and len(ret) == 1:
if (ret := DatasetsApiImpl.get_datasets_gtfs(query, db_session)) and len(ret) == 1:
return ret[0]
else:
raise_http_error(404, dataset_not_found.format(id))
84 changes: 39 additions & 45 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import joinedload, Session
from sqlalchemy.orm.query import Query

from database.database import Database
from database.database import Database, with_db_session
from database_gen.sqlacodegen_models import (
Feed,
Gtfsdataset,
Expand Down Expand Up @@ -63,17 +63,15 @@ class FeedsApiImpl(BaseFeedsApi):
def __init__(self) -> None:
self.logger = Logger("FeedsApiImpl").get_logger()

def get_feed(
self,
id: str,
) -> BasicFeed:
@with_db_session
def get_feed(self, id: str, db_session: Session) -> BasicFeed:
"""Get the specified feed from the Mobility Database."""
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")

feed = (
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
.filter(Database().get_query_model(db_session, Feed))
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
.filter(
or_(
Expand All @@ -89,6 +87,7 @@ def get_feed(
else:
raise_http_error(404, feed_not_found.format(id))

@with_db_session
def get_feeds(
self,
limit: int,
Expand All @@ -97,14 +96,15 @@ def get_feeds(
provider: str,
producer_url: str,
is_official: bool,
db_session: Session,
) -> List[BasicFeed]:
"""Get some (or all) feeds from the Mobility Database."""
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_filter = FeedFilter(
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
Expand All @@ -126,27 +126,25 @@ def get_feeds(
results = feed_query.all()
return [BasicFeedImpl.from_orm(feed) for feed in results]

def get_gtfs_feed(
self,
id: str,
) -> GtfsFeed:
@with_db_session
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
"""Get the specified gtfs feed from the Mobility Database."""
feed, translations = self._get_gtfs_feed(id)
feed, translations = self._get_gtfs_feed(id, db_session)
if feed:
return GtfsFeedImpl.from_orm(feed, translations)
else:
raise_http_error(404, gtfs_feed_not_found.format(id))

@staticmethod
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
results = (
FeedFilter(
stable_id=stable_id,
status=None,
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
.filter(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Expand All @@ -168,6 +166,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
return results[0].Gtfsfeed, translations
return None, {}

@with_db_session
def get_gtfs_feed_datasets(
self,
gtfs_feed_id: str,
Expand All @@ -176,6 +175,7 @@ def get_gtfs_feed_datasets(
offset: int,
downloaded_after: str,
downloaded_before: str,
db_session: Session,
) -> List[GtfsDataset]:
"""Get a list of datasets related to a feed."""
if downloaded_before and not valid_iso_date(downloaded_before):
Expand All @@ -191,7 +191,7 @@ def get_gtfs_feed_datasets(
provider__ilike=None,
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.filter(Database().get_query_model(db_session, Gtfsfeed))
.filter(
or_(
Feed.operational_status == None, # noqa: E711
Expand All @@ -208,19 +208,20 @@ def get_gtfs_feed_datasets(
# Replace Z with +00:00 to make the datetime object timezone aware
# Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
query = GtfsDatasetFilter(
downloaded_at__lte=datetime.fromisoformat(downloaded_before.replace("Z", "+00:00"))
if downloaded_before
else None,
downloaded_at__gte=datetime.fromisoformat(downloaded_after.replace("Z", "+00:00"))
if downloaded_after
else None,
downloaded_at__lte=(
datetime.fromisoformat(downloaded_before.replace("Z", "+00:00")) if downloaded_before else None
),
downloaded_at__gte=(
datetime.fromisoformat(downloaded_after.replace("Z", "+00:00")) if downloaded_after else None
),
).filter(DatasetsApiImpl.create_dataset_query().filter(Feed.stable_id == gtfs_feed_id))

if latest:
query = query.filter(Gtfsdataset.latest)

return DatasetsApiImpl.get_datasets_gtfs(query, limit=limit, offset=offset)
return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)

@with_db_session
def get_gtfs_feeds(
self,
limit: int,
Expand All @@ -234,6 +235,7 @@ def get_gtfs_feeds(
dataset_longitudes: str,
bounding_filter_method: str,
is_official: bool,
db_session: Session,
) -> List[GtfsFeed]:
"""Get some (or all) GTFS feeds from the Mobility Database."""
gtfs_feed_filter = GtfsFeedFilter(
Expand All @@ -255,9 +257,7 @@ def get_gtfs_feeds(
is_email_restricted = is_user_email_restricted()
self.logger.info(f"User email is restricted: {is_email_restricted}")
feed_query = (
Database()
.get_session()
.query(Gtfsfeed)
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.id.in_(subquery))
.filter(
or_(
Expand All @@ -277,12 +277,10 @@ def get_gtfs_feeds(
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsFeedImpl)
return self._get_response(feed_query, GtfsFeedImpl, db_session)

def get_gtfs_rt_feed(
self,
id: str,
) -> GtfsRTFeed:
@with_db_session
def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
"""Get the specified GTFS Realtime feed from the Mobility Database."""
gtfs_rt_feed_filter = GtfsRtFeedFilter(
stable_id=id,
Expand All @@ -292,9 +290,7 @@ def get_gtfs_rt_feed(
location=None,
)
results = gtfs_rt_feed_filter.filter(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
db_session.query(Gtfsrealtimefeed, t_location_with_translations_en)
.filter(
or_(
Gtfsrealtimefeed.operational_status == None, # noqa: E711
Expand All @@ -317,6 +313,7 @@ def get_gtfs_rt_feed(
else:
raise_http_error(404, gtfs_rt_feed_not_found.format(id))

@with_db_session
def get_gtfs_rt_feeds(
self,
limit: int,
Expand All @@ -328,6 +325,7 @@ def get_gtfs_rt_feeds(
subdivision_name: str,
municipality: str,
is_official: bool,
db_session: Session,
) -> List[GtfsRTFeed]:
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
entity_types_list = entity_types.split(",") if entity_types else None
Expand Down Expand Up @@ -359,9 +357,7 @@ def get_gtfs_rt_feeds(
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
).subquery()
feed_query = (
Database()
.get_session()
.query(Gtfsrealtimefeed)
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.id.in_(subquery))
.filter(
or_(
Expand All @@ -380,22 +376,20 @@ def get_gtfs_rt_feeds(
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsRTFeedImpl)
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)

@staticmethod
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -> List[T]:
"""Get the response for the feed query."""
results = feed_query.all()
location_translations = get_feeds_location_translations(results)
location_translations = get_feeds_location_translations(results, db_session)
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
return list({feed.id: feed for feed in response}.values())

def get_gtfs_feed_gtfs_rt_feeds(
self,
id: str,
) -> List[GtfsRTFeed]:
@with_db_session
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
"""Get a list of GTFS Realtime related to a GTFS feed."""
feed, translations = self._get_gtfs_feed(id)
feed, translations = self._get_gtfs_feed(id, db_session)
if feed:
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
else:
Expand Down
8 changes: 6 additions & 2 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List

from sqlalchemy import func, select
from sqlalchemy.orm import Query
from sqlalchemy.orm import Query, Session

from database.database import Database
from database.database import Database, with_db_session
from database.sql_functions.unaccent import unaccent
from database_gen.sqlacodegen_models import t_feedsearch
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
Expand Down Expand Up @@ -93,6 +93,7 @@ def create_search_query(
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
return query.order_by(rank_expression.desc())

@with_db_session
def search_feeds(
self,
limit: int,
Expand All @@ -102,15 +103,18 @@ def search_feeds(
data_type: str,
is_official: bool,
search_query: str,
db_session: "Session",
) -> SearchFeeds200Response:
"""Search feeds using full-text search on feed, location and provider's information."""
query = self.create_search_query(status, feed_id, data_type, is_official, search_query)
feed_rows = Database().select(
session=db_session,
query=query,
limit=limit,
offset=offset,
)
feed_total_count = Database().select(
session=db_session,
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query),
)
if feed_rows is None or feed_total_count is None:
Expand Down
11 changes: 8 additions & 3 deletions api/src/scripts/populate_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from pathlib import Path
from typing import Type
from typing import Type, TYPE_CHECKING

import pandas
from dotenv import load_dotenv
Expand All @@ -11,6 +11,9 @@
from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
from utils.logger import Logger

if TYPE_CHECKING:
from sqlalchemy.orm import Session

logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)

Expand Down Expand Up @@ -56,12 +59,14 @@ def __init__(self, filepaths):

self.filter_data()

def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
def query_feed_by_stable_id(
self, session: "Session", stable_id: str, data_type: str | None
) -> Gtfsrealtimefeed | Gtfsfeed | None:
"""
Query the feed by stable id
"""
model = self.get_model(data_type)
return self.db.session.query(model).filter(model.stable_id == stable_id).first()
return session.query(model).filter(model.stable_id == stable_id).first()

@staticmethod
def get_model(data_type: str | None) -> Type[Feed]:
Expand Down
16 changes: 9 additions & 7 deletions api/src/scripts/populate_db_gbfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,16 @@ def deprecate_feeds(self, deprecated_feeds):
if deprecated_feeds is None or deprecated_feeds.empty:
self.logger.info("No feeds to deprecate.")
return

self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).")
for index, row in deprecated_feeds.iterrows():
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
self.db.session.flush()
with self.db.start_db_session() as session:
for index, row in deprecated_feeds.iterrows():
stable_id = self.get_stable_id(row)
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
if gbfs_feed:
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
gbfs_feed.status = "deprecated"
session.flush()

def populate_db(self):
"""Populate the database with the GBFS feeds"""
Expand Down
Loading
Loading