11from datetime import datetime
22from typing import List , Union , TypeVar
33
4+ from sqlalchemy import or_
45from sqlalchemy import select
56from sqlalchemy .orm import joinedload , Session
67from sqlalchemy .orm .query import Query
3940from feeds_gen .models .gtfs_feed import GtfsFeed
4041from feeds_gen .models .gtfs_rt_feed import GtfsRTFeed
4142from middleware .request_context import is_user_email_restricted
42- from sqlalchemy import or_
4343from utils .date_utils import valid_iso_date
4444from utils .location_translation import (
4545 create_location_translation_object ,
4646 LocationTranslation ,
4747 get_feeds_location_translations ,
4848)
49+ from utils .logger import Logger
4950
5051T = TypeVar ("T" , bound = "BasicFeed" )
5152
@@ -59,9 +60,15 @@ class FeedsApiImpl(BaseFeedsApi):
5960
6061 APIFeedType = Union [BasicFeed , GtfsFeed , GtfsRTFeed ]
6162
63+ def __init__ (self ) -> None :
64+ self .logger = Logger ("FeedsApiImpl" ).get_logger ()
65+
6266 @with_db_session
6367 def get_feed (self , id : str , db_session : Session ) -> BasicFeed :
6468 """Get the specified feed from the Mobility Database."""
69+ is_email_restricted = is_user_email_restricted ()
70+ self .logger .info (f"User email is restricted: { is_email_restricted } " )
71+
6572 feed = (
6673 FeedFilter (stable_id = id , provider__ilike = None , producer_url__ilike = None , status = None )
6774 .filter (Database ().get_query_model (db_session , Feed ))
@@ -70,8 +77,7 @@ def get_feed(self, id: str, db_session: Session) -> BasicFeed:
7077 or_ (
7178 Feed .operational_status == None , # noqa: E711
7279 Feed .operational_status != "wip" ,
73- # Allow all feeds to be returned if the user is not restricted
74- not is_user_email_restricted (),
80+ not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
7581 )
7682 )
7783 .first ()
@@ -83,19 +89,30 @@ def get_feed(self, id: str, db_session: Session) -> BasicFeed:
8389
8490 @with_db_session
8591 def get_feeds (
86- self , limit : int , offset : int , status : str , provider : str , producer_url : str , db_session : Session
92+ self ,
93+ limit : int ,
94+ offset : int ,
95+ status : str ,
96+ provider : str ,
97+ producer_url : str ,
98+ is_official : bool ,
99+ db_session : Session ,
87100 ) -> List [BasicFeed ]:
88101 """Get some (or all) feeds from the Mobility Database."""
102+ is_email_restricted = is_user_email_restricted ()
103+ self .logger .info (f"User email is restricted: { is_email_restricted } " )
89104 feed_filter = FeedFilter (
90105 status = status , provider__ilike = provider , producer_url__ilike = producer_url , stable_id = None
91106 )
92107 feed_query = feed_filter .filter (Database ().get_query_model (db_session , Feed ))
108+ if is_official :
109+ feed_query = feed_query .filter (Feed .official )
93110 feed_query = feed_query .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
94111 feed_query = feed_query .filter (
95112 or_ (
96113 Feed .operational_status == None , # noqa: E711
97114 Feed .operational_status != "wip" ,
98- not is_user_email_restricted () , # Allow all feeds to be returned if the user is not restricted
115+ not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
99116 )
100117 )
101118 # Results are sorted by provider
@@ -217,6 +234,7 @@ def get_gtfs_feeds(
217234 dataset_latitudes : str ,
218235 dataset_longitudes : str ,
219236 bounding_filter_method : str ,
237+ is_official : bool ,
220238 db_session : Session ,
221239 ) -> List [GtfsFeed ]:
222240 """Get some (or all) GTFS feeds from the Mobility Database."""
@@ -236,14 +254,16 @@ def get_gtfs_feeds(
236254 subquery , dataset_latitudes , dataset_longitudes , bounding_filter_method
237255 ).subquery ()
238256
257+ is_email_restricted = is_user_email_restricted ()
258+ self .logger .info (f"User email is restricted: { is_email_restricted } " )
239259 feed_query = (
240260 db_session .query (Gtfsfeed )
241261 .filter (Gtfsfeed .id .in_ (subquery ))
242262 .filter (
243263 or_ (
244264 Gtfsfeed .operational_status == None , # noqa: E711
245265 Gtfsfeed .operational_status != "wip" ,
246- not is_user_email_restricted () , # Allow all feeds to be returned if the user is not restricted
266+ not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
247267 )
248268 )
249269 .options (
@@ -253,9 +273,10 @@ def get_gtfs_feeds(
253273 * BasicFeedImpl .get_joinedload_options (),
254274 )
255275 .order_by (Gtfsfeed .provider , Gtfsfeed .stable_id )
256- .limit (limit )
257- .offset (offset )
258276 )
277+ if is_official :
278+ feed_query = feed_query .filter (Feed .official )
279+ feed_query = feed_query .limit (limit ).offset (offset )
259280 return self ._get_response (feed_query , GtfsFeedImpl , db_session )
260281
261282 @with_db_session
@@ -303,6 +324,7 @@ def get_gtfs_rt_feeds(
303324 country_code : str ,
304325 subdivision_name : str ,
305326 municipality : str ,
327+ is_official : bool ,
306328 db_session : Session ,
307329 ) -> List [GtfsRTFeed ]:
308330 """Get some (or all) GTFS Realtime feeds from the Mobility Database."""
@@ -350,9 +372,10 @@ def get_gtfs_rt_feeds(
350372 * BasicFeedImpl .get_joinedload_options (),
351373 )
352374 .order_by (Gtfsrealtimefeed .provider , Gtfsrealtimefeed .stable_id )
353- .limit (limit )
354- .offset (offset )
355375 )
376+ if is_official :
377+ feed_query = feed_query .filter (Feed .official )
378+ feed_query = feed_query .limit (limit ).offset (offset )
356379 return self ._get_response (feed_query , GtfsRTFeedImpl , db_session )
357380
358381 @staticmethod
0 commit comments