Skip to content

Commit 22af3e6

Browse files
authored
feat: add create_gtfs_feed endpoint to operations API (#1427)
1 parent d6ca0a3 commit 22af3e6

File tree

28 files changed

+1192
-164
lines changed

28 files changed

+1192
-164
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,7 @@ functions-python/**/*.csv
8383
.cloudstorage
8484

8585
# Project files
86-
*.code-workspace
86+
*.code-workspace
87+
88+
# Ignore OpenApi local backup files
89+
*.yaml.bak
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import logging
2+
from typing import Optional
3+
4+
from sqlalchemy.orm import Session
5+
6+
from shared.database_gen.sqlacodegen_models import Location
7+
8+
9+
def get_country_code(country_name: str) -> Optional[str]:
10+
"""
11+
Get ISO 3166 country code from country name
12+
13+
Args:
14+
country_name (str): Full country name
15+
16+
Returns:
17+
Optional[str]: Two-letter ISO country code or None if not found
18+
"""
19+
import pycountry
20+
21+
# Return None for empty or whitespace-only strings
22+
if not country_name or not country_name.strip():
23+
logging.error("Could not find country code for: empty string")
24+
return None
25+
26+
try:
27+
# Try exact match first
28+
country = pycountry.countries.get(name=country_name)
29+
if country:
30+
return country.alpha_2
31+
32+
# Try searching with fuzzy matching
33+
countries = pycountry.countries.search_fuzzy(country_name)
34+
if countries:
35+
return countries[0].alpha_2
36+
37+
except LookupError:
38+
logging.error(f"Could not find country code for: {country_name}")
39+
return None
40+
41+
42+
def create_or_get_location(
43+
session: Session,
44+
country: Optional[str],
45+
state_province: Optional[str],
46+
city_name: Optional[str],
47+
country_code: Optional[str] = None,
48+
) -> Optional[Location]:
49+
"""
50+
Create a new location or get existing one
51+
52+
Args:
53+
session: Database session
54+
country: Country name
55+
state_province: State/province name
56+
city_name: City name
57+
country_code: ISO 3166-1 alpha-2 country code
58+
59+
Returns:
60+
Optional[Location]: Location object or None if creation failed
61+
"""
62+
import pycountry
63+
64+
if not any([country, state_province, city_name]):
65+
return None
66+
67+
# Generate location_id using the specified pattern
68+
location_components = []
69+
if country_code:
70+
location_components.append(country_code)
71+
try:
72+
py_country = pycountry.countries.get(alpha_2=country_code.strip())
73+
country = py_country.name if py_country else country
74+
except LookupError:
75+
logging.warning("Could not find country code for: %s", country_code)
76+
elif country:
77+
country_code = get_country_code(country)
78+
if country_code:
79+
location_components.append(country_code)
80+
else:
81+
logging.error(f"Could not determine country code for {country}")
82+
return None
83+
84+
if state_province:
85+
location_components.append(state_province)
86+
if city_name:
87+
location_components.append(city_name)
88+
89+
location_id = "-".join(location_components)
90+
91+
# First check if location already exists
92+
existing_location = session.query(Location).filter(Location.id == location_id).first()
93+
94+
if existing_location:
95+
logging.debug(f"Using existing location: {location_id}")
96+
return existing_location
97+
98+
# Create new location
99+
location = Location(
100+
id=location_id,
101+
country_code=country_code,
102+
country=country,
103+
subdivision_name=state_province,
104+
municipality=city_name,
105+
)
106+
session.add(location)
107+
session.flush()
108+
logging.debug(f"Created new location: {location_id}")
109+
110+
return location

api/src/shared/db_models/external_id_impl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,12 @@ def from_orm(cls, external_id: Externalid | None) -> ExternalId | None:
2121
external_id=external_id.associated_id,
2222
source=external_id.source,
2323
)
24+
25+
@classmethod
26+
def to_orm_from_dict(cls, external_id_dict: dict | None) -> Externalid | None:
27+
if not external_id_dict:
28+
return None
29+
return Externalid(
30+
associated_id=external_id_dict.get("external_id"),
31+
source=external_id_dict.get("source"),
32+
)

api/src/shared/db_models/feed_impl.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from shared.db_models.basic_feed_impl import BaseFeedImpl
22
from feeds_gen.models.feed import Feed
33
from shared.database_gen.sqlacodegen_models import Feed as FeedOrm
4+
from shared.db_models.external_id_impl import ExternalIdImpl
45
from shared.db_models.feed_related_link_impl import FeedRelatedLinkImpl
56

7+
from shared.db_models.location_impl import LocationImpl
8+
from shared.db_models.redirect_impl import RedirectImpl
9+
610

