Skip to content

Commit 8789ae3

Browse files
committed
code refactoring: implemented a with_db_session decorator to streamline session management.
1 parent 83268aa commit 8789ae3

File tree

14 files changed

+420
-449
lines changed

14 files changed

+420
-449
lines changed

api/src/database/database.py

Lines changed: 196 additions & 190 deletions
Large diffs are not rendered by default.

api/src/feeds/impl/datasets_api_impl.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
from geoalchemy2 import WKTElement
55
from sqlalchemy import or_
6-
from sqlalchemy.orm import Query
6+
from sqlalchemy.orm import Query, Session
77

8-
from database.database import Database
8+
from database.database import Database, with_db_session
99
from database_gen.sqlacodegen_models import (
1010
Gtfsdataset,
1111
Feed,
@@ -93,9 +93,10 @@ def apply_bounding_filtering(
9393
raise_http_validation_error(invalid_bounding_method.format(bounding_filter_method))
9494

9595
@staticmethod
96-
def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> List[GtfsDataset]:
96+
def get_datasets_gtfs(query: Query, session: Session, limit: int = None, offset: int = None) -> List[GtfsDataset]:
9797
# Results are sorted by stable_id because Database.select(group_by=) requires it so
9898
dataset_groups = Database().select(
99+
session=session,
99100
query=query.order_by(Gtfsdataset.stable_id),
100101
limit=limit,
101102
offset=offset,
@@ -109,15 +110,13 @@ def get_datasets_gtfs(query: Query, limit: int = None, offset: int = None) -> Li
109110
gtfs_datasets.append(GtfsDatasetImpl.from_orm(dataset_objects[0]))
110111
return gtfs_datasets
111112

112-
def get_dataset_gtfs(
113-
self,
114-
id: str,
115-
) -> GtfsDataset:
113+
@with_db_session
114+
def get_dataset_gtfs(self, id: str, db_session: Session) -> GtfsDataset:
116115
"""Get the specified dataset from the Mobility Database."""
117116

118117
query = DatasetsApiImpl.create_dataset_query().filter(Gtfsdataset.stable_id == id)
119118

120-
if (ret := DatasetsApiImpl.get_datasets_gtfs(query)) and len(ret) == 1:
119+
if (ret := DatasetsApiImpl.get_datasets_gtfs(query, db_session)) and len(ret) == 1:
121120
return ret[0]
122121
else:
123122
raise_http_error(404, dataset_not_found.format(id))

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 41 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
from typing import List, Union, TypeVar
33

44
from sqlalchemy import select
5-
from sqlalchemy.orm import joinedload
5+
from sqlalchemy.orm import joinedload, Session
66
from sqlalchemy.orm.query import Query
77

8-
from database.database import Database
8+
from database.database import Database, with_db_session
99
from database_gen.sqlacodegen_models import (
1010
Feed,
1111
Gtfsdataset,
@@ -59,20 +59,19 @@ class FeedsApiImpl(BaseFeedsApi):
5959

6060
APIFeedType = Union[BasicFeed, GtfsFeed, GtfsRTFeed]
6161

62-
def get_feed(
63-
self,
64-
id: str,
65-
) -> BasicFeed:
62+
@with_db_session
63+
def get_feed(self, id: str, db_session: Session) -> BasicFeed:
6664
"""Get the specified feed from the Mobility Database."""
6765
feed = (
6866
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
69-
.filter(Database().get_query_model(Feed))
67+
.filter(Database().get_query_model(db_session, Feed))
7068
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
7169
.filter(
7270
or_(
7371
Feed.operational_status == None, # noqa: E711
7472
Feed.operational_status != "wip",
75-
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
73+
# Allow all feeds to be returned if the user is not restricted
74+
not is_user_email_restricted(),
7675
)
7776
)
7877
.first()
@@ -82,19 +81,15 @@ def get_feed(
8281
else:
8382
raise_http_error(404, feed_not_found.format(id))
8483

84+
@with_db_session
8585
def get_feeds(
86-
self,
87-
limit: int,
88-
offset: int,
89-
status: str,
90-
provider: str,
91-
producer_url: str,
86+
self, limit: int, offset: int, status: str, provider: str, producer_url: str, db_session: Session
9287
) -> List[BasicFeed]:
9388
"""Get some (or all) feeds from the Mobility Database."""
9489
feed_filter = FeedFilter(
9590
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
9691
)
97-
feed_query = feed_filter.filter(Database().get_query_model(Feed))
92+
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
9893
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
9994
feed_query = feed_query.filter(
10095
or_(
@@ -114,27 +109,25 @@ def get_feeds(
114109
results = feed_query.all()
115110
return [BasicFeedImpl.from_orm(feed) for feed in results]
116111

117-
def get_gtfs_feed(
118-
self,
119-
id: str,
120-
) -> GtfsFeed:
112+
@with_db_session
113+
def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
121114
"""Get the specified gtfs feed from the Mobility Database."""
122-
feed, translations = self._get_gtfs_feed(id)
115+
feed, translations = self._get_gtfs_feed(id, db_session)
123116
if feed:
124117
return GtfsFeedImpl.from_orm(feed, translations)
125118
else:
126119
raise_http_error(404, gtfs_feed_not_found.format(id))
127120

128121
@staticmethod
129-
def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
122+
def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]:
130123
results = (
131124
FeedFilter(
132125
stable_id=stable_id,
133126
status=None,
134127
provider__ilike=None,
135128
producer_url__ilike=None,
136129
)
137-
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
130+
.filter(db_session.query(Gtfsfeed, t_location_with_translations_en))
138131
.filter(
139132
or_(
140133
Gtfsfeed.operational_status == None, # noqa: E711
@@ -156,6 +149,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
156149
return results[0].Gtfsfeed, translations
157150
return None, {}
158151

152+
@with_db_session
159153
def get_gtfs_feed_datasets(
160154
self,
161155
gtfs_feed_id: str,
@@ -164,6 +158,7 @@ def get_gtfs_feed_datasets(
164158
offset: int,
165159
downloaded_after: str,
166160
downloaded_before: str,
161+
db_session: Session,
167162
) -> List[GtfsDataset]:
168163
"""Get a list of datasets related to a feed."""
169164
if downloaded_before and not valid_iso_date(downloaded_before):
@@ -179,7 +174,7 @@ def get_gtfs_feed_datasets(
179174
provider__ilike=None,
180175
producer_url__ilike=None,
181176
)
182-
.filter(Database().get_query_model(Gtfsfeed))
177+
.filter(Database().get_query_model(db_session, Gtfsfeed))
183178
.filter(
184179
or_(
185180
Feed.operational_status == None, # noqa: E711
@@ -196,19 +191,20 @@ def get_gtfs_feed_datasets(
196191
# Replace Z with +00:00 to make the datetime object timezone aware
197192
# Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
198193
query = GtfsDatasetFilter(
199-
downloaded_at__lte=datetime.fromisoformat(downloaded_before.replace("Z", "+00:00"))
200-
if downloaded_before
201-
else None,
202-
downloaded_at__gte=datetime.fromisoformat(downloaded_after.replace("Z", "+00:00"))
203-
if downloaded_after
204-
else None,
194+
downloaded_at__lte=(
195+
datetime.fromisoformat(downloaded_before.replace("Z", "+00:00")) if downloaded_before else None
196+
),
197+
downloaded_at__gte=(
198+
datetime.fromisoformat(downloaded_after.replace("Z", "+00:00")) if downloaded_after else None
199+
),
205200
).filter(DatasetsApiImpl.create_dataset_query().filter(Feed.stable_id == gtfs_feed_id))
206201

207202
if latest:
208203
query = query.filter(Gtfsdataset.latest)
209204

210-
return DatasetsApiImpl.get_datasets_gtfs(query, limit=limit, offset=offset)
205+
return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)
211206

207+
@with_db_session
212208
def get_gtfs_feeds(
213209
self,
214210
limit: int,
@@ -221,6 +217,7 @@ def get_gtfs_feeds(
221217
dataset_latitudes: str,
222218
dataset_longitudes: str,
223219
bounding_filter_method: str,
220+
db_session: Session,
224221
) -> List[GtfsFeed]:
225222
"""Get some (or all) GTFS feeds from the Mobility Database."""
226223
gtfs_feed_filter = GtfsFeedFilter(
@@ -240,9 +237,7 @@ def get_gtfs_feeds(
240237
).subquery()
241238

242239
feed_query = (
243-
Database()
244-
.get_session()
245-
.query(Gtfsfeed)
240+
db_session.query(Gtfsfeed)
246241
.filter(Gtfsfeed.id.in_(subquery))
247242
.filter(
248243
or_(
@@ -261,12 +256,10 @@ def get_gtfs_feeds(
261256
.limit(limit)
262257
.offset(offset)
263258
)
264-
return self._get_response(feed_query, GtfsFeedImpl)
259+
return self._get_response(feed_query, GtfsFeedImpl, db_session)
265260

266-
def get_gtfs_rt_feed(
267-
self,
268-
id: str,
269-
) -> GtfsRTFeed:
261+
@with_db_session
262+
def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
270263
"""Get the specified GTFS Realtime feed from the Mobility Database."""
271264
gtfs_rt_feed_filter = GtfsRtFeedFilter(
272265
stable_id=id,
@@ -276,9 +269,7 @@ def get_gtfs_rt_feed(
276269
location=None,
277270
)
278271
results = gtfs_rt_feed_filter.filter(
279-
Database()
280-
.get_session()
281-
.query(Gtfsrealtimefeed, t_location_with_translations_en)
272+
db_session.query(Gtfsrealtimefeed, t_location_with_translations_en)
282273
.filter(
283274
or_(
284275
Gtfsrealtimefeed.operational_status == None, # noqa: E711
@@ -301,6 +292,7 @@ def get_gtfs_rt_feed(
301292
else:
302293
raise_http_error(404, gtfs_rt_feed_not_found.format(id))
303294

295+
@with_db_session
304296
def get_gtfs_rt_feeds(
305297
self,
306298
limit: int,
@@ -311,6 +303,7 @@ def get_gtfs_rt_feeds(
311303
country_code: str,
312304
subdivision_name: str,
313305
municipality: str,
306+
db_session: Session,
314307
) -> List[GtfsRTFeed]:
315308
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
316309
entity_types_list = entity_types.split(",") if entity_types else None
@@ -342,9 +335,7 @@ def get_gtfs_rt_feeds(
342335
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
343336
).subquery()
344337
feed_query = (
345-
Database()
346-
.get_session()
347-
.query(Gtfsrealtimefeed)
338+
db_session.query(Gtfsrealtimefeed)
348339
.filter(Gtfsrealtimefeed.id.in_(subquery))
349340
.filter(
350341
or_(
@@ -362,22 +353,20 @@ def get_gtfs_rt_feeds(
362353
.limit(limit)
363354
.offset(offset)
364355
)
365-
return self._get_response(feed_query, GtfsRTFeedImpl)
356+
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)
366357

367358
@staticmethod
368-
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
359+
def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -> List[T]:
369360
"""Get the response for the feed query."""
370361
results = feed_query.all()
371-
location_translations = get_feeds_location_translations(results)
362+
location_translations = get_feeds_location_translations(results, db_session)
372363
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
373364
return list({feed.id: feed for feed in response}.values())
374365

375-
def get_gtfs_feed_gtfs_rt_feeds(
376-
self,
377-
id: str,
378-
) -> List[GtfsRTFeed]:
366+
@with_db_session
367+
def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]:
379368
"""Get a list of GTFS Realtime related to a GTFS feed."""
380-
feed, translations = self._get_gtfs_feed(id)
369+
feed, translations = self._get_gtfs_feed(id, db_session)
381370
if feed:
382371
return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds]
383372
else:

api/src/feeds/impl/search_api_impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import List
22

33
from sqlalchemy import func, select
4-
from sqlalchemy.orm import Query
4+
from sqlalchemy.orm import Query, Session
55

6-
from database.database import Database
6+
from database.database import Database, with_db_session
77
from database.sql_functions.unaccent import unaccent
88
from database_gen.sqlacodegen_models import t_feedsearch
99
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
@@ -83,6 +83,7 @@ def create_search_query(status: List[str], feed_id: str, data_type: str, search_
8383
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status)
8484
return query.order_by(rank_expression.desc())
8585

86+
@with_db_session
8687
def search_feeds(
8788
self,
8889
limit: int,
@@ -91,15 +92,18 @@ def search_feeds(
9192
feed_id: str,
9293
data_type: str,
9394
search_query: str,
95+
db_session: "Session",
9496
) -> SearchFeeds200Response:
9597
"""Search feeds using full-text search on feed, location and provider's information."""
9698
query = self.create_search_query(status, feed_id, data_type, search_query)
9799
feed_rows = Database().select(
100+
session=db_session,
98101
query=query,
99102
limit=limit,
100103
offset=offset,
101104
)
102105
feed_total_count = Database().select(
106+
session=db_session,
103107
query=self.create_count_search_query(status, feed_id, data_type, search_query),
104108
)
105109
if feed_rows is None or feed_total_count is None:

api/src/scripts/populate_db.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
from pathlib import Path
5-
from typing import Type
5+
from typing import Type, TYPE_CHECKING
66

77
import pandas
88
from dotenv import load_dotenv
@@ -11,6 +11,9 @@
1111
from database_gen.sqlacodegen_models import Feed, Gtfsrealtimefeed, Gtfsfeed, Gbfsfeed
1212
from utils.logger import Logger
1313

14+
if TYPE_CHECKING:
15+
from sqlalchemy.orm import Session
16+
1417
logging.basicConfig()
1518
logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR)
1619

@@ -56,12 +59,14 @@ def __init__(self, filepaths):
5659

5760
self.filter_data()
5861

59-
def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None:
62+
def query_feed_by_stable_id(
63+
self, session: "Session", stable_id: str, data_type: str | None
64+
) -> Gtfsrealtimefeed | Gtfsfeed | None:
6065
"""
6166
Query the feed by stable id
6267
"""
6368
model = self.get_model(data_type)
64-
return self.db.session.query(model).filter(model.stable_id == stable_id).first()
69+
return session.query(model).filter(model.stable_id == stable_id).first()
6570

6671
@staticmethod
6772
def get_model(data_type: str | None) -> Type[Feed]:

api/src/scripts/populate_db_gbfs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@ def deprecate_feeds(self, deprecated_feeds):
3636
if deprecated_feeds is None or deprecated_feeds.empty:
3737
self.logger.info("No feeds to deprecate.")
3838
return
39+
3940
self.logger.info(f"Deprecating {len(deprecated_feeds)} feed(s).")
40-
for index, row in deprecated_feeds.iterrows():
41-
stable_id = self.get_stable_id(row)
42-
gbfs_feed = self.query_feed_by_stable_id(stable_id, "gbfs")
43-
if gbfs_feed:
44-
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
45-
gbfs_feed.status = "deprecated"
46-
self.db.session.flush()
41+
with self.db.start_db_session() as session:
42+
for index, row in deprecated_feeds.iterrows():
43+
stable_id = self.get_stable_id(row)
44+
gbfs_feed = self.query_feed_by_stable_id(session, stable_id, "gbfs")
45+
if gbfs_feed:
46+
self.logger.info(f"Deprecating feed with stable_id={stable_id}")
47+
gbfs_feed.status = "deprecated"
48+
session.flush()
4749

4850
def populate_db(self):
4951
"""Populate the database with the GBFS feeds"""

0 commit comments

Comments
 (0)