Skip to content

Commit 9c6daad

Browse files
authored
feat: refactor latest dataset as a relationship of gtfsfeed (#1405)
1 parent 37183b8 commit 9c6daad

File tree

35 files changed

+562
-203
lines changed

35 files changed

+562
-203
lines changed

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_gtfs_feed_datasets(
183183
).filter(DatasetsApiImpl.create_dataset_query().filter(FeedOrm.stable_id == gtfs_feed_id))
184184

185185
if latest:
186-
query = query.filter(Gtfsdataset.latest)
186+
query = query.join(Gtfsdataset.feed).filter(Gtfsdataset.id == FeedOrm.latest_dataset_id)
187187

188188
return DatasetsApiImpl.get_datasets_gtfs(query, session=db_session, limit=limit, offset=offset)
189189

api/src/feeds/impl/models/gtfs_feed_impl.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ def from_orm(cls, feed: GtfsfeedOrm | None) -> GtfsFeed | None:
2323
if not gtfs_feed:
2424
return None
2525
gtfs_feed.locations = [LocationImpl.from_orm(item) for item in feed.locations]
26-
latest_dataset = next(
27-
(dataset for dataset in feed.gtfsdatasets if dataset is not None and dataset.latest), None
28-
)
29-
gtfs_feed.latest_dataset = LatestDatasetImpl.from_orm(latest_dataset)
26+
gtfs_feed.latest_dataset = LatestDatasetImpl.from_orm(feed.latest_dataset)
3027
gtfs_feed.bounding_box = BoundingBoxImpl.from_orm(feed.bounding_box)
3128
gtfs_feed.visualization_dataset_id = (
3229
feed.visualization_dataset.stable_id if feed.visualization_dataset else None

api/src/scripts/populate_db_test_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
7373
id=dataset["id"],
7474
feed_id=gtfsfeed[0].id,
7575
stable_id=dataset["id"],
76-
latest=dataset["latest"],
7776
hosted_url=dataset["hosted_url"],
7877
hash=dataset["hash"],
7978
downloaded_at=dataset["downloaded_at"],
@@ -82,6 +81,9 @@ def populate_test_datasets(self, filepath, db_session: "Session"):
8281
),
8382
validation_reports=[],
8483
)
84+
if dataset["latest"]:
85+
gtfsfeed[0].latest_dataset = gtfs_dataset
86+
8587
dataset_dict[dataset["id"]] = gtfs_dataset
8688
db_session.add(gtfs_dataset)
8789
db_session.commit()