711
class FeedImpl(BaseFeedImpl, Feed):
812
"""Base implementation of the feeds models.
@@ -27,3 +31,48 @@ def from_orm(cls, feed_orm: FeedOrm | None) -> Feed | None:
2731
feed.related_links = [FeedRelatedLinkImpl.from_orm(related_link) for related_link in feed_orm.feedrelatedlinks]
2832
feed.note = feed_orm.note
2933
return feed
34+
35+
@classmethod
36+
def to_orm_from_dict(cls, feed_dict: dict | None) -> FeedOrm | None:
37+
"""Convert a dictionary representation of a feed to a SQLAlchemy Feed ORM object."""
38+
if not feed_dict:
39+
return None
40+
result: Feed = FeedOrm(
41+
id=feed_dict.get("id"),
42+
stable_id=feed_dict.get("stable_id"),
43+
data_type=feed_dict.get("data_type"),
44+
created_at=feed_dict.get("created_at"),
45+
provider=feed_dict.get("provider"),
46+
feed_contact_email=feed_dict.get("feed_contact_email"),
47+
producer_url=feed_dict.get("producer_url"),
48+
authentication_type=None
49+
if feed_dict.get("authentication_type") is None
50+
else str(feed_dict.get("authentication_type")),
51+
authentication_info_url=feed_dict.get("authentication_info_url"),
52+
api_key_parameter_name=feed_dict.get("api_key_parameter_name"),
53+
license_url=feed_dict.get("license_url"),
54+
status=feed_dict.get("status"),
55+
official=feed_dict.get("official"),
56+
official_updated_at=feed_dict.get("official_updated_at"),
57+
feed_name=feed_dict.get("feed_name"),
58+
note=feed_dict.get("note"),
59+
externalids=sorted(
60+
[ExternalIdImpl.to_orm_from_dict(item) for item in feed_dict.get("externalids")],
61+
key=lambda x: x.associated_id,
62+
)
63+
if feed_dict.get("externalids")
64+
else [],
65+
redirectingids=sorted(
66+
[RedirectImpl.to_orm_from_dict(item) for item in feed_dict.get("redirectingids")],
67+
key=lambda x: x.target_id,
68+
)
69+
if feed_dict.get("redirectingids")
70+
else [],
71+
feedrelatedlinks=[FeedRelatedLinkImpl.to_orm_from_dict(item) for item in feed_dict.get("feedrelatedlinks")]
72+
if feed_dict.get("feedrelatedlinks")
73+
else [],
74+
locations=[LocationImpl.to_orm_from_dict(item) for item in feed_dict.get("locations")]
75+
if feed_dict.get("locations")
76+
else [],
77+
)
78+
return result

api/src/shared/db_models/feed_related_link_impl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,15 @@ def from_orm(cls, feed_related_link_orm: Feedrelatedlink) -> FeedRelatedLink | N
2121
description=feed_related_link_orm.description,
2222
created_at=feed_related_link_orm.created_at,
2323
)
24+
25+
@classmethod
26+
def to_orm_from_dict(cls, feedrelaticlink_dict: dict) -> Feedrelatedlink | None:
27+
"""Convert a dict to a SQLAlchemy row object."""
28+
if not feedrelaticlink_dict:
29+
return None
30+
result = Feedrelatedlink(
31+
code=feedrelaticlink_dict.get("code"),
32+
url=feedrelaticlink_dict.get("url"),
33+
description=feedrelaticlink_dict.get("description"),
34+
)
35+
return result

api/src/shared/db_models/location_impl.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
from sqlalchemy.orm import Session
2+
13
from feeds_gen.models.location import Location
24
import pycountry
5+
6+
from shared.common.locations_utils import create_or_get_location
7+
from shared.database.database import with_db_session
38
from shared.database_gen.sqlacodegen_models import Location as LocationOrm
49

510

@@ -27,3 +32,15 @@ def from_orm(cls, location: LocationOrm | None) -> Location | None:
2732
subdivision_name=location.subdivision_name,
2833
municipality=location.municipality,
2934
)
35+
36+
@classmethod
37+
@with_db_session
38+
def to_orm_from_dict(self, location_dict: dict, db_session: Session) -> LocationOrm | None:
39+
"""Convert the Pydantic model instance to a SQLAlchemy Location row object."""
40+
return create_or_get_location(
41+
session=db_session,
42+
country=location_dict["country"],
43+
state_province=location_dict["subdivision_name"],
44+
city_name=location_dict["municipality"],
45+
country_code=location_dict["country_code"],
46+
)

api/src/shared/db_models/redirect_impl.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from shared.database_gen.sqlacodegen_models import Redirectingid
1+
from sqlalchemy.orm import Session
2+
3+
from shared.database.database import with_db_session
4+
from shared.database_gen.sqlacodegen_models import Redirectingid, Feed
25
from feeds_gen.models.redirect import Redirect
36

47

@@ -21,3 +24,18 @@ def from_orm(cls, redirect: Redirectingid | None) -> Redirect | None:
2124
target_id=redirect.target.stable_id,
2225
comment=redirect.redirect_comment,
2326
)
27+
28+
@classmethod
29+
@with_db_session
30+
def to_orm_from_dict(cls, redirect_dict: dict | None, db_session: Session = None) -> Redirectingid | None:
31+
# Return None if no payload or missing target_id
32+
if not redirect_dict or redirect_dict.get("target_id") is None:
33+
return None
34+
target = db_session.query(Feed).filter_by(stable_id=redirect_dict.get("target_id")).first()
35+
if not target:
36+
return None
37+
return Redirectingid(
38+
target_id=target.id,
39+
target=target,
40+
redirect_comment=redirect_dict.get("comment"),
41+
)

api/tests/unittest/models/test_basic_feed_impl.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
feed_id="feed_id",
6363
hosted_url="hosted_url",
6464
note="note",
65-
downloaded_at="downloaded_at",
65+
downloaded_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("UTC")),
6666
hash="hash",
6767
bounding_box="bounding_box",
6868
service_date_range_start=datetime(2024, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
@@ -160,3 +160,97 @@ def test_from_orm_empty_fields(self):
160160
def test_from_orm_none(self):
161161
"""Test the `from_orm` method with None."""
162162
assert FeedImpl.from_orm(None) is None
163+
164+
def test_to_orm_from_dict_none(self):
165+
"""to_orm_from_dict returns None when input is None or empty dict."""
166+
assert FeedImpl.to_orm_from_dict(None) is None
167+
assert FeedImpl.to_orm_from_dict({}) is None
168+
169+
def test_to_orm_from_dict_full_payload(self):
170+
"""to_orm_from_dict maps primitives and nested collections; sorts externalids by associated_id."""
171+
now = datetime(2024, 5, 1, 12, 30, 0)
172+
updated = datetime(2024, 6, 1, 8, 0, 0)
173+
payload = {
174+
"id": "feed-123",
175+
"stable_id": "stable-123",
176+
"data_type": "gtfs",
177+
"created_at": now,
178+
"provider": "Provider A",
179+
"feed_contact_email": "[email protected]",
180+
"producer_url": "https://producer.example.com",
181+
"authentication_type": 1, # should be converted to string
182+
"authentication_info_url": "https://auth.example.com",
183+
"api_key_parameter_name": "api_key",
184+
"license_url": "https://license.example.com",
185+
"status": "active",
186+
"official": True,
187+
"official_updated_at": updated,
188+
"feed_name": "Feed Name",
189+
"note": "Some note",
190+
# avoid DB-dependent fields: locations and redirectingids
191+
"externalids": [
192+
{"external_id": "b-id", "source": "src"},
193+
{"external_id": "a-id", "source": "src"},
194+
],
195+
"feedrelatedlinks": [
196+
{"code": "docs", "url": "https://docs.example.com", "description": "Docs"},
197+
{"code": "home", "url": "https://home.example.com", "description": "Home"},
198+
],
199+
}
200+
201+
obj = FeedImpl.to_orm_from_dict(payload)
202+
203+
# Basic type
204+
assert isinstance(obj, Feed)
205+
206+
# Primitives
207+
assert obj.id == "feed-123"
208+
assert obj.stable_id == "stable-123"
209+
assert obj.data_type == "gtfs"
210+
assert obj.created_at == now
211+
assert obj.provider == "Provider A"
212+
assert obj.feed_contact_email == "[email protected]"
213+
assert obj.producer_url == "https://producer.example.com"
214+
# authentication_type coerced to string per implementation
215+
assert obj.authentication_type == "1"
216+
assert obj.authentication_info_url == "https://auth.example.com"
217+
assert obj.api_key_parameter_name == "api_key"
218+
assert obj.license_url == "https://license.example.com"
219+
assert obj.status == "active"
220+
assert obj.official is True
221+
assert obj.official_updated_at == updated
222+
assert obj.feed_name == "Feed Name"
223+
assert obj.note == "Some note"
224+
225+
# Nested: externalids should be sorted by associated_id
226+
assert [type(e).__name__ for e in obj.externalids] == ["Externalid", "Externalid"]
227+
got_ext = [(e.source, e.associated_id) for e in obj.externalids]
228+
assert got_ext == [("src", "a-id"), ("src", "b-id")]
229+
230+
# Nested: feedrelatedlinks preserved order (no explicit sort in impl)
231+
assert len(obj.feedrelatedlinks) == 2
232+
codes = [feedrelatedlinks.code for feedrelatedlinks in obj.feedrelatedlinks]
233+
urls = [feedrelatedlinks.url for feedrelatedlinks in obj.feedrelatedlinks]
234+
assert codes == ["docs", "home"]
235+
assert urls == ["https://docs.example.com", "https://home.example.com"]
236+
237+
# Relationships not provided should be empty lists
238+
assert obj.redirectingids == []
239+
assert obj.locations == []
240+
241+
def test_to_orm_from_dict_empty_collections(self):
242+
"""Explicit empty lists yield empty relationship collections in ORM object."""
243+
payload = {
244+
"stable_id": "s",
245+
"data_type": "gtfs",
246+
"externalids": [],
247+
"feedrelatedlinks": [],
248+
"redirectingids": [],
249+
"locations": [],
250+
}
251+
obj = FeedImpl.to_orm_from_dict(payload)
252+
assert isinstance(obj, Feed)
253+
assert obj.externalids == []
254+
assert obj.feedrelatedlinks == []
255+
assert obj.redirectingids == []
256+
assert obj.locations == []

0 commit comments

Comments
 (0)