22from typing import List , Union , TypeVar
33
44from sqlalchemy import select
5- from sqlalchemy .orm import joinedload
5+ from sqlalchemy .orm import joinedload , Session
66from sqlalchemy .orm .query import Query
77
8- from database .database import Database
8+ from database .database import Database , with_db_session
99from database_gen .sqlacodegen_models import (
1010 Feed ,
1111 Gtfsdataset ,
@@ -59,20 +59,19 @@ class FeedsApiImpl(BaseFeedsApi):
5959
6060 APIFeedType = Union [BasicFeed , GtfsFeed , GtfsRTFeed ]
6161
62- def get_feed (
63- self ,
64- id : str ,
65- ) -> BasicFeed :
62+ @with_db_session
63+ def get_feed (self , id : str , db_session : Session ) -> BasicFeed :
6664 """Get the specified feed from the Mobility Database."""
6765 feed = (
6866 FeedFilter (stable_id = id , provider__ilike = None , producer_url__ilike = None , status = None )
69- .filter (Database ().get_query_model (Feed ))
67+ .filter (Database ().get_query_model (db_session , Feed ))
7068 .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
7169 .filter (
7270 or_ (
7371 Feed .operational_status == None , # noqa: E711
7472 Feed .operational_status != "wip" ,
75- not is_user_email_restricted (), # Allow all feeds to be returned if the user is not restricted
73+ # Allow all feeds to be returned if the user is not restricted
74+ not is_user_email_restricted (),
7675 )
7776 )
7877 .first ()
@@ -82,19 +81,15 @@ def get_feed(
8281 else :
8382 raise_http_error (404 , feed_not_found .format (id ))
8483
84+ @with_db_session
8585 def get_feeds (
86- self ,
87- limit : int ,
88- offset : int ,
89- status : str ,
90- provider : str ,
91- producer_url : str ,
86+ self , limit : int , offset : int , status : str , provider : str , producer_url : str , db_session : Session
9287 ) -> List [BasicFeed ]:
9388 """Get some (or all) feeds from the Mobility Database."""
9489 feed_filter = FeedFilter (
9590 status = status , provider__ilike = provider , producer_url__ilike = producer_url , stable_id = None
9691 )
97- feed_query = feed_filter .filter (Database ().get_query_model (Feed ))
92+ feed_query = feed_filter .filter (Database ().get_query_model (db_session , Feed ))
9893 feed_query = feed_query .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
9994 feed_query = feed_query .filter (
10095 or_ (
@@ -114,27 +109,25 @@ def get_feeds(
114109 results = feed_query .all ()
115110 return [BasicFeedImpl .from_orm (feed ) for feed in results ]
116111
117- def get_gtfs_feed (
118- self ,
119- id : str ,
120- ) -> GtfsFeed :
112+ @with_db_session
113+ def get_gtfs_feed (self , id : str , db_session : Session ) -> GtfsFeed :
121114 """Get the specified gtfs feed from the Mobility Database."""
122- feed , translations = self ._get_gtfs_feed (id )
115+ feed , translations = self ._get_gtfs_feed (id , db_session )
123116 if feed :
124117 return GtfsFeedImpl .from_orm (feed , translations )
125118 else :
126119 raise_http_error (404 , gtfs_feed_not_found .format (id ))
127120
128121 @staticmethod
129- def _get_gtfs_feed (stable_id : str ) -> tuple [Gtfsfeed | None , dict [str , LocationTranslation ]]:
122+ def _get_gtfs_feed (stable_id : str , db_session : Session ) -> tuple [Gtfsfeed | None , dict [str , LocationTranslation ]]:
130123 results = (
131124 FeedFilter (
132125 stable_id = stable_id ,
133126 status = None ,
134127 provider__ilike = None ,
135128 producer_url__ilike = None ,
136129 )
137- .filter (Database (). get_session () .query (Gtfsfeed , t_location_with_translations_en ))
130+ .filter (db_session .query (Gtfsfeed , t_location_with_translations_en ))
138131 .filter (
139132 or_ (
140133 Gtfsfeed .operational_status == None , # noqa: E711
@@ -156,6 +149,7 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
156149 return results [0 ].Gtfsfeed , translations
157150 return None , {}
158151
152+ @with_db_session
159153 def get_gtfs_feed_datasets (
160154 self ,
161155 gtfs_feed_id : str ,
@@ -164,6 +158,7 @@ def get_gtfs_feed_datasets(
164158 offset : int ,
165159 downloaded_after : str ,
166160 downloaded_before : str ,
161+ db_session : Session ,
167162 ) -> List [GtfsDataset ]:
168163 """Get a list of datasets related to a feed."""
169164 if downloaded_before and not valid_iso_date (downloaded_before ):
@@ -179,7 +174,7 @@ def get_gtfs_feed_datasets(
179174 provider__ilike = None ,
180175 producer_url__ilike = None ,
181176 )
182- .filter (Database ().get_query_model (Gtfsfeed ))
177+ .filter (Database ().get_query_model (db_session , Gtfsfeed ))
183178 .filter (
184179 or_ (
185180 Feed .operational_status == None , # noqa: E711
@@ -196,19 +191,20 @@ def get_gtfs_feed_datasets(
196191 # Replace Z with +00:00 to make the datetime object timezone aware
197192 # Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
198193 query = GtfsDatasetFilter (
199- downloaded_at__lte = datetime . fromisoformat ( downloaded_before . replace ( "Z" , "+00:00" ))
200- if downloaded_before
201- else None ,
202- downloaded_at__gte = datetime . fromisoformat ( downloaded_after . replace ( "Z" , "+00:00" ))
203- if downloaded_after
204- else None ,
194+ downloaded_at__lte = (
195+ datetime . fromisoformat ( downloaded_before . replace ( "Z" , "+00:00" )) if downloaded_before else None
196+ ) ,
197+ downloaded_at__gte = (
198+ datetime . fromisoformat ( downloaded_after . replace ( "Z" , "+00:00" )) if downloaded_after else None
199+ ) ,
205200 ).filter (DatasetsApiImpl .create_dataset_query ().filter (Feed .stable_id == gtfs_feed_id ))
206201
207202 if latest :
208203 query = query .filter (Gtfsdataset .latest )
209204
210- return DatasetsApiImpl .get_datasets_gtfs (query , limit = limit , offset = offset )
205+ return DatasetsApiImpl .get_datasets_gtfs (query , session = db_session , limit = limit , offset = offset )
211206
207+ @with_db_session
212208 def get_gtfs_feeds (
213209 self ,
214210 limit : int ,
@@ -221,6 +217,7 @@ def get_gtfs_feeds(
221217 dataset_latitudes : str ,
222218 dataset_longitudes : str ,
223219 bounding_filter_method : str ,
220+ db_session : Session ,
224221 ) -> List [GtfsFeed ]:
225222 """Get some (or all) GTFS feeds from the Mobility Database."""
226223 gtfs_feed_filter = GtfsFeedFilter (
@@ -240,9 +237,7 @@ def get_gtfs_feeds(
240237 ).subquery ()
241238
242239 feed_query = (
243- Database ()
244- .get_session ()
245- .query (Gtfsfeed )
240+ db_session .query (Gtfsfeed )
246241 .filter (Gtfsfeed .id .in_ (subquery ))
247242 .filter (
248243 or_ (
@@ -261,12 +256,10 @@ def get_gtfs_feeds(
261256 .limit (limit )
262257 .offset (offset )
263258 )
264- return self ._get_response (feed_query , GtfsFeedImpl )
259+ return self ._get_response (feed_query , GtfsFeedImpl , db_session )
265260
266- def get_gtfs_rt_feed (
267- self ,
268- id : str ,
269- ) -> GtfsRTFeed :
261+ @with_db_session
262+ def get_gtfs_rt_feed (self , id : str , db_session : Session ) -> GtfsRTFeed :
270263 """Get the specified GTFS Realtime feed from the Mobility Database."""
271264 gtfs_rt_feed_filter = GtfsRtFeedFilter (
272265 stable_id = id ,
@@ -276,9 +269,7 @@ def get_gtfs_rt_feed(
276269 location = None ,
277270 )
278271 results = gtfs_rt_feed_filter .filter (
279- Database ()
280- .get_session ()
281- .query (Gtfsrealtimefeed , t_location_with_translations_en )
272+ db_session .query (Gtfsrealtimefeed , t_location_with_translations_en )
282273 .filter (
283274 or_ (
284275 Gtfsrealtimefeed .operational_status == None , # noqa: E711
@@ -301,6 +292,7 @@ def get_gtfs_rt_feed(
301292 else :
302293 raise_http_error (404 , gtfs_rt_feed_not_found .format (id ))
303294
295+ @with_db_session
304296 def get_gtfs_rt_feeds (
305297 self ,
306298 limit : int ,
@@ -311,6 +303,7 @@ def get_gtfs_rt_feeds(
311303 country_code : str ,
312304 subdivision_name : str ,
313305 municipality : str ,
306+ db_session : Session ,
314307 ) -> List [GtfsRTFeed ]:
315308 """Get some (or all) GTFS Realtime feeds from the Mobility Database."""
316309 entity_types_list = entity_types .split ("," ) if entity_types else None
@@ -342,9 +335,7 @@ def get_gtfs_rt_feeds(
342335 .join (Entitytype , Gtfsrealtimefeed .entitytypes )
343336 ).subquery ()
344337 feed_query = (
345- Database ()
346- .get_session ()
347- .query (Gtfsrealtimefeed )
338+ db_session .query (Gtfsrealtimefeed )
348339 .filter (Gtfsrealtimefeed .id .in_ (subquery ))
349340 .filter (
350341 or_ (
@@ -362,22 +353,20 @@ def get_gtfs_rt_feeds(
362353 .limit (limit )
363354 .offset (offset )
364355 )
365- return self ._get_response (feed_query , GtfsRTFeedImpl )
356+ return self ._get_response (feed_query , GtfsRTFeedImpl , db_session )
366357
367358 @staticmethod
368- def _get_response (feed_query : Query , impl_cls : type [T ]) -> List [T ]:
359+ def _get_response (feed_query : Query , impl_cls : type [T ], db_session : "Session" ) -> List [T ]:
369360 """Get the response for the feed query."""
370361 results = feed_query .all ()
371- location_translations = get_feeds_location_translations (results )
362+ location_translations = get_feeds_location_translations (results , db_session )
372363 response = [impl_cls .from_orm (feed , location_translations ) for feed in results ]
373364 return list ({feed .id : feed for feed in response }.values ())
374365
375- def get_gtfs_feed_gtfs_rt_feeds (
376- self ,
377- id : str ,
378- ) -> List [GtfsRTFeed ]:
366+ @with_db_session
367+ def get_gtfs_feed_gtfs_rt_feeds (self , id : str , db_session : Session ) -> List [GtfsRTFeed ]:
379368 """Get a list of GTFS Realtime related to a GTFS feed."""
380- feed , translations = self ._get_gtfs_feed (id )
369+ feed , translations = self ._get_gtfs_feed (id , db_session )
381370 if feed :
382371 return [GtfsRTFeedImpl .from_orm (gtfs_rt_feed , translations ) for gtfs_rt_feed in feed .gtfs_rt_feeds ]
383372 else :
0 commit comments