66from sqlalchemy .orm import joinedload , Session
77from sqlalchemy .orm .query import Query
88
9- from database .database import Database , with_db_session
10- from database_gen .sqlacodegen_models import (
9+ from shared .common .db_utils import get_gtfs_feeds_query , get_gtfs_rt_feeds_query , get_joinedload_options
10+ from shared .database .database import Database , with_db_session
11+ from shared .database_gen .sqlacodegen_models import (
1112 Feed ,
1213 Gtfsdataset ,
1314 Gtfsfeed ,
1718 t_location_with_translations_en ,
1819 Entitytype ,
1920)
20- from feeds . filters .feed_filter import FeedFilter
21- from feeds . filters .gtfs_dataset_filter import GtfsDatasetFilter
22- from feeds . filters .gtfs_feed_filter import GtfsFeedFilter , LocationFilter
23- from feeds . filters .gtfs_rt_feed_filter import GtfsRtFeedFilter , EntityTypeFilter
21+ from shared . feed_filters .feed_filter import FeedFilter
22+ from shared . feed_filters .gtfs_dataset_filter import GtfsDatasetFilter
23+ from shared . feed_filters .gtfs_feed_filter import LocationFilter
24+ from shared . feed_filters .gtfs_rt_feed_filter import GtfsRtFeedFilter , EntityTypeFilter
2425from feeds .impl .datasets_api_impl import DatasetsApiImpl
25- from feeds .impl .error_handling import (
26- raise_http_validation_error ,
26+ from shared .common .error_handling import (
2727 invalid_date_message ,
28- raise_http_error ,
2928 feed_not_found ,
3029 gtfs_feed_not_found ,
3130 gtfs_rt_feed_not_found ,
31+ InternalHTTPException ,
3232)
3333from feeds .impl .models .basic_feed_impl import BasicFeedImpl
3434from feeds .impl .models .entity_type_enum import EntityType
3939from feeds_gen .models .gtfs_dataset import GtfsDataset
4040from feeds_gen .models .gtfs_feed import GtfsFeed
4141from feeds_gen .models .gtfs_rt_feed import GtfsRTFeed
42+ from feeds .impl .error_handling import raise_http_error , raise_http_validation_error , convert_exception
4243from middleware .request_context import is_user_email_restricted
4344from utils .date_utils import valid_iso_date
4445from utils .location_translation import (
@@ -117,7 +118,7 @@ def get_feeds(
117118 )
118119 # Results are sorted by provider
119120 feed_query = feed_query .order_by (Feed .provider , Feed .stable_id )
120- feed_query = feed_query .options (* BasicFeedImpl . get_joinedload_options ())
121+ feed_query = feed_query .options (* get_joinedload_options ())
121122 if limit is not None :
122123 feed_query = feed_query .limit (limit )
123124 if offset is not None :
@@ -158,7 +159,7 @@ def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None
158159 joinedload (Gtfsfeed .gtfsdatasets )
159160 .joinedload (Gtfsdataset .validation_reports )
160161 .joinedload (Validationreport .notices ),
161- * BasicFeedImpl . get_joinedload_options (),
162+ * get_joinedload_options (),
162163 )
163164 ).all ()
164165 if len (results ) > 0 and results [0 ].Gtfsfeed :
@@ -237,46 +238,29 @@ def get_gtfs_feeds(
237238 is_official : bool ,
238239 db_session : Session ,
239240 ) -> List [GtfsFeed ]:
240- """Get some (or all) GTFS feeds from the Mobility Database."""
241- gtfs_feed_filter = GtfsFeedFilter (
242- stable_id = None ,
243- provider__ilike = provider ,
244- producer_url__ilike = producer_url ,
245- location = LocationFilter (
241+ try :
242+ include_wip = not is_user_email_restricted ()
243+ feed_query = get_gtfs_feeds_query (
244+ limit = limit ,
245+ offset = offset ,
246+ provider = provider ,
247+ producer_url = producer_url ,
246248 country_code = country_code ,
247- subdivision_name__ilike = subdivision_name ,
248- municipality__ilike = municipality ,
249- ),
250- )
251-
252- subquery = gtfs_feed_filter .filter (select (Gtfsfeed .id ).join (Location , Gtfsfeed .locations ))
253- subquery = DatasetsApiImpl .apply_bounding_filtering (
254- subquery , dataset_latitudes , dataset_longitudes , bounding_filter_method
255- ).subquery ()
256-
257- is_email_restricted = is_user_email_restricted ()
258- self .logger .info (f"User email is restricted: { is_email_restricted } " )
259- feed_query = (
260- db_session .query (Gtfsfeed )
261- .filter (Gtfsfeed .id .in_ (subquery ))
262- .filter (
263- or_ (
264- Gtfsfeed .operational_status == None , # noqa: E711
265- Gtfsfeed .operational_status != "wip" ,
266- not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
267- )
249+ subdivision_name = subdivision_name ,
250+ municipality = municipality ,
251+ dataset_latitudes = dataset_latitudes ,
252+ dataset_longitudes = dataset_longitudes ,
253+ bounding_filter_method = bounding_filter_method ,
254+ is_official = is_official ,
255+ include_wip = include_wip ,
256+ db_session = db_session ,
268257 )
269- .options (
270- joinedload (Gtfsfeed .gtfsdatasets )
271- .joinedload (Gtfsdataset .validation_reports )
272- .joinedload (Validationreport .notices ),
273- * BasicFeedImpl .get_joinedload_options (),
274- )
275- .order_by (Gtfsfeed .provider , Gtfsfeed .stable_id )
276- )
277- if is_official :
278- feed_query = feed_query .filter (Feed .official )
279- feed_query = feed_query .limit (limit ).offset (offset )
258+ except InternalHTTPException as e :
259+ # get_gtfs_feeds_query cannot throw HTTPException since it's part of fastapi and it's
260+ # not necessarily deployed (e.g. for python functions). Instead it throws an InternalHTTPException
261+ # that needs to be converted to HTTPException before being thrown.
262+ raise convert_exception (e )
263+
280264 return self ._get_response (feed_query , GtfsFeedImpl , db_session )
281265
282266 @with_db_session
@@ -303,7 +287,7 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
303287 .options (
304288 joinedload (Gtfsrealtimefeed .entitytypes ),
305289 joinedload (Gtfsrealtimefeed .gtfs_feeds ),
306- * BasicFeedImpl . get_joinedload_options (),
290+ * get_joinedload_options (),
307291 )
308292 ).all ()
309293
@@ -328,6 +312,26 @@ def get_gtfs_rt_feeds(
328312 db_session : Session ,
329313 ) -> List [GtfsRTFeed ]:
330314 """Get some (or all) GTFS Realtime feeds from the Mobility Database."""
315+ try :
316+ include_wip = not is_user_email_restricted ()
317+ feed_query = get_gtfs_rt_feeds_query (
318+ limit = limit ,
319+ offset = offset ,
320+ provider = provider ,
321+ producer_url = producer_url ,
322+ entity_types = entity_types ,
323+ country_code = country_code ,
324+ subdivision_name = subdivision_name ,
325+ municipality = municipality ,
326+ is_official = is_official ,
327+ include_wip = include_wip ,
328+ db_session = db_session ,
329+ )
330+ except InternalHTTPException as e :
331+ raise convert_exception (e )
332+
333+ return self ._get_response (feed_query , GtfsRTFeedImpl , db_session )
334+
331335 entity_types_list = entity_types .split ("," ) if entity_types else None
332336
333337 # Validate entity types using the EntityType enum
@@ -369,7 +373,7 @@ def get_gtfs_rt_feeds(
369373 .options (
370374 joinedload (Gtfsrealtimefeed .entitytypes ),
371375 joinedload (Gtfsrealtimefeed .gtfs_feeds ),
372- * BasicFeedImpl . get_joinedload_options (),
376+ * get_joinedload_options (),
373377 )
374378 .order_by (Gtfsrealtimefeed .provider , Gtfsrealtimefeed .stable_id )
375379 )
0 commit comments