Skip to content

Commit e613ad7

Browse files
authored
fix: limit and offset behaviour for feeds with multiple locations (#773)
1 parent e11fd88 commit e613ad7

File tree

8 files changed

+345
-104
lines changed

8 files changed

+345
-104
lines changed

.github/workflows/db-update-qa.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ name: Database Update - QA
33
on:
44
workflow_dispatch:
55
workflow_call:
6-
6+
repository_dispatch: # Update on mobility-database-catalog repo dispatch
7+
types: [ catalog-sources-updated ]
78
jobs:
89
update:
910
uses: ./.github/workflows/db-update.yml

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from datetime import datetime
22
from typing import List, Union, TypeVar
3-
3+
from sqlalchemy import select
44
from sqlalchemy.orm import joinedload
55
from sqlalchemy.orm.query import Query
6+
67
from database.database import Database
78
from database_gen.sqlacodegen_models import (
89
Feed,
@@ -11,8 +12,8 @@
1112
Gtfsrealtimefeed,
1213
Location,
1314
Validationreport,
14-
Entitytype,
1515
t_location_with_translations_en,
16+
Entitytype,
1617
)
1718
from feeds.filters.feed_filter import FeedFilter
1819
from feeds.filters.gtfs_dataset_filter import GtfsDatasetFilter
@@ -37,7 +38,11 @@
3738
from feeds_gen.models.gtfs_feed import GtfsFeed
3839
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
3940
from utils.date_utils import valid_iso_date
40-
from utils.location_translation import create_location_translation_object, LocationTranslation
41+
from utils.location_translation import (
42+
create_location_translation_object,
43+
LocationTranslation,
44+
get_feeds_location_translations,
45+
)
4146

4247
T = TypeVar("T", bound="BasicFeed")
4348

@@ -197,25 +202,28 @@ def get_gtfs_feeds(
197202
municipality__ilike=municipality,
198203
),
199204
)
200-
gtfs_feed_query = gtfs_feed_filter.filter(
201-
Database().get_session().query(Gtfsfeed, t_location_with_translations_en)
202-
)
203-
gtfs_feed_query = (
204-
gtfs_feed_query.outerjoin(Location, Feed.locations)
205-
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
205+
206+
subquery = gtfs_feed_filter.filter(select(Gtfsfeed.id).join(Location, Gtfsfeed.locations))
207+
subquery = DatasetsApiImpl.apply_bounding_filtering(
208+
subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method
209+
).subquery()
210+
211+
feed_query = (
212+
Database()
213+
.get_session()
214+
.query(Gtfsfeed)
215+
.filter(Gtfsfeed.id.in_(subquery))
206216
.options(
207217
joinedload(Gtfsfeed.gtfsdatasets)
208218
.joinedload(Gtfsdataset.validation_reports)
209219
.joinedload(Validationreport.notices),
210220
*BasicFeedImpl.get_joinedload_options(),
211221
)
212222
.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
223+
.limit(limit)
224+
.offset(offset)
213225
)
214-
gtfs_feed_query = gtfs_feed_query.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
215-
gtfs_feed_query = DatasetsApiImpl.apply_bounding_filtering(
216-
gtfs_feed_query, dataset_latitudes, dataset_longitudes, bounding_filter_method
217-
)
218-
return self._get_response(gtfs_feed_query, limit, offset, GtfsFeedImpl)
226+
return self._get_response(feed_query, GtfsFeedImpl)
219227

220228
def get_gtfs_rt_feed(
221229
self,
@@ -283,32 +291,33 @@ def get_gtfs_rt_feeds(
283291
municipality__ilike=municipality,
284292
),
285293
)
286-
gtfs_rt_feed_query = gtfs_rt_feed_filter.filter(
287-
Database().get_session().query(Gtfsrealtimefeed, t_location_with_translations_en)
288-
)
289-
gtfs_rt_feed_query = (
290-
gtfs_rt_feed_query.outerjoin(Location, Gtfsrealtimefeed.locations)
291-
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
292-
.outerjoin(Entitytype, Gtfsrealtimefeed.entitytypes)
294+
subquery = gtfs_rt_feed_filter.filter(
295+
select(Gtfsrealtimefeed.id)
296+
.join(Location, Gtfsrealtimefeed.locations)
297+
.join(Entitytype, Gtfsrealtimefeed.entitytypes)
298+
).subquery()
299+
feed_query = (
300+
Database()
301+
.get_session()
302+
.query(Gtfsrealtimefeed)
303+
.filter(Gtfsrealtimefeed.id.in_(subquery))
293304
.options(
294305
joinedload(Gtfsrealtimefeed.entitytypes),
295306
joinedload(Gtfsrealtimefeed.gtfs_feeds),
296307
*BasicFeedImpl.get_joinedload_options(),
297308
)
298309
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
310+
.limit(limit)
311+
.offset(offset)
299312
)
300-
return self._get_response(gtfs_rt_feed_query, limit, offset, GtfsRTFeedImpl)
313+
return self._get_response(feed_query, GtfsRTFeedImpl)
301314

302315
@staticmethod
303-
def _get_response(feed_query: Query, limit: int, offset: int, impl_cls: type[T]) -> List[T]:
316+
def _get_response(feed_query: Query, impl_cls: type[T]) -> List[T]:
304317
"""Get the response for the feed query."""
305-
if limit is not None:
306-
feed_query = feed_query.limit(limit)
307-
if offset is not None:
308-
feed_query = feed_query.offset(offset)
309318
results = feed_query.all()
310-
location_translations = {row[1]: create_location_translation_object(row) for row in results}
311-
response = [impl_cls.from_orm(feed[0], location_translations) for feed in results]
319+
location_translations = get_feeds_location_translations(results)
320+
response = [impl_cls.from_orm(feed, location_translations) for feed in results]
312321
return list({feed.id: feed for feed in response}.values())
313322

314323
def get_gtfs_feed_gtfs_rt_feeds(

api/src/scripts/populate_db_test_data.py

Lines changed: 114 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
import json
2+
from uuid import uuid4
23

34
from geoalchemy2 import WKTElement
5+
from google.cloud.sql.connector.instance import logger
46
from sqlalchemy import text
57

68
from database.database import Database
7-
from database_gen.sqlacodegen_models import Gtfsdataset, Validationreport, Gtfsfeed, Notice, Feature, t_feedsearch
8-
from scripts.populate_db import set_up_configs
9+
from database_gen.sqlacodegen_models import (
10+
Gtfsdataset,
11+
Validationreport,
12+
Gtfsfeed,
13+
Notice,
14+
Feature,
15+
t_feedsearch,
16+
Location,
17+
)
18+
from scripts.populate_db import set_up_configs, DatabasePopulateHelper
919
from utils.logger import Logger
1020

1121

@@ -36,70 +46,78 @@ def populate_test_datasets(self, filepath):
3646
with open(filepath) as f:
3747
data = json.load(f)
3848

49+
# GTFS Feeds
50+
if "feeds" in data:
51+
self.populate_test_feeds(data["feeds"])
52+
3953
# GTFS Datasets
4054
dataset_dict = {}
41-
for dataset in data["datasets"]:
42-
# query the db using feed_id to get the feed object
43-
gtfsfeed = self.db.session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == dataset["feed_stable_id"]).all()
44-
if not gtfsfeed:
45-
self.logger.error(f"No feed found with stable_id: {dataset['feed_stable_id']}")
46-
continue
47-
48-
gtfs_dataset = Gtfsdataset(
49-
id=dataset["id"],
50-
feed_id=gtfsfeed[0].id,
51-
stable_id=dataset["id"],
52-
latest=dataset["latest"],
53-
hosted_url=dataset["hosted_url"],
54-
hash=dataset["hash"],
55-
downloaded_at=dataset["downloaded_at"],
56-
bounding_box=None
57-
if dataset.get("bounding_box") is None
58-
else WKTElement(dataset["bounding_box"], srid=4326),
59-
validation_reports=[],
60-
)
61-
dataset_dict[dataset["id"]] = gtfs_dataset
62-
self.db.session.add(gtfs_dataset)
55+
if "datasets" in data:
56+
for dataset in data["datasets"]:
57+
# query the db using feed_id to get the feed object
58+
gtfsfeed = self.db.session.query(Gtfsfeed).filter(Gtfsfeed.stable_id == dataset["feed_stable_id"]).all()
59+
if not gtfsfeed:
60+
self.logger.error(f"No feed found with stable_id: {dataset['feed_stable_id']}")
61+
continue
62+
63+
gtfs_dataset = Gtfsdataset(
64+
id=dataset["id"],
65+
feed_id=gtfsfeed[0].id,
66+
stable_id=dataset["id"],
67+
latest=dataset["latest"],
68+
hosted_url=dataset["hosted_url"],
69+
hash=dataset["hash"],
70+
downloaded_at=dataset["downloaded_at"],
71+
bounding_box=None
72+
if dataset.get("bounding_box") is None
73+
else WKTElement(dataset["bounding_box"], srid=4326),
74+
validation_reports=[],
75+
)
76+
dataset_dict[dataset["id"]] = gtfs_dataset
77+
self.db.session.add(gtfs_dataset)
6378

6479
# Validation reports
65-
validation_report_dict = {}
66-
for report in data["validation_reports"]:
67-
validation_report = Validationreport(
68-
id=report["id"],
69-
validator_version=report["validator_version"],
70-
validated_at=report["validated_at"],
71-
html_report=report["html_report"],
72-
json_report=report["json_report"],
73-
features=[],
74-
)
75-
dataset_dict[report["dataset_id"]].validation_reports.append(validation_report)
76-
validation_report_dict[report["id"]] = validation_report
77-
self.db.session.add(validation_report)
80+
if "validation_reports" in data:
81+
validation_report_dict = {}
82+
for report in data["validation_reports"]:
83+
validation_report = Validationreport(
84+
id=report["id"],
85+
validator_version=report["validator_version"],
86+
validated_at=report["validated_at"],
87+
html_report=report["html_report"],
88+
json_report=report["json_report"],
89+
features=[],
90+
)
91+
dataset_dict[report["dataset_id"]].validation_reports.append(validation_report)
92+
validation_report_dict[report["id"]] = validation_report
93+
self.db.session.add(validation_report)
7894

7995
# Notices
80-
for report_notice in data["notices"]:
81-
notice = Notice(
82-
dataset_id=report_notice["dataset_id"],
83-
validation_report_id=report_notice["validation_report_id"],
84-
severity=report_notice["severity"],
85-
notice_code=report_notice["notice_code"],
86-
total_notices=report_notice["total_notices"],
87-
)
88-
self.db.session.add(notice)
96+
if "notices" in data:
97+
for report_notice in data["notices"]:
98+
notice = Notice(
99+
dataset_id=report_notice["dataset_id"],
100+
validation_report_id=report_notice["validation_report_id"],
101+
severity=report_notice["severity"],
102+
notice_code=report_notice["notice_code"],
103+
total_notices=report_notice["total_notices"],
104+
)
105+
self.db.session.add(notice)
89106
# Features
90-
for featureName in data["features"]:
91-
feature = Feature(name=featureName)
92-
self.db.session.add(feature)
107+
if "features" in data:
108+
for featureName in data["features"]:
109+
feature = Feature(name=featureName)
110+
self.db.session.add(feature)
93111

94112
# Features in Validation Reports
95-
for report_features in data["validation_report_features"]:
96-
validation_report_dict[report_features["validation_report_id"]].features.append(
97-
self.db.session.query(Feature).filter(Feature.name == report_features["feature_name"]).first()
98-
)
99-
100-
self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}"))
113+
if "validation_report_features" in data:
114+
for report_features in data["validation_report_features"]:
115+
validation_report_dict[report_features["validation_report_id"]].features.append(
116+
self.db.session.query(Feature).filter(Feature.name == report_features["feature_name"]).first()
117+
)
101118

102119
self.db.session.commit()
120+
self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}"))
103121

