33
44from sqlalchemy import or_
55from sqlalchemy import select
6- from sqlalchemy .orm import joinedload
6+ from sqlalchemy .orm import joinedload , Session
77from sqlalchemy .orm .query import Query
88
9- from database .database import Database
9+ from database .database import Database , with_db_session
1010from database_gen .sqlacodegen_models import (
1111 Feed ,
1212 Gtfsdataset ,
@@ -63,17 +63,15 @@ class FeedsApiImpl(BaseFeedsApi):
6363 def __init__ (self ) -> None :
6464 self .logger = Logger ("FeedsApiImpl" ).get_logger ()
6565
66- def get_feed (
67- self ,
68- id : str ,
69- ) -> BasicFeed :
66+ @with_db_session
67+ def get_feed (self , id : str , db_session : Session ) -> BasicFeed :
7068 """Get the specified feed from the Mobility Database."""
7169 is_email_restricted = is_user_email_restricted ()
7270 self .logger .info (f"User email is restricted: { is_email_restricted } " )
7371
7472 feed = (
7573 FeedFilter (stable_id = id , provider__ilike = None , producer_url__ilike = None , status = None )
76- .filter (Database ().get_query_model (Feed ))
74+ .filter (Database ().get_query_model (db_session , Feed ))
7775 .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
7876 .filter (
7977 or_ (
@@ -89,6 +87,7 @@ def get_feed(
8987 else :
9088 raise_http_error (404 , feed_not_found .format (id ))
9189
90+ @with_db_session
9291 def get_feeds (
9392 self ,
9493 limit : int ,
@@ -97,14 +96,15 @@ def get_feeds(
9796 provider : str ,
9897 producer_url : str ,
9998 is_official : bool ,
99+ db_session : Session ,
100100 ) -> List [BasicFeed ]:
101101 """Get some (or all) feeds from the Mobility Database."""
102102 is_email_restricted = is_user_email_restricted ()
103103 self .logger .info (f"User email is restricted: { is_email_restricted } " )
104104 feed_filter = FeedFilter (
105105 status = status , provider__ilike = provider , producer_url__ilike = producer_url , stable_id = None
106106 )
107- feed_query = feed_filter .filter (Database ().get_query_model (Feed ))
107+ feed_query = feed_filter .filter (Database ().get_query_model (db_session , Feed ))
108108 if is_official :
109109 feed_query = feed_query .filter (Feed .official )
110110 feed_query = feed_query .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
@@ -126,27 +126,25 @@ def get_feeds(
126126 results = feed_query .all ()
127127 return [BasicFeedImpl .from_orm (feed ) for feed in results ]
128128
129- def get_gtfs_feed (
130- self ,
131- id : str ,
132- ) -> GtfsFeed :
129+ @with_db_session
130+ def get_gtfs_feed (self , id : str , db_session : Session ) -> GtfsFeed :
133131 """Get the specified gtfs feed from the Mobility Database."""
134- feed , translations = self ._get_gtfs_feed (id )
132+ feed , translations = self ._get_gtfs_feed (id , db_session )
135133 if feed :
136134 return GtfsFeedImpl .from_orm (feed , translations )
137135 else :
138136 raise_http_error (404 , gtfs_feed_not_found .format (id ))
139137
140138 @staticmethod
141- def _get_gtfs_feed (stable_id : str ) -> tuple [Gtfsfeed | None , dict [str , LocationTranslation ]]:
139+ def _get_gtfs_feed (stable_id : str , db_session : Session ) -> tuple [Gtfsfeed | None , dict [str , LocationTranslation ]]:
142140 results = (
143141 FeedFilter (
144142 stable_id = stable_id ,
145143 status = None ,
146144 provider__ilike = None ,
147145 producer_url__ilike = None ,
148146 )
149- .filter (Database (). get_session () .query (Gtfsfeed , t_location_with_translations_en ))
147+ .filter (db_session .query (Gtfsfeed , t_location_with_translations_en ))
150148 .filter (
151149 or_ (
152150 Gtfsfeed .operational_status == None , # noqa: E711
@@ -168,6 +166,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
168166 return results [0 ].Gtfsfeed , translations
169167 return None , {}
170168
169+ @with_db_session
171170 def get_gtfs_feed_datasets (
172171 self ,
173172 gtfs_feed_id : str ,
@@ -176,6 +175,7 @@ def get_gtfs_feed_datasets(
176175 offset : int ,
177176 downloaded_after : str ,
178177 downloaded_before : str ,
178+ db_session : Session ,
179179 ) -> List [GtfsDataset ]:
180180 """Get a list of datasets related to a feed."""
181181 if downloaded_before and not valid_iso_date (downloaded_before ):
@@ -191,7 +191,7 @@ def get_gtfs_feed_datasets(
191191 provider__ilike = None ,
192192 producer_url__ilike = None ,
193193 )
194- .filter (Database ().get_query_model (Gtfsfeed ))
194+ .filter (Database ().get_query_model (db_session , Gtfsfeed ))
195195 .filter (
196196 or_ (
197197 Feed .operational_status == None , # noqa: E711
@@ -208,19 +208,20 @@ def get_gtfs_feed_datasets(
208208 # Replace Z with +00:00 to make the datetime object timezone aware
209209 # Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
210210 query = GtfsDatasetFilter (
211- downloaded_at__lte = datetime . fromisoformat ( downloaded_before . replace ( "Z" , "+00:00" ))
212- if downloaded_before
213- else None ,
214- downloaded_at__gte = datetime . fromisoformat ( downloaded_after . replace ( "Z" , "+00:00" ))
215- if downloaded_after
216- else None ,
211+ downloaded_at__lte = (
212+ datetime . fromisoformat ( downloaded_before . replace ( "Z" , "+00:00" )) if downloaded_before else None
213+ ) ,
214+ downloaded_at__gte = (
215+ datetime . fromisoformat ( downloaded_after . replace ( "Z" , "+00:00" )) if downloaded_after else None
216+ ) ,
217217 ).filter (DatasetsApiImpl .create_dataset_query ().filter (Feed .stable_id == gtfs_feed_id ))
218218
219219 if latest :
220220 query = query .filter (Gtfsdataset .latest )
221221
222- return DatasetsApiImpl .get_datasets_gtfs (query , limit = limit , offset = offset )
222+ return DatasetsApiImpl .get_datasets_gtfs (query , session = db_session , limit = limit , offset = offset )
223223
224+ @with_db_session
224225 def get_gtfs_feeds (
225226 self ,
226227 limit : int ,
@@ -234,6 +235,7 @@ def get_gtfs_feeds(
234235 dataset_longitudes : str ,
235236 bounding_filter_method : str ,
236237 is_official : bool ,
238+ db_session : Session ,
237239 ) -> List [GtfsFeed ]:
238240 """Get some (or all) GTFS feeds from the Mobility Database."""
239241 gtfs_feed_filter = GtfsFeedFilter (
@@ -255,9 +257,7 @@ def get_gtfs_feeds(
255257 is_email_restricted = is_user_email_restricted ()
256258 self .logger .info (f"User email is restricted: { is_email_restricted } " )
257259 feed_query = (
258- Database ()
259- .get_session ()
260- .query (Gtfsfeed )
260+ db_session .query (Gtfsfeed )
261261 .filter (Gtfsfeed .id .in_ (subquery ))
262262 .filter (
263263 or_ (
@@ -277,12 +277,10 @@ def get_gtfs_feeds(
277277 if is_official :
278278 feed_query = feed_query .filter (Feed .official )
279279 feed_query = feed_query .limit (limit ).offset (offset )
280- return self ._get_response (feed_query , GtfsFeedImpl )
280+ return self ._get_response (feed_query , GtfsFeedImpl , db_session )
281281
282- def get_gtfs_rt_feed (
283- self ,
284- id : str ,
285- ) -> GtfsRTFeed :
282+ @with_db_session
283+ def get_gtfs_rt_feed (self , id : str , db_session : Session ) -> GtfsRTFeed :
286284 """Get the specified GTFS Realtime feed from the Mobility Database."""
287285 gtfs_rt_feed_filter = GtfsRtFeedFilter (
288286 stable_id = id ,
@@ -292,9 +290,7 @@ def get_gtfs_rt_feed(
292290 location = None ,
293291 )
294292 results = gtfs_rt_feed_filter .filter (
295- Database ()
296- .get_session ()
297- .query (Gtfsrealtimefeed , t_location_with_translations_en )
293+ db_session .query (Gtfsrealtimefeed , t_location_with_translations_en )
298294 .filter (
299295 or_ (
300296 Gtfsrealtimefeed .operational_status == None , # noqa: E711
@@ -317,6 +313,7 @@ def get_gtfs_rt_feed(
317313 else :
318314 raise_http_error (404 , gtfs_rt_feed_not_found .format (id ))
319315
316+ @with_db_session
320317 def get_gtfs_rt_feeds (
321318 self ,
322319 limit : int ,
@@ -328,6 +325,7 @@ def get_gtfs_rt_feeds(
328325 subdivision_name : str ,
329326 municipality : str ,
330327 is_official : bool ,
328+ db_session : Session ,
331329 ) -> List [GtfsRTFeed ]:
332330 """Get some (or all) GTFS Realtime feeds from the Mobility Database."""
333331 entity_types_list = entity_types .split ("," ) if entity_types else None
@@ -359,9 +357,7 @@ def get_gtfs_rt_feeds(
359357 .join (Entitytype , Gtfsrealtimefeed .entitytypes )
360358 ).subquery ()
361359 feed_query = (
362- Database ()
363- .get_session ()
364- .query (Gtfsrealtimefeed )
360+ db_session .query (Gtfsrealtimefeed )
365361 .filter (Gtfsrealtimefeed .id .in_ (subquery ))
366362 .filter (
367363 or_ (
@@ -380,22 +376,20 @@ def get_gtfs_rt_feeds(
380376 if is_official :
381377 feed_query = feed_query .filter (Feed .official )
382378 feed_query = feed_query .limit (limit ).offset (offset )
383- return self ._get_response (feed_query , GtfsRTFeedImpl )
379+ return self ._get_response (feed_query , GtfsRTFeedImpl , db_session )
384380
385381 @staticmethod
386- def _get_response (feed_query : Query , impl_cls : type [T ]) -> List [T ]:
382+ def _get_response (feed_query : Query , impl_cls : type [T ], db_session : "Session" ) -> List [T ]:
387383 """Get the response for the feed query."""
388384 results = feed_query .all ()
389- location_translations = get_feeds_location_translations (results )
385+ location_translations = get_feeds_location_translations (results , db_session )
390386 response = [impl_cls .from_orm (feed , location_translations ) for feed in results ]
391387 return list ({feed .id : feed for feed in response }.values ())
392388
393- def get_gtfs_feed_gtfs_rt_feeds (
394- self ,
395- id : str ,
396- ) -> List [GtfsRTFeed ]:
389+ @with_db_session
390+ def get_gtfs_feed_gtfs_rt_feeds (self , id : str , db_session : Session ) -> List [GtfsRTFeed ]:
397391 """Get a list of GTFS Realtime related to a GTFS feed."""
398- feed , translations = self ._get_gtfs_feed (id )
392+ feed , translations = self ._get_gtfs_feed (id , db_session )
399393 if feed :
400394 return [GtfsRTFeedImpl .from_orm (gtfs_rt_feed , translations ) for gtfs_rt_feed in feed .gtfs_rt_feeds ]
401395 else :
0 commit comments