|
1 | 1 | from datetime import datetime |
2 | | -from typing import List, Union, TypeVar |
| 2 | +from typing import List, Union, TypeVar, Optional |
3 | 3 |
|
4 | 4 | from sqlalchemy import or_ |
5 | 5 | from sqlalchemy import select |
|
44 | 44 | from utils.date_utils import valid_iso_date |
45 | 45 | from utils.location_translation import ( |
46 | 46 | create_location_translation_object, |
47 | | - LocationTranslation, |
48 | 47 | get_feeds_location_translations, |
49 | 48 | ) |
50 | 49 | from utils.logger import Logger |
@@ -130,42 +129,39 @@ def get_feeds( |
130 | 129 | @with_db_session |
131 | 130 | def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed: |
132 | 131 | """Get the specified gtfs feed from the Mobility Database.""" |
133 | | - feed, translations = self._get_gtfs_feed(id, db_session) |
| 132 | + feed = self._get_gtfs_feed(id, db_session) |
134 | 133 | if feed: |
135 | | - return GtfsFeedImpl.from_orm(feed, translations) |
| 134 | + return GtfsFeedImpl.from_orm(feed) |
136 | 135 | else: |
137 | 136 | raise_http_error(404, gtfs_feed_not_found.format(id)) |
138 | 137 |
|
139 | 138 | @staticmethod |
140 | | - def _get_gtfs_feed(stable_id: str, db_session: Session) -> tuple[Gtfsfeed | None, dict[str, LocationTranslation]]: |
| 139 | + def _get_gtfs_feed(stable_id: str, db_session: Session) -> Optional[Gtfsfeed]: |
141 | 140 | results = ( |
142 | 141 | FeedFilter( |
143 | 142 | stable_id=stable_id, |
144 | 143 | status=None, |
145 | 144 | provider__ilike=None, |
146 | 145 | producer_url__ilike=None, |
147 | 146 | ) |
148 | | - .filter(db_session.query(Gtfsfeed, t_location_with_translations_en)) |
| 147 | + .filter(db_session.query(Gtfsfeed)) |
149 | 148 | .filter( |
150 | 149 | or_( |
151 | 150 | Gtfsfeed.operational_status == None, # noqa: E711 |
152 | 151 | Gtfsfeed.operational_status != "wip", |
153 | 152 | not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted |
154 | 153 | ) |
155 | 154 | ) |
156 | | - .outerjoin(Location, Feed.locations) |
157 | | - .outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id) |
158 | 155 | .options( |
159 | 156 | joinedload(Gtfsfeed.gtfsdatasets) |
160 | 157 | .joinedload(Gtfsdataset.validation_reports) |
161 | 158 | .joinedload(Validationreport.notices), |
162 | 159 | *get_joinedload_options(), |
163 | 160 | ) |
164 | 161 | ).all() |
165 | | - if len(results) > 0 and results[0].Gtfsfeed: |
166 | | - translations = {result[1]: create_location_translation_object(result) for result in results} |
167 | | - return results[0].Gtfsfeed, translations |
168 | | - return None, {} |
| 162 | + if len(results) == 0: |
| 163 | + return None |
| 164 | + return results[0] |
169 | 165 |
|
170 | 166 | @with_db_session |
171 | 167 | def get_gtfs_feed_datasets( |
@@ -393,8 +389,8 @@ def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") - |
393 | 389 | @with_db_session |
394 | 390 | def get_gtfs_feed_gtfs_rt_feeds(self, id: str, db_session: Session) -> List[GtfsRTFeed]: |
395 | 391 | """Get a list of GTFS Realtime related to a GTFS feed.""" |
396 | | - feed, translations = self._get_gtfs_feed(id, db_session) |
| 392 | + feed = self._get_gtfs_feed(id, db_session) |
397 | 393 | if feed: |
398 | | - return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed, translations) for gtfs_rt_feed in feed.gtfs_rt_feeds] |
| 394 | + return [GtfsRTFeedImpl.from_orm(gtfs_rt_feed) for gtfs_rt_feed in feed.gtfs_rt_feeds] |
399 | 395 | else: |
400 | 396 | raise_http_error(404, gtfs_feed_not_found.format(id)) |
0 commit comments