104122
def populate(self):
105123
"""
@@ -116,6 +134,47 @@ def populate(self):
116134

117135
self.logger.info("Database populated with test data")
118136

137+
def populate_test_feeds(self, feeds_data):
138+
for feed_data in feeds_data:
139+
feed = Gtfsfeed(
140+
id=str(uuid4()),
141+
stable_id=feed_data["id"],
142+
data_type=feed_data["data_type"],
143+
status=feed_data["status"],
144+
created_at=feed_data["created_at"],
145+
provider=feed_data["provider"],
146+
feed_name=feed_data["feed_name"],
147+
note=feed_data["note"],
148+
authentication_info_url=None,
149+
api_key_parameter_name=None,
150+
license_url=None,
151+
feed_contact_email=feed_data["feed_contact_email"],
152+
producer_url=feed_data["source_info"]["producer_url"],
153+
)
154+
locations = []
155+
for location_data in feed_data["locations"]:
156+
location_id = DatabasePopulateHelper.get_location_id(
157+
location_data["country_code"],
158+
location_data["subdivision_name"],
159+
location_data["municipality"],
160+
)
161+
location = self.db.session.get(Location, location_id)
162+
location = (
163+
location
164+
if location
165+
else Location(
166+
id=location_id,
167+
country_code=location_data["country_code"],
168+
subdivision_name=location_data["subdivision_name"],
169+
municipality=location_data["municipality"],
170+
country=location_data["country"],
171+
)
172+
)
173+
locations.append(location)
174+
feed.locations = locations
175+
self.db.session.add(feed)
176+
logger.info(f"Added feed {feed.stable_id}")
177+
119178

120179
if __name__ == "__main__":
121180
db_helper = DatabasePopulateTestDataHelper(set_up_configs())

0 commit comments

Comments
 (0)