Skip to content

Commit 2aec616

Browse files
authored
Merge pull request #837 from MobilityData/293-Psycopg
feat: Use psycopg2 for Connection Pooling and Implement Global Engine with Context Manager for Session Management
2 parents b0d5f89 + b95e99b commit 2aec616

File tree

42 files changed

+706
-875
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+706
-875
lines changed

api/src/database/database.py

Lines changed: 77 additions & 224 deletions
Large diffs are not rendered by default.

api/src/feeds/impl/datasets_api_impl.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from geoalchemy2 import WKTElement
55
from sqlalchemy import or_
6-
from sqlalchemy.orm import Query
6+
from sqlalchemy.orm import Query, Session
77

8-
from database.database import Database
8+
from database.database import Database, with_db_session
99
from database_gen.sqlacodegen_models import (
1010
Gtfsdataset,
1111
Feed,
@@ -93,9 +93,10 @@ def apply_bounding_filtering(
9393
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))
9494

9595
@staticmethod
96-
def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> List[GtfsDataset]:
96+
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
9797
# Results are sorted by stable_id because Database.select(group_by=) requires it so
9898
dataset_groups = Database().select(
99+
session=session,
99100
query=query.order_by(Gtfsdataset.stable_id),
100101
limit=limit,
101102
offset=offset,
@@ -109,15 +110,13 @@ def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> Li
109110
gtfs_datasets.append(GtfsDatasetImpl.from_orm(dataset_objects[0]))
110111
return gtfs_datasets
111112

112-
def get_dataset_gtfs(
113-
self,
114-
id: str,
115-
) -> GtfsDataset:
113+
@with_db_session
114+
def get_dataset_gtfs(self, id: str, db_session: Session) -> GtfsDataset:
116115
"""Get the specified dataset from the Mobility Database."""
117116

118117
query = DatasetsApiImpl.create_dataset_query().filter(Gtfsdataset.stable_id == id)
119118

120-
if (ret := DatasetsApiImpl.get_datasets_gtfs(query)) and len(ret) == 1:
119+
if (ret := DatasetsApiImpl.get_datasets_gtfs(query, db_session)) and len(ret) == 1:
121120
return ret[0]
122121
else:
123122
raise_http_error(404, dataset_not_found.format(id))

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 39 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
from sqlalchemy import or_
55
from sqlalchemy import select
6-
from sqlalchemy.orm import joinedload
6+
from sqlalchemy.orm import joinedload, Session
77
from sqlalchemy.orm.query import Query
88

9-
from database.database import Database
9+
from database.database import Database, with_db_session
1010
from database_gen.sqlacodegen_models import (
1111
Feed,
1212
Gtfsdataset,
@@ -63,17 +63,15 @@ class FeedsApiImpl(BaseFeedsApi):
6363
def __init__(self) -> None:
6464
self.logger = Logger("FeedsApiImpl").get_logger()
6565

66-
def get_feed(
67-
self,
68-
id: str,
69-
) -> BasicFeed:
66+
@with_db_session
67+
def get_feed(self, id: str, db_session: Session) -> BasicFeed:
7068
"""Get the specified feed from the Mobility Database."""
7169
is_email_restricted = is_user_email_restricted()
7270
self.logger.info(f"User email is restricted: {is_email_restricted}")
7371

7472
feed = (
7573
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
76-
.filter(Database().get_query_model(Feed))
74+
.filter(Database().get_query_model(db_session, Feed))
7775
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
7876
.filter(
7977
or_(
@@ -89,6 +87,7 @@ def get_feed(
8987
else:
9088
raise_http_error(404, feed_not_found.format(id))
9189

90+
@with_db_session
9291
def get_feeds(
9392
self,
9493
limit: int,
@@ -97,14 +96,15 @@ def get_feeds(
9796
provider: str,
9897
producer_url: str,
9998
is_official: bool,
99+
db_session: Session,
100100
) -> List[BasicFeed]:
101101
"""Get some (or all) feeds from the Mobility Database."""
102102
is_email_restricted = is_user_email_restricted()
103103
self.logger.info(f"User email is restricted: {is_email_restricted}")
104104
feed_filter = FeedFilter(
105105
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
106106
)
107-
feed_query = feed_filter.filter(Database().get_query_model(Feed))
107+
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
108108
if is_official:
109109
feed_query = feed_query.filter(Feed.official)
110110
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
@@ -126,27 +126,25 @@ def get_feeds(
126126
results = feed_query.all()
127127
return [BasicFeedImpl.from_orm(feed) for feed in results]
128128

129-
def get_gtfs_feed(
130-
self,
131-
id: str,
132-
) -> GtfsFeed:
129+
@with_db_session
130+
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
133131
"""Get the specified gtfs feed from the Mobility Database."""
134-
feed, translations = self._get_gtfs_feed(id)
132+
feed, translations = self._get_gtfs_feed(id, db_session)
135133
if feed:
136134
return GtfsFeedImpl.from_orm(feed, translations)
137135
else:
138136
raise_http_error(404, gtfs_feed_not_found.format(id))
139137

140138
@staticmethod
141-
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
139+
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
142140
results = (
143141
FeedFilter(
144142
stable_id=stable_id,
145143
status=None,
146144
provider__ilike=None,
147145
producer_url__ilike=None,
148146
)
149-
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
147+
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
150148
.filter(
151149
or_(
152150
Gtfsfeed.operational_status == None, # noqa: E711
@@ -168,6 +166,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
168166
return results[0].Gtfsfeed, translations
169167
return None, {}
170168

169+
@with_db_session
171170
def get_gtfs_feed_datasets(
172171
self,
173172
gtfs_feed_id: str,
@@ -176,6 +175,7 @@ def get_gtfs_feed_datasets(
176175
offset: int,
177176
downloaded_after: str,
178177
downloaded_before: str,
178+
db_session: Session,
179179
) -> List[GtfsDataset]:
180180
"""Get a list of datasets related to a feed."""
181181
if downloaded_before and not valid_iso_date(downloaded_before):
@@ -191,7 +191,7 @@ def get_gtfs_feed_datasets(
191191
provider__ilike=None,
192192
producer_url__ilike=None,
193193
)
194-
.filter(Database().get_query_model(Gtfsfeed))
194+
.filter(Database().get_query_model(db_session, Gtfsfeed))
195195
.filter(
196196
or_(
197197
Feed.operational_status == None, # noqa: E711
@@ -208,19 +208,20 @@ def get_gtfs_feed_datasets(
208208
# Replace Z with +00:00 to make the datetime object timezone aware
209209
# Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
210210
query = GtfsDatasetFilter(
211-
downloaded_at__lte=datetime.fromisoformat(downloaded_before.replace("Z", "+00:00"))
212-
if downloaded_before
213-
else None,
214-
downloaded_at__gte=datetime.fromisoformat(downloaded_after.replace("Z", "+00:00"))
215-
if downloaded_after
216-
else None,
211+
downloaded_at__lte=(
212+
datetime.fromisoformat(downloaded_before.replace("Z", "+00:00")) if downloaded_before else None
213+
),
214+
downloaded_at__gte=(
215+
datetime.fromisoformat(downloaded_after.replace("Z", "+00:00")) if downloaded_after else None
216+
),
217217
).filter(DatasetsApiImpl.create_dataset_query().filter(Feed.stable_id == gtfs_feed_id))
218218

219219
if latest:
220220
query = query.filter(Gtfsdataset.latest)
221221

222-
return DatasetsApiImpl.get_datasets_gtfs(query, limit=limit, offset=offset)
222+
return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)
223223

224+
@with_db_session
224225
def get_gtfs_feeds(
225226
self,
226227
limit: int,
@@ -234,6 +235,7 @@ def get_gtfs_feeds(
234235
dataset_longitudes: str,
235236
bounding_filter_method: str,
236237
is_official: bool,
238+
db_session: Session,
237239
) -> List[GtfsFeed]:
238240
"""Get some (or all) GTFS feeds from the Mobility Database."""
239241
gtfs_feed_filter = GtfsFeedFilter(
@@ -255,9 +257,7 @@ def get_gtfs_feeds(
255257
is_email_restricted = is_user_email_restricted()
256258
self.logger.info(f"User email is restricted: {is_email_restricted}")
257259
feed_query = (
258-
Database()
259-
.get_session()
260-
.query(Gtfsfeed)
260+
db_session.query(Gtfsfeed)
261261
.filter(Gtfsfeed.id.in_(subquery))
262262
.filter(
263263
or_(
@@ -277,12 +277,10 @@ def get_gtfs_feeds(
277277
if is_official:
278278
feed_query = feed_query.filter(Feed.official)
279279
feed_query = feed_query.limit(limit).offset(offset)
280-
return self._get_response(feed_query, GtfsFeedImpl)
280+
return self._get_response(feed_query, GtfsFeedImpl, db_session)
281281

282-
def get_gtfs_rt_feed(
283-
self,
284-
id: str,
285-
) -> GtfsRTFeed:
282+
@with_db_session
283+
def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
286284
"""Get the specified GTFS Realtime feed from the Mobility Database."""
287285
gtfs_rt_feed_filter = GtfsRtFeedFilter(
288286
stable_id=id,
@@ -292,9 +290,7 @@ def get_gtfs_rt_feed(
292290
location=None,
293291
)
294292
results = gtfs_rt_feed_filter.filter(
295-
Database()
296-
.get_session()
297-
.query(Gtfsrealtimefeed, t_location_with_translations_en)
293+
db_session.query(Gtfsrealtimefeed, t_location_with_translations_en)
298294
.filter(
299295
or_(
300296
Gtfsrealtimefeed.operational_status == None, # noqa: E711
@@ -317,6 +313,7 @@ def get_gtfs_rt_feed(
317313
else:
318314
raise_http_error(404, gtfs_rt_feed_not_found.format(id))
319315

316+
@with_db_session
320317
def get_gtfs_rt_feeds(
321318
self,
322319
limit: int,
@@ -328,6 +325,7 @@ def get_gtfs_rt_feeds(
328325
subdivision_name: str,
329326
municipality: str,
330327
is_official: bool,
328+
db_session: Session,
331329
) -> List[GtfsRTFeed]:
332330
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
333331
entity_types_list = entity_types.split(",") if entity_types else None
@@ -359,9 +357,7 @@ def get_gtfs_rt_feeds(
359357
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
360358
).subquery()
361359
feed_query = (
362-
Database()
363-
.get_session()
364-
.query(Gtfsrealtimefeed)
360+
db_session.query(Gtfsrealtimefeed)
365361
.filter(Gtfsrealtimefeed.id.in_(subquery))
366362
.filter(
367363
or_(
@@ -380,22 +376,20 @@ def get_gtfs_rt_feeds(
380376
if is_official:
381377
feed_query = feed_query.filter(Feed.official)
382378
feed_query = feed_query.limit(limit).offset(offset)
383-
return self._get_response(feed_query, GtfsRTFeedImpl)
379+
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)
384380

385381
@staticmethod
386-
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
382+
def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -> List[T]:
387383
"""Get the response for the feed query."""
388384
results = feed_query.all()
389-
location_translations = get_feeds_location_translations(results)
385+
location_translations = get_feeds_location_translations(results, db_session)
390386
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
391387
return list({feed.id: feed for feed in response}.values())
392388

393-
def get_gtfs_feed_gtfs_rt_feeds(
394-
self,
395-
id: str,
396-
) -> List[GtfsRTFeed]:
389+
@with_db_session
390+
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
397391
"""Get a list of GTFS Realtime related to a GTFS feed."""
398-
feed, translations = self._get_gtfs_feed(id)
392+
feed, translations = self._get_gtfs_feed(id, db_session)
399393
if feed:
400394
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
401395
else:

api/src/feeds/impl/search_api_impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import List
22

33
from sqlalchemy import func, select
4-
from sqlalchemy.orm import Query
4+
from sqlalchemy.orm import Query, Session
55

6-
from database.database import Database
6+
from database.database import Database, with_db_session
77
from database.sql_functions.unaccent import unaccent
88
from database_gen.sqlacodegen_models import t_feedsearch
99
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
@@ -93,6 +93,7 @@ def create_search_query(
9393
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
9494
return query.order_by(rank_expression.desc())
9595

96+
@with_db_session
9697
def search_feeds(
9798
self,
9899
limit: int,
@@ -102,15 +103,18 @@ def search_feeds(
102103
data_type: str,
103104
is_official: bool,
104105
search_query: str,
106+
db_session: "Session",
105107
) -> SearchFeeds200Response:
106108
"""Search feeds using full-text search on feed, location and provider's information."""
107109
query = self.create_search_query(status, feed_id, data_type, is_official, search_query)
108110
feed_rows = Database().select(
111+
session=db_session,
109112
query=query,
110113
limit=limit,
111114
offset=offset,
112115
)
113116
feed_total_count = Database().select(
117+
session=db_session,
114118
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query),
115119
)
116120
if feed_rows is None or feed_total_count is None:

api/src/scripts/populate_db.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
from pathlib import Path
5-
from typing import Type
5+
from typing import Type, TYPE_CHECKING
66

77
import pandas
88
from dotenv import load_dotenv
@@ -11,6 +11,9 @@
1111
from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
1212
from utils.logger import Logger
1313

14+
if TYPE_CHECKING:
15+
from sqlalchemy.orm import Session
16+
1417
logging.basicConfig()
1518
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)
1619

@@ -56,12 +59,14 @@ def __init__(self, filepaths):
5659

5760
self.filter_data()
5861

59-
def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
62+
def query_feed_by_stable_id(
63+
self, session: "Session", stable_id: str, data_type: str | None
64+
) -> Gtfsrealtimefeed | Gtfsfeed | None:
6065
"""
6166
Query the feed by stable id
6267
"""
6368
model = self.get_model(data_type)
64-
return self.db.session.query(model).filter(model.stable_id == stable_id).first()
69+
return session.query(model).filter(model.stable_id == stable_id).first()
6570

6671
@staticmethod
6772
def get_model(data_type: str | None) -> Type[Feed]:

api/src/scripts/populate_db_gbfs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ def deprecate_feeds(self, deprecated_feeds):
3636
if deprecated_feeds is None or deprecated_feeds.empty:
3737
self.logger.info("No feeds to deprecate.")
3838
return
39+
3940
self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).")
40-
for index, row in deprecated_feeds.iterrows():
41-
stable_id = self.get_stable_id(row)
42-
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
43-
if gbfs_feed:
44-
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
45-
gbfs_feed.status = "deprecated"
46-
self.db.session.flush()
41+
with self.db.start_db_session() as session:
42+
for index, row in deprecated_feeds.iterrows():
43+
stable_id = self.get_stable_id(row)
44+
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
45+
if gbfs_feed:
46+
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
47+
gbfs_feed.status = "deprecated"
48+
session.flush()
4749

4850
def populate_db(self):
4951
"""Populate the database with the GBFS feeds"""

0 commit comments

Comments
 (0)