api/src/shared/common/db_utils.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def get_gtfs_feeds_query(
6161
subquery = apply_bounding_filtering(
6262
subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method
6363
).subquery()
64-
feed_query = (
65-
db_session.query(Gtfsfeed)
66-
.outerjoin(Gtfsfeed.gtfsdatasets)
67-
.filter(Gtfsfeed.id.in_(subquery))
68-
.filter(or_(Gtfsdataset.latest, Gtfsdataset.id == None)) # noqa: E711
69-
)
64+
feed_query = db_session.query(Gtfsfeed).filter(Gtfsfeed.id.in_(subquery))
7065

7166
if country_code or subdivision_name or municipality:
7267
location_filter = LocationFilter(
@@ -84,7 +79,7 @@ def get_gtfs_feeds_query(
8479

8580
if include_options_for_joinedload:
8681
feed_query = feed_query.options(
87-
contains_eager(Gtfsfeed.gtfsdatasets)
82+
joinedload(Gtfsfeed.latest_dataset)
8883
.joinedload(Gtfsdataset.validation_reports)
8984
.joinedload(Validationreport.features),
9085
joinedload(Gtfsfeed.visualization_dataset),
@@ -172,14 +167,10 @@ def get_all_gtfs_feeds(
172167
for batch in batched(batch_query, batch_size):
173168
stable_ids = (f.stable_id for f in batch)
174169
if w_extracted_locations_only:
175-
feed_query = apply_most_common_location_filter(
176-
db_session.query(Gtfsfeed).outerjoin(Gtfsfeed.gtfsdatasets), db_session
177-
)
170+
feed_query = apply_most_common_location_filter(db_session.query(Gtfsfeed), db_session)
178171
yield from (
179-
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids))
180-
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
181-
.options(
182-
contains_eager(Gtfsfeed.gtfsdatasets)
172+
feed_query.filter(Gtfsfeed.stable_id.in_(stable_ids)).options(
173+
joinedload(Gtfsfeed.latest_dataset)
183174
.joinedload(Gtfsdataset.validation_reports)
184175
.joinedload(Validationreport.features),
185176
*get_joinedload_options(include_extracted_location_entities=True),
@@ -190,9 +181,8 @@ def get_all_gtfs_feeds(
190181
db_session.query(Gtfsfeed)
191182
.outerjoin(Gtfsfeed.gtfsdatasets)
192183
.filter(Gtfsfeed.stable_id.in_(stable_ids))
193-
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
194184
.options(
195-
contains_eager(Gtfsfeed.gtfsdatasets)
185+
joinedload(Gtfsfeed.latest_dataset)
196186
.joinedload(Gtfsdataset.validation_reports)
197187
.joinedload(Validationreport.features),
198188
*get_joinedload_options(include_extracted_location_entities=False),

api/tests/unittest/models/test_gtfs_feed_impl.py

Lines changed: 81 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import unittest
32
from datetime import datetime
43
from zoneinfo import ZoneInfo
@@ -49,13 +48,47 @@ def create_test_notice(notice_code: str, total_notices: int, severity: str):
4948
)
5049

5150

51+
gtfs_dataset_orm = Gtfsdataset(
52+
id="id",
53+
stable_id="dataset_stable_id",
54+
feed_id="feed_id",
55+
hosted_url="hosted_url",
56+
note="note",
57+
downloaded_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
58+
hash="hash",
59+
service_date_range_start=datetime(2024, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
60+
service_date_range_end=datetime(2025, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
61+
agency_timezone="Canada/Atlantic",
62+
bounding_box=WKTElement(POLYGON, srid=4326),
63+
validation_reports=[
64+
Validationreport(
65+
id="id",
66+
validator_version="validator_version",
67+
validated_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
68+
html_report="html_report",
69+
json_report="json_report",
70+
features=[Feature(name="feature")],
71+
notices=[
72+
create_test_notice("notice_code1", 1, "INFO"),
73+
create_test_notice("notice_code2", 3, "INFO"),
74+
create_test_notice("notice_code3", 7, "ERROR"),
75+
create_test_notice("notice_code4", 9, "ERROR"),
76+
create_test_notice("notice_code5", 11, "ERROR"),
77+
create_test_notice("notice_code6", 13, "WARNING"),
78+
create_test_notice("notice_code7", 15, "WARNING"),
79+
create_test_notice("notice_code8", 17, "WARNING"),
80+
create_test_notice("notice_code9", 19, "WARNING"),
81+
],
82+
)
83+
],
84+
)
5285
gtfs_feed_orm = Gtfsfeed(
5386
id="id",
5487
data_type="gtfs",
5588
feed_name="feed_name",
5689
note="note",
5790
producer_url="producer_url",
58-
authentication_type=1,
91+
authentication_type="1",
5992
authentication_info_url="authentication_info_url",
6093
api_key_parameter_name="api_key_parameter_name",
6194
license_url="license_url",
@@ -79,43 +112,8 @@ def create_test_notice(notice_code: str, total_notices: int, severity: str):
79112
source="source",
80113
)
81114
],
82-
gtfsdatasets=[
83-
Gtfsdataset(
84-
id="id",
85-
stable_id="dataset_stable_id",
86-
feed_id="feed_id",
87-
hosted_url="hosted_url",
88-
note="note",
89-
downloaded_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
90-
hash="hash",
91-
service_date_range_start=datetime(2024, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
92-
service_date_range_end=datetime(2025, 1, 1, 0, 0, 0, tzinfo=ZoneInfo("Canada/Atlantic")),
93-
agency_timezone="Canada/Atlantic",
94-
bounding_box=WKTElement(POLYGON, srid=4326),
95-
latest=True,
96-
validation_reports=[
97-
Validationreport(
98-
id="id",
99-
validator_version="validator_version",
100-
validated_at=datetime(year=2022, month=12, day=31, hour=13, minute=45, second=56),
101-
html_report="html_report",
102-
json_report="json_report",
103-
features=[Feature(name="feature")],
104-
notices=[
105-
create_test_notice("notice_code1", 1, "INFO"),
106-
create_test_notice("notice_code2", 3, "INFO"),
107-
create_test_notice("notice_code3", 7, "ERROR"),
108-
create_test_notice("notice_code4", 9, "ERROR"),
109-
create_test_notice("notice_code5", 11, "ERROR"),
110-
create_test_notice("notice_code6", 13, "WARNING"),
111-
create_test_notice("notice_code7", 15, "WARNING"),
112-
create_test_notice("notice_code8", 17, "WARNING"),
113-
create_test_notice("notice_code9", 19, "WARNING"),
114-
],
115-
)
116-
],
117-
)
118-
],
115+
latest_dataset=gtfs_dataset_orm,
116+
gtfsdatasets=[gtfs_dataset_orm],
119117
redirectingids=[
120118
Redirectingid(source_id="source_id", target_id="id1", redirect_comment="redirect_comment", target=targetFeed)
121119
],
@@ -198,24 +196,47 @@ def test_from_orm_all_fields(self):
198196

199197
def test_from_orm_empty_fields(self):
200198
"""Test the `from_orm` method with not provided fields."""
201-
# Test with empty fields and None values
202-
# No error should be raised
203-
# Target is set to None as deep copy is failing for unknown reasons
204-
# At the end of the test, the target is set back to the original value
205-
gtfs_feed_orm.redirectingids[0].target = None
206-
target_feed_orm = copy.deepcopy(gtfs_feed_orm)
207-
target_feed_orm.feed_name = ""
208-
target_feed_orm.provider = None
209-
target_feed_orm.externalids = []
210-
target_feed_orm.redirectingids = []
211-
212-
target_expected_gtfs_feed_result = copy.deepcopy(expected_gtfs_feed_result)
213-
target_expected_gtfs_feed_result.feed_name = ""
214-
target_expected_gtfs_feed_result.provider = None
215-
target_expected_gtfs_feed_result.external_ids = []
216-
target_expected_gtfs_feed_result.redirects = []
217-
218-
result = GtfsFeedImpl.from_orm(target_feed_orm)
219-
assert result == target_expected_gtfs_feed_result
220-
# Set the target back to the original value
221-
gtfs_feed_orm.redirectingids[0].target = targetFeed
199+
# Manually construct a minimal Gtfsfeed ORM object with empty/None fields
200+
minimal_feed_orm = Gtfsfeed(
201+
id="id",
202+
data_type="gtfs",
203+
feed_name="",
204+
note=None,
205+
producer_url=None,
206+
authentication_type=None,
207+
authentication_info_url=None,
208+
api_key_parameter_name=None,
209+
license_url=None,
210+
stable_id="stable_id",
211+
status=None,
212+
feed_contact_email=None,
213+
provider=None,
214+
locations=[],
215+
externalids=[],
216+
latest_dataset=None,
217+
gtfsdatasets=[],
218+
redirectingids=[],
219+
gtfs_rt_feeds=[],
220+
)
221+
minimal_expected_result = GtfsFeedImpl(
222+
id="stable_id",
223+
data_type="gtfs",
224+
status=None,
225+
external_ids=[],
226+
provider=None,
227+
feed_name="",
228+
note=None,
229+
feed_contact_email=None,
230+
source_info=SourceInfo(
231+
producer_url=None,
232+
authentication_type=None,
233+
authentication_info_url=None,
234+
api_key_parameter_name=None,
235+
license_url=None,
236+
),
237+
redirects=[],
238+
locations=[],
239+
latest_dataset=None,
240+
)
241+
result = GtfsFeedImpl.from_orm(minimal_feed_orm)
242+
assert result == minimal_expected_result

api/tests/unittest/test_feeds.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_gtfs_feeds_get_no_bounding_box(client: TestClient, mocker):
172172
"""
173173
mock_select = mocker.patch.object(Database(), "select")
174174
mock_feed = Feed(stable_id="test_gtfs_id")
175-
mock_latest_datasets = Gtfsdataset(stable_id="test_latest_dataset_id", hosted_url="test_hosted_url", latest=True)
175+
mock_latest_datasets = Gtfsdataset(stable_id="test_latest_dataset_id", hosted_url="test_hosted_url")
176176

177177
mock_select.return_value = [
178178
[
@@ -296,18 +296,17 @@ def assert_gtfs(gtfs_feed, response_gtfs_feed):
296296
), f'Response feed municipality was {response_gtfs_feed["locations"][0]["municipality"]} \
297297
instead of {gtfs_feed.locations[0].municipality}'
298298
# It seems the resulting are not always in the same order, so find the latest instead of using a hardcoded index
299-
latest_dataset = next((dataset for dataset in gtfs_feed.gtfsdatasets if dataset.latest), None)
300-
if latest_dataset is not None:
299+
# latest_dataset = next((dataset for dataset in gtfs_feed.gtfsdatasets if dataset.latest), None)
300+
if gtfs_feed.latest_dataset is not None:
301301
assert (
302-
response_gtfs_feed["latest_dataset"]["id"] == latest_dataset.stable_id
302+
response_gtfs_feed["latest_dataset"]["id"] == gtfs_feed.latest_dataset.stable_id
303303
), f'Response feed latest dataset id was {response_gtfs_feed["latest_dataset"]["id"]} \
304-
instead of {latest_dataset.stable_id}'
304+
instead of {gtfs_feed.latest_dataset.stable_id}'
305305
else:
306306
raise Exception("No latest dataset found")
307307

308-
latest_dataset = next(filter(lambda x: x.latest, gtfs_feed.gtfsdatasets))
309308
assert (
310-
response_gtfs_feed["latest_dataset"]["hosted_url"] == latest_dataset.hosted_url
309+
response_gtfs_feed["latest_dataset"]["hosted_url"] == gtfs_feed.latest_dataset.hosted_url
311310
), f'Response feed hosted url was {response_gtfs_feed["latest_dataset"]["hosted_url"]} \
312311
instead of test_hosted_url'
313312
assert response_gtfs_feed["latest_dataset"]["bounding_box"] is not None, "Response feed bounding_box was None"

functions-python/batch_datasets/src/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from google.cloud import pubsub_v1
2626
from google.cloud.pubsub_v1 import PublisherClient
2727
from google.cloud.pubsub_v1.futures import Future
28-
from sqlalchemy import or_
2928
from sqlalchemy.orm import Session
3029

3130
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsdataset
@@ -87,9 +86,8 @@ def get_non_deprecated_feeds(
8786
Gtfsdataset.hash.label("dataset_hash"),
8887
)
8988
.select_from(Gtfsfeed)
90-
.outerjoin(Gtfsdataset, (Gtfsdataset.feed_id == Gtfsfeed.id))
89+
.outerjoin(Gtfsdataset, (Gtfsfeed.latest_dataset_id == Gtfsdataset.id))
9190
.filter(Gtfsfeed.status != "deprecated")
92-
.filter(or_(Gtfsdataset.id.is_(None), Gtfsdataset.latest.is_(True)))
9391
)
9492
if feed_stable_ids:
9593
# If feed_stable_ids are provided, filter the query by stable IDs

functions-python/batch_datasets/tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,10 @@ def populate_database(db_session: Session | None = None):
8484
# GTFS datasets leaving one active feed without a dataset
8585
active_gtfs_feeds = db_session.query(Gtfsfeed).all()
8686
for i in range(1, 9):
87+
id = fake.uuid4()
8788
gtfs_dataset = Gtfsdataset(
88-
id=fake.uuid4(),
89+
id=id,
8990
feed_id=active_gtfs_feeds[i].id,
90-
latest=True,
9191
bounding_box="POLYGON((-180 -90, -180 90, 180 90, 180 -90, -180 -90))",
9292
hosted_url=fake.url(),
9393
note=fake.sentence(),
@@ -96,6 +96,8 @@ def populate_database(db_session: Session | None = None):
9696
stable_id=fake.uuid4(),
9797
)
9898
db_session.add(gtfs_dataset)
99+
db_session.flush()
100+
active_gtfs_feeds[i].latest_dataset_id = id
99101

100102
db_session.flush()
101103
# GTFS Realtime feeds

0 commit comments

Comments
 (0)