Skip to content

Commit a77ee30

Browse files
authored
fix: removed redundant location extraction and refactored code (#985)
1 parent fbf3040 commit a77ee30

30 files changed

+246
-2001
lines changed

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
Gtfsrealtimefeed,
1616
Location,
1717
Validationreport,
18-
t_location_with_translations_en,
1918
Entitytype,
2019
)
2120
from shared.feed_filters.feed_filter import FeedFilter
@@ -42,10 +41,6 @@
4241
from feeds.impl.error_handling import raise_http_error, raise_http_validation_error, convert_exception
4342
from middleware.request_context import is_user_email_restricted
4443
from utils.date_utils import valid_iso_date
45-
from utils.location_translation import (
46-
create_location_translation_object,
47-
get_feeds_location_translations,
48-
)
4944
from utils.logger import Logger
5045

5146
T = TypeVar("T", bound="BasicFeed")
@@ -253,7 +248,7 @@ def get_gtfs_feeds(
253248
# that needs to be converted to HTTPException before being thrown.
254249
raise convert_exception(e)
255250

256-
return self._get_response(feed_query, GtfsFeedImpl, db_session)
251+
return self._get_response(feed_query, GtfsFeedImpl)
257252

258253
@with_db_session
259254
def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
@@ -266,25 +261,23 @@ def get_gtfs_rt_feed(self, id: str, db_session: Session) -> GtfsRTFeed:
266261
location=None,
267262
)
268263
results = gtfs_rt_feed_filter.filter(
269-
db_session.query(Gtfsrealtimefeed, t_location_with_translations_en)
264+
db_session.query(Gtfsrealtimefeed)
270265
.filter(
271266
or_(
272267
Gtfsrealtimefeed.operational_status == "published",
273268
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
274269
)
275270
)
276271
.outerjoin(Location, Gtfsrealtimefeed.locations)
277-
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
278272
.options(
279273
joinedload(Gtfsrealtimefeed.entitytypes),
280274
joinedload(Gtfsrealtimefeed.gtfs_feeds),
281275
*get_joinedload_options(),
282276
)
283277
).all()
284278

285-
if len(results) > 0 and results[0].Gtfsrealtimefeed:
286-
translations = {result[1]: create_location_translation_object(result) for result in results}
287-
return GtfsRTFeedImpl.from_orm(results[0].Gtfsrealtimefeed, translations)
279+
if len(results) > 0 and results[0]:
280+
return GtfsRTFeedImpl.from_orm(results[0])
288281
else:
289282
raise_http_error(404, gtfs_rt_feed_not_found.format(id))
290283

@@ -321,7 +314,7 @@ def get_gtfs_rt_feeds(
321314
except InternalHTTPException as e:
322315
raise convert_exception(e)
323316

324-
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)
317+
return self._get_response(feed_query, GtfsRTFeedImpl)
325318

326319
entity_types_list = entity_types.split(",") if entity_types else None
327320

@@ -370,14 +363,13 @@ def get_gtfs_rt_feeds(
370363
if is_official:
371364
feed_query = feed_query.filter(Feed.official)
372365
feed_query = feed_query.limit(limit).offset(offset)
373-
return self._get_response(feed_query, GtfsRTFeedImpl, db_session)
366+
return self._get_response(feed_query, GtfsRTFeedImpl)
374367

375368
@staticmethod
376-
def _get_response(feed_query: Query, impl_cls: type[T], db_session: "Session") -> List[T]:
369+
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
377370
"""Get the response for the feed query."""
378371
results = feed_query.all()
379-
location_translations = get_feeds_location_translations(results, db_session)
380-
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
372+
response = [impl_cls.from_orm(feed) for feed in results]
381373
return list({feed.id: feed for feed in response}.values())
382374

383375
@with_db_session

api/src/feeds/impl/models/basic_feed_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Config:
1717
from_attributes = True
1818

1919
@classmethod
20-
def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
20+
def from_orm(cls, feed: Feed | None) -> BasicFeed | None:
2121
if not feed:
2222
return None
2323
return cls(

api/src/feeds/impl/models/gtfs_feed_impl.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
from typing import Dict
2-
31
from shared.database_gen.sqlacodegen_models import Gtfsfeed as GtfsfeedOrm
42
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
53
from feeds.impl.models.latest_dataset_impl import LatestDatasetImpl
64
from feeds.impl.models.location_impl import LocationImpl
75
from feeds_gen.models.gtfs_feed import GtfsFeed
8-
from utils.location_translation import LocationTranslation, translate_feed_locations
96

107

118
class GtfsFeedImpl(BaseFeedImpl, GtfsFeed):
@@ -20,11 +17,7 @@ class Config:
2017
from_attributes = True
2118

2219
@classmethod
23-
def from_orm(
24-
cls, feed: GtfsfeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
25-
) -> GtfsFeed | None:
26-
if location_translations is not None:
27-
translate_feed_locations(feed, location_translations)
20+
def from_orm(cls, feed: GtfsfeedOrm | None) -> GtfsFeed | None:
2821
gtfs_feed: GtfsFeed = super().from_orm(feed)
2922
if not gtfs_feed:
3023
return None

api/src/feeds/impl/models/gtfs_rt_feed_impl.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from typing import Dict
2-
31
from shared.database_gen.sqlacodegen_models import Gtfsrealtimefeed as GtfsRTFeedOrm
42
from feeds.impl.models.basic_feed_impl import BaseFeedImpl
53
from feeds.impl.models.location_impl import LocationImpl
64
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
7-
from utils.location_translation import LocationTranslation, translate_feed_locations
85

96

107
class GtfsRTFeedImpl(BaseFeedImpl, GtfsRTFeed):
@@ -17,11 +14,7 @@ class Config:
1714
from_attributes = True
1815

1916
@classmethod
20-
def from_orm(
21-
cls, feed: GtfsRTFeedOrm | None, location_translations: Dict[str, LocationTranslation] = None
22-
) -> GtfsRTFeed | None:
23-
if location_translations is not None:
24-
translate_feed_locations(feed, location_translations)
17+
def from_orm(cls, feed: GtfsRTFeedOrm | None) -> GtfsRTFeed | None:
2518
gtfs_rt_feed: GtfsRTFeed = super().from_orm(feed)
2619
if not gtfs_rt_feed:
2720
return None

api/src/feeds/impl/models/location_impl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from feeds_gen.models.location import Location
2+
import pycountry
23
from shared.database_gen.sqlacodegen_models import Location as LocationOrm
34

45

@@ -14,9 +15,15 @@ def from_orm(cls, location: LocationOrm | None) -> Location | None:
1415
"""Create a model instance from a SQLAlchemy a Location row object."""
1516
if not location:
1617
return None
18+
country_name = location.country
19+
if not country_name:
20+
try:
21+
country_name = pycountry.countries.get(alpha_2=location.country_code).name
22+
except AttributeError:
23+
pass
1724
return cls(
1825
country_code=location.country_code,
19-
country=location.country,
26+
country=country_name,
2027
subdivision_name=location.subdivision_name,
2128
municipality=location.municipality,
2229
)

api/src/feeds/impl/models/search_feed_item_result_impl.py

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def from_orm_gtfs(cls, feed_search_row):
3838
license_url=feed_search_row.license_url,
3939
),
4040
redirects=feed_search_row.redirect_ids,
41-
locations=feed_search_row.locations,
41+
locations=cls.resolve_locations(feed_search_row.locations),
4242
latest_dataset=LatestDataset(
4343
id=feed_search_row.latest_dataset_id,
4444
hosted_url=feed_search_row.latest_dataset_hosted_url,
@@ -74,55 +74,39 @@ def from_orm_gtfs_rt(cls, feed_search_row):
7474
license_url=feed_search_row.license_url,
7575
),
7676
redirects=feed_search_row.redirect_ids,
77-
locations=feed_search_row.locations,
77+
locations=cls.resolve_locations(feed_search_row.locations),
7878
entity_types=feed_search_row.entities,
7979
feed_references=feed_search_row.feed_reference_ids,
8080
)
8181

