66from sqlalchemy .orm import joinedload , Session
77from sqlalchemy .orm .query import Query
88
9- from database .database import Database , with_db_session
10- from database_gen .sqlacodegen_models import (
9+ from shared .common .db_utils import get_gtfs_feeds_query , get_gtfs_rt_feeds_query , get_joinedload_options
10+ from shared .database .database import Database , with_db_session
11+ from shared .database_gen .sqlacodegen_models import (
1112 Feed ,
1213 Gtfsdataset ,
1314 Gtfsfeed ,
1718 t_location_with_translations_en ,
1819 Entitytype ,
1920)
20- from feeds . filters .feed_filter import FeedFilter
21- from feeds . filters .gtfs_dataset_filter import GtfsDatasetFilter
22- from feeds . filters .gtfs_feed_filter import GtfsFeedFilter , LocationFilter
23- from feeds . filters .gtfs_rt_feed_filter import GtfsRtFeedFilter , EntityTypeFilter
21+ from shared . feed_filters .feed_filter import FeedFilter
22+ from shared . feed_filters .gtfs_dataset_filter import GtfsDatasetFilter
23+ from shared . feed_filters .gtfs_feed_filter import LocationFilter
24+ from shared . feed_filters .gtfs_rt_feed_filter import GtfsRtFeedFilter , EntityTypeFilter
2425from feeds .impl .datasets_api_impl import DatasetsApiImpl
25- from feeds .impl .error_handling import (
26- raise_http_validation_error ,
26+ from shared .common .error_handling import (
2727 invalid_date_message ,
28- raise_http_error ,
2928 feed_not_found ,
3029 gtfs_feed_not_found ,
3130 gtfs_rt_feed_not_found ,
31+ InternalHTTPException ,
3232)
3333from feeds .impl .models .basic_feed_impl import BasicFeedImpl
3434from feeds .impl .models .entity_type_enum import EntityType
3939from feeds_gen .models .gtfs_dataset import GtfsDataset
4040from feeds_gen .models .gtfs_feed import GtfsFeed
4141from feeds_gen .models .gtfs_rt_feed import GtfsRTFeed
42+ from feeds .impl .error_handling import raise_http_error , raise_http_validation_error , convert_exception
4243from middleware .request_context import is_user_email_restricted
4344from utils .date_utils import valid_iso_date
4445from utils .location_translation import (
@@ -116,7 +117,7 @@ def get_feeds(
116117 )
117118 # Results are sorted by provider
118119 feed_query = feed_query .order_by (Feed .provider , Feed .stable_id )
119- feed_query = feed_query .options (* BasicFeedImpl . get_joinedload_options ())
120+ feed_query = feed_query .options (* get_joinedload_options ())
120121 if limit is not None :
121122 feed_query = feed_query .limit (limit )
122123 if offset is not None :
@@ -155,7 +156,7 @@ def _get_gtfs_feed(stable_id: str, db_session: Session) -> Optional[Gtfsfeed]:
155156 joinedload (Gtfsfeed .gtfsdatasets )
156157 .joinedload (Gtfsdataset .validation_reports )
157158 .joinedload (Validationreport .notices ),
158- * BasicFeedImpl . get_joinedload_options (),
159+ * get_joinedload_options (),
159160 )
160161 ).all ()
161162 if len (results ) == 0 :
@@ -233,46 +234,29 @@ def get_gtfs_feeds(
233234 is_official : bool ,
234235 db_session : Session ,
235236 ) -> List [GtfsFeed ]:
236- """Get some (or all) GTFS feeds from the Mobility Database."""
237- gtfs_feed_filter = GtfsFeedFilter (
238- stable_id = None ,
239- provider__ilike = provider ,
240- producer_url__ilike = producer_url ,
241- location = LocationFilter (
237+ try :
238+ include_wip = not is_user_email_restricted ()
239+ feed_query = get_gtfs_feeds_query (
240+ limit = limit ,
241+ offset = offset ,
242+ provider = provider ,
243+ producer_url = producer_url ,
242244 country_code = country_code ,
243- subdivision_name__ilike = subdivision_name ,
244- municipality__ilike = municipality ,
245- ),
246- )
247-
248- subquery = gtfs_feed_filter .filter (select (Gtfsfeed .id ).join (Location , Gtfsfeed .locations ))
249- subquery = DatasetsApiImpl .apply_bounding_filtering (
250- subquery , dataset_latitudes , dataset_longitudes , bounding_filter_method
251- ).subquery ()
252-
253- is_email_restricted = is_user_email_restricted ()
254- self .logger .info (f"User email is restricted: { is_email_restricted } " )
255- feed_query = (
256- db_session .query (Gtfsfeed )
257- .filter (Gtfsfeed .id .in_ (subquery ))
258- .filter (
259- or_ (
260- Gtfsfeed .operational_status == None , # noqa: E711
261- Gtfsfeed .operational_status != "wip" ,
262- not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
263- )
245+ subdivision_name = subdivision_name ,
246+ municipality = municipality ,
247+ dataset_latitudes = dataset_latitudes ,
248+ dataset_longitudes = dataset_longitudes ,
249+ bounding_filter_method = bounding_filter_method ,
250+ is_official = is_official ,
251+ include_wip = include_wip ,
252+ db_session = db_session ,
264253 )
265- .options (
266- joinedload (Gtfsfeed .gtfsdatasets )
267- .joinedload (Gtfsdataset .validation_reports )
268- .joinedload (Validationreport .notices ),
269- * BasicFeedImpl .get_joinedload_options (),
270- )
271- .order_by (Gtfsfeed .provider , Gtfsfeed .stable_id )
272- )
273- if is_official :
274- feed_query = feed_query .filter (Feed .official )
275- feed_query = feed_query .limit (limit ).offset (offset )
254+ except InternalHTTPException as e :
255+ # get_gtfs_feeds_query cannot throw HTTPException since it's part of fastapi and it's
256+ # not necessarily deployed (e.g. for python functions). Instead it throws an InternalHTTPException
257+ # that needs to be converted to HTTPException before being thrown.
258+ raise convert_exception (e )
259+
276260 return self ._get_response (feed_query , GtfsFeedImpl , db_session )
277261
278262 @with_db_session
@@ -299,7 +283,7 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
299283 .options (
300284 joinedload (Gtfsrealtimefeed .entitytypes ),
301285 joinedload (Gtfsrealtimefeed .gtfs_feeds ),
302- * BasicFeedImpl . get_joinedload_options (),
286+ * get_joinedload_options (),
303287 )
304288 ).all ()
305289
@@ -324,6 +308,26 @@ def get_gtfs_rt_feeds(
324308 db_session : Session ,
325309 ) -> List [GtfsRTFeed ]:
326310 """Get some (or all) GTFS Realtime feeds from the Mobility Database."""
311+ try :
312+ include_wip = not is_user_email_restricted ()
313+ feed_query = get_gtfs_rt_feeds_query (
314+ limit = limit ,
315+ offset = offset ,
316+ provider = provider ,
317+ producer_url = producer_url ,
318+ entity_types = entity_types ,
319+ country_code = country_code ,
320+ subdivision_name = subdivision_name ,
321+ municipality = municipality ,
322+ is_official = is_official ,
323+ include_wip = include_wip ,
324+ db_session = db_session ,
325+ )
326+ except InternalHTTPException as e :
327+ raise convert_exception (e )
328+
329+ return self ._get_response (feed_query , GtfsRTFeedImpl , db_session )
330+
327331 entity_types_list = entity_types .split ("," ) if entity_types else None
328332
329333 # Validate entity types using the EntityType enum
@@ -365,7 +369,7 @@ def get_gtfs_rt_feeds(
365369 .options (
366370 joinedload (Gtfsrealtimefeed .entitytypes ),
367371 joinedload (Gtfsrealtimefeed .gtfs_feeds ),
368- * BasicFeedImpl . get_joinedload_options (),
372+ * get_joinedload_options (),
369373 )
370374 .order_by (Gtfsrealtimefeed .provider , Gtfsrealtimefeed .stable_id )
371375 )
0 commit comments