66from sqlalchemy .orm import joinedload , Session
77from sqlalchemy .orm .query import Query
88
9+ from feeds .impl .datasets_api_impl import DatasetsApiImpl
10+ from feeds .impl .error_handling import raise_http_error , raise_http_validation_error , convert_exception
11+ from feeds .impl .models .entity_type_enum import EntityType
12+ from feeds .impl .models .feed_impl import FeedImpl
13+ from feeds .impl .models .gbfs_feed_impl import GbfsFeedImpl
14+ from feeds .impl .models .gtfs_feed_impl import GtfsFeedImpl
15+ from feeds .impl .models .gtfs_rt_feed_impl import GtfsRTFeedImpl
16+ from feeds_gen .apis .feeds_api_base import BaseFeedsApi
17+ from feeds_gen .models .feed import Feed
18+ from feeds_gen .models .gbfs_feed import GbfsFeed
19+ from feeds_gen .models .gtfs_dataset import GtfsDataset
20+ from feeds_gen .models .gtfs_feed import GtfsFeed
21+ from feeds_gen .models .gtfs_rt_feed import GtfsRTFeed
22+ from middleware .request_context import is_user_email_restricted
923from shared .common .db_utils import (
1024 get_gtfs_feeds_query ,
1125 get_gtfs_rt_feeds_query ,
1226 get_joinedload_options ,
1327 add_official_filter ,
28+ get_gbfs_feeds_query ,
1429)
30+ from shared .common .error_handling import (
31+ invalid_date_message ,
32+ feed_not_found ,
33+ gtfs_feed_not_found ,
34+ gtfs_rt_feed_not_found ,
35+ InternalHTTPException ,
36+ gbfs_feed_not_found ,
37+ )
38+ from shared .common .logging_utils import Logger
1539from shared .database .database import Database , with_db_session
1640from shared .database_gen .sqlacodegen_models import (
17- Feed ,
41+ Feed as FeedOrm ,
1842 Gtfsdataset ,
1943 Gtfsfeed ,
2044 Gtfsrealtimefeed ,
2549from shared .feed_filters .gtfs_dataset_filter import GtfsDatasetFilter
2650from shared .feed_filters .gtfs_feed_filter import LocationFilter
2751from shared .feed_filters .gtfs_rt_feed_filter import GtfsRtFeedFilter , EntityTypeFilter
28- from feeds .impl .datasets_api_impl import DatasetsApiImpl
29- from shared .common .error_handling import (
30- invalid_date_message ,
31- feed_not_found ,
32- gtfs_feed_not_found ,
33- gtfs_rt_feed_not_found ,
34- InternalHTTPException ,
35- )
36- from feeds .impl .models .basic_feed_impl import BasicFeedImpl
37- from feeds .impl .models .entity_type_enum import EntityType
38- from feeds .impl .models .gtfs_feed_impl import GtfsFeedImpl
39- from feeds .impl .models .gtfs_rt_feed_impl import GtfsRTFeedImpl
40- from feeds_gen .apis .feeds_api_base import BaseFeedsApi
41- from feeds_gen .models .basic_feed import BasicFeed
42- from feeds_gen .models .gtfs_dataset import GtfsDataset
43- from feeds_gen .models .gtfs_feed import GtfsFeed
44- from feeds_gen .models .gtfs_rt_feed import GtfsRTFeed
45- from feeds .impl .error_handling import raise_http_error , raise_http_validation_error , convert_exception
46- from middleware .request_context import is_user_email_restricted
4752from utils .date_utils import valid_iso_date
48- from shared .common .logging_utils import Logger
4953
50- T = TypeVar ("T" , bound = "BasicFeed " )
54+ T = TypeVar ("T" , bound = "Feed " )
5155
5256
5357class FeedsApiImpl (BaseFeedsApi ):
@@ -57,31 +61,30 @@ class FeedsApiImpl(BaseFeedsApi):
5761 If a method is left blank the associated endpoint will return a 500 HTTP response.
5862 """
5963
60- APIFeedType = Union [BasicFeed , GtfsFeed , GtfsRTFeed ]
64+ APIFeedType = Union [FeedOrm , GtfsFeed , GtfsRTFeed ]
6165
6266 def __init__ (self ) -> None :
6367 self .logger = Logger ("FeedsApiImpl" ).get_logger ()
6468
6569 @with_db_session
66- def get_feed (self , id : str , db_session : Session ) -> BasicFeed :
70+ def get_feed (self , id : str , db_session : Session ) -> Feed :
6771 """Get the specified feed from the Mobility Database."""
6872 is_email_restricted = is_user_email_restricted ()
6973 self .logger .debug (f"User email is restricted: { is_email_restricted } " )
7074
7175 feed = (
7276 FeedFilter (stable_id = id , provider__ilike = None , producer_url__ilike = None , status = None )
73- .filter (Database ().get_query_model (db_session , Feed ))
74- .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
77+ .filter (Database ().get_query_model (db_session , FeedOrm ))
7578 .filter (
7679 or_ (
77- Feed .operational_status == "published" ,
80+ FeedOrm .operational_status == "published" ,
7881 not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
7982 )
8083 )
8184 .first ()
8285 )
8386 if feed :
84- return BasicFeedImpl .from_orm (feed )
87+ return FeedImpl .from_orm (feed )
8588 else :
8689 raise_http_error (404 , feed_not_found .format (id ))
8790
@@ -95,32 +98,31 @@ def get_feeds(
9598 producer_url : str ,
9699 is_official : bool ,
97100 db_session : Session ,
98- ) -> List [BasicFeed ]:
101+ ) -> List [Feed ]:
99102 """Get some (or all) feeds from the Mobility Database."""
100103 is_email_restricted = is_user_email_restricted ()
101104 self .logger .debug (f"User email is restricted: { is_email_restricted } " )
102105 feed_filter = FeedFilter (
103106 status = status , provider__ilike = provider , producer_url__ilike = producer_url , stable_id = None
104107 )
105- feed_query = feed_filter .filter (Database ().get_query_model (db_session , Feed ))
108+ feed_query = feed_filter .filter (Database ().get_query_model (db_session , FeedOrm ))
106109 feed_query = add_official_filter (feed_query , is_official )
107- feed_query = feed_query .filter (Feed .data_type != "gbfs" ) # Filter out GBFS feeds
108110 feed_query = feed_query .filter (
109111 or_ (
110- Feed .operational_status == "published" ,
112+ FeedOrm .operational_status == "published" ,
111113 not is_email_restricted , # Allow all feeds to be returned if the user is not restricted
112114 )
113115 )
114116 # Results are sorted by provider
115- feed_query = feed_query .order_by (Feed .provider , Feed .stable_id )
117+ feed_query = feed_query .order_by (FeedOrm .provider , FeedOrm .stable_id )
116118 feed_query = feed_query .options (* get_joinedload_options ())
117119 if limit is not None :
118120 feed_query = feed_query .limit (limit )
119121 if offset is not None :
120122 feed_query = feed_query .offset (offset )
121123
122124 results = feed_query .all ()
123- return [BasicFeedImpl .from_orm (feed ) for feed in results ]
125+ return [FeedImpl .from_orm (feed ) for feed in results ]
124126
125127 @with_db_session
126128 def get_gtfs_feed (self , id : str , db_session : Session ) -> GtfsFeed :
@@ -163,7 +165,7 @@ def get_gtfs_feed_datasets(
163165 feed = self ._get_gtfs_feed (gtfs_feed_id , db_session , include_options_for_joinedload = False )
164166
165167 if not feed :
166- raise_http_error (404 , f"Feed with id { gtfs_feed_id } not found" )
168+ raise_http_error (404 , f"FeedOrm with id { gtfs_feed_id } not found" )
167169
168170 # Replace Z with +00:00 to make the datetime object timezone aware
169171 # Due to https://github.com/python/cpython/issues/80010, once migrate to Python 3.11, we can use fromisoformat
@@ -174,7 +176,7 @@ def get_gtfs_feed_datasets(
174176 downloaded_at__gte = (
175177 datetime .fromisoformat (downloaded_after .replace ("Z" , "+00:00" )) if downloaded_after else None
176178 ),
177- ).filter (DatasetsApiImpl .create_dataset_query ().filter (Feed .stable_id == gtfs_feed_id ))
179+ ).filter (DatasetsApiImpl .create_dataset_query ().filter (FeedOrm .stable_id == gtfs_feed_id ))
178180
179181 if latest :
180182 query = query .filter (Gtfsdataset .latest )
@@ -352,3 +354,47 @@ def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[Gtfs
352354 return [GtfsRTFeedImpl .from_orm (gtfs_rt_feed ) for gtfs_rt_feed in feed .gtfs_rt_feeds ]
353355 else :
354356 raise_http_error (404 , gtfs_feed_not_found .format (id ))
357+
358+ @with_db_session
359+ def get_gbfs_feed (
360+ self ,
361+ id : str ,
362+ db_session : Session ,
363+ ) -> GbfsFeed :
364+ """Get the specified GBFS feed from the Mobility Database."""
365+ result = get_gbfs_feeds_query (db_session , stable_id = id ).one_or_none ()
366+ if result :
367+ return GbfsFeedImpl .from_orm (result )
368+ else :
369+ raise_http_error (404 , gbfs_feed_not_found .format (id ))
370+
371+ @with_db_session
372+ def get_gbfs_feeds (
373+ self ,
374+ limit : int ,
375+ offset : int ,
376+ provider : str ,
377+ producer_url : str ,
378+ country_code : str ,
379+ subdivision_name : str ,
380+ municipality : str ,
381+ system_id : str ,
382+ version : str ,
383+ db_session : Session ,
384+ ) -> List [GbfsFeed ]:
385+ query = get_gbfs_feeds_query (
386+ db_session = db_session ,
387+ provider = provider ,
388+ producer_url = producer_url ,
389+ country_code = country_code ,
390+ subdivision_name = subdivision_name ,
391+ municipality = municipality ,
392+ system_id = system_id ,
393+ version = version ,
394+ )
395+ if limit :
396+ query = query .limit (limit )
397+ if offset :
398+ query = query .offset (offset )
399+ results = query .all ()
400+ return [GbfsFeedImpl .from_orm (feed ) for feed in results ]
0 commit comments