8282
@classmethod
83-
def _translate_locations(cls, feed_search_row):
84-
"""Translate location information in the feed search row.
85-
This method modifies the locations in the feed search row in place."""
86-
if feed_search_row.locations is None:
87-
return
88-
country_translations = cls._create_translation_dict(feed_search_row.country_translations)
89-
subdivision_translations = cls._create_translation_dict(feed_search_row.subdivision_name_translations)
90-
municipality_translations = cls._create_translation_dict(feed_search_row.municipality_translations)
91-
92-
for location in feed_search_row.locations:
93-
location["country"] = country_translations.get(location["country"], location["country"])
94-
if location["country"] is None or len(location["country"]) == 0:
95-
location["country"] = SearchFeedItemResultImpl.resolve_country_by_code(location)
96-
location["subdivision_name"] = subdivision_translations.get(
97-
location["subdivision_name"], location["subdivision_name"]
98-
)
99-
location["municipality"] = municipality_translations.get(location["municipality"], location["municipality"])
83+
def resolve_locations(cls, locations):
84+
"""Resolve locations by country code."""
85+
return [
86+
{
87+
**location,
88+
"country": location.get("country")
89+
if location.get("country")
90+
else cls.resolve_country_by_code(location),
91+
}
92+
for location in locations
93+
]
10094

10195
@classmethod
10296
def resolve_country_by_code(cls, location):
10397
"""Resolve country name by country code.
10498
If the country code is not found, return the original country name."""
105-
country = pycountry.countries.get(alpha_2=location["country_code"])
106-
return country.name if country else location["country"]
107-
108-
@staticmethod
109-
def _create_translation_dict(translations):
110-
"""Helper method to create a translation dictionary."""
111-
if translations:
112-
return {
113-
elem.get("key"): elem.get("value") for elem in translations if elem.get("key") and elem.get("value")
114-
}
115-
return {}
99+
try:
100+
country = pycountry.countries.get(alpha_2=location.get("country_code"))
101+
return country.name if country else location.get("country")
102+
except AttributeError:
103+
return location.get("country")
116104

117105
@classmethod
118106
def from_orm(cls, feed_search_row):
119107
"""Create a model instance from a SQLAlchemy row object."""
120108
if feed_search_row is None:
121109
return None
122-
123-
# Translate location data
124-
cls._translate_locations(feed_search_row)
125-
126110
match feed_search_row.data_type:
127111
case "gtfs":
128112
return cls.from_orm_gtfs(feed_search_row)

api/src/utils/location_translation.py

Lines changed: 0 additions & 128 deletions
This file was deleted.

api/tests/unittest/models/test_gtfs_feed_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class TestGtfsFeedImpl(unittest.TestCase):
182182

183183
def test_from_orm_all_fields(self):
184184
"""Test the `from_orm` method with all fields."""
185-
result = GtfsFeedImpl.from_orm(gtfs_feed_orm, {})
185+
result = GtfsFeedImpl.from_orm(gtfs_feed_orm)
186186
assert result == expected_gtfs_feed_result
187187

188188
def test_from_orm_empty_fields(self):

functions-python/extract_location/.coveragerc

Lines changed: 0 additions & 10 deletions
This file was deleted.

functions-python/extract_location/.env.rename_me

Lines changed: 0 additions & 2 deletions
This file was deleted.

0 commit comments

Comments
 (0)