|
1 | 1 | from datetime import datetime |
2 | | -from typing import List, Union, TypeVar |
| 2 | +from typing import List, Union, TypeVar, Optional |
3 | 3 |
|
4 | 4 | from sqlalchemy import or_ |
5 | 5 | from sqlalchemy import select |
|
43 | 43 | from utils.date_utils import valid_iso_date |
44 | 44 | from utils.location_translation import ( |
45 | 45 | create_location_translation_object, |
46 | | - LocationTranslation, |
47 | 46 | get_feeds_location_translations, |
48 | 47 | ) |
49 | 48 | from utils.logger import Logger |
@@ -129,42 +128,39 @@ def get_feeds( |
129 | 128 | @with_db_session |
130 | 129 | def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed: |
131 | 130 | """Get the specified gtfs feed from the Mobility Database.""" |
132 | | - feed, translations = self._get_gtfs_feed(id, db_session) |
| 131 | + feed = self._get_gtfs_feed(id, db_session) |
133 | 132 | if feed: |
134 | | - return GtfsFeedImpl.from_orm(feed, translations) |
| 133 | + return GtfsFeedImpl.from_orm(feed) |
135 | 134 | else: |
136 | 135 | raise_http_error(404, gtfs_feed_not_found.format(id)) |
137 | 136 |
|
138 | 137 | @staticmethod |
139 | | - def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]: |
| 138 | + def _get_gtfs_feed(stable_id: str, db_session: Session) -> Optional[Gtfsfeed]: |
140 | 139 | results = ( |
141 | 140 | FeedFilter( |
142 | 141 | stable_id=stable_id, |
143 | 142 | status=None, |
144 | 143 | provider__ilike=None, |
145 | 144 | producer_url__ilike=None, |
146 | 145 | ) |
147 | | - .filter(db_session.query(Gtfsfeed, t_location_with_translations_en)) |
| 146 | + .filter(db_session.query(Gtfsfeed)) |
148 | 147 | .filter( |
149 | 148 | or_( |
150 | 149 | Gtfsfeed.operational_status == None, # noqa: E711 |
151 | 150 | Gtfsfeed.operational_status != "wip", |
152 | 151 | not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted |
153 | 152 | ) |
154 | 153 | ) |
155 | | - .outerjoin(Location, Feed.locations) |
156 | | - .outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id) |
157 | 154 | .options( |
158 | 155 | joinedload(Gtfsfeed.gtfsdatasets) |
159 | 156 | .joinedload(Gtfsdataset.validation_reports) |
160 | 157 | .joinedload(Validationreport.notices), |
161 | 158 | *BasicFeedImpl.get_joinedload_options(), |
162 | 159 | ) |
163 | 160 | ).all() |
164 | | - if len(results) > 0 and results[0].Gtfsfeed: |
165 | | - translations = {result[1]: create_location_translation_object(result) for result in results} |
166 | | - return results[0].Gtfsfeed, translations |
167 | | - return None, {} |
| 161 | + if len(results) == 0: |
| 162 | + return None |
| 163 | + return results[0] |
168 | 164 |
|
169 | 165 | @with_db_session |
170 | 166 | def get_gtfs_feed_datasets( |
@@ -389,8 +385,8 @@ def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") - |
389 | 385 | @with_db_session |
390 | 386 | def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]: |
391 | 387 | """Get a list of GTFS Realtime related to a GTFS feed.""" |
392 | | - feed, translations = self._get_gtfs_feed(id, db_session) |
| 388 | + feed = self._get_gtfs_feed(id, db_session) |
393 | 389 | if feed: |
394 | | - return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds] |
| 390 | + return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds] |
395 | 391 | else: |
396 | 392 | raise_http_error(404, gtfs_feed_not_found.format(id)) |
0 commit comments