Skip to content

Commit 39744c8

Browse files
authored
Merge branch 'main' into 1034-change-how-often-were-pulling-the-data-to-daily
2 parents d19d0ef + 380430d commit 39744c8

File tree

11 files changed

+164
-61
lines changed

11 files changed

+164
-61
lines changed

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from sqlalchemy.orm import joinedload, Session
77
from sqlalchemy.orm.query import Query
88

9-
from shared.common.db_utils import get_gtfs_feeds_query, get_gtfs_rt_feeds_query, get_joinedload_options
9+
from shared.common.db_utils import (
10+
get_gtfs_feeds_query,
11+
get_gtfs_rt_feeds_query,
12+
get_joinedload_options,
13+
add_official_filter,
14+
)
1015
from shared.database.database import Database, with_db_session
1116
from shared.database_gen.sqlacodegen_models import (
1217
Feed,
@@ -98,8 +103,7 @@ def get_feeds(
98103
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
99104
)
100105
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
101-
if is_official:
102-
feed_query = feed_query.filter(Feed.official)
106+
feed_query = add_official_filter(feed_query, is_official)
103107
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
104108
feed_query = feed_query.filter(
105109
or_(
@@ -127,13 +131,14 @@ def get_gtfs_feed(self, id: str, db_session: Session) -> GtfsFeed:
127131
else:
128132
raise_http_error(404, gtfs_feed_not_found.format(id))
129133

130-
@staticmethod
131134
def _get_gtfs_feed(
132-
stable_id: str, db_session: Session, include_options_for_joinedload: bool = True
135+
self, stable_id: str, db_session: Session, include_options_for_joinedload: bool = True
133136
) -> Optional[Gtfsfeed]:
134-
results = get_gtfs_feeds_query(
137+
query = get_gtfs_feeds_query(
135138
db_session=db_session, stable_id=stable_id, include_options_for_joinedload=include_options_for_joinedload
136-
).all()
139+
)
140+
self.logger.debug("Query: %s", str(query.statement.compile(compile_kwargs={"literal_binds": True})))
141+
results = query.all()
137142
if len(results) == 0:
138143
return None
139144
return results[0]
@@ -194,7 +199,7 @@ def get_gtfs_feeds(
194199
db_session: Session,
195200
) -> List[GtfsFeed]:
196201
try:
197-
include_wip = not is_user_email_restricted()
202+
published_only = is_user_email_restricted()
198203
feed_query = get_gtfs_feeds_query(
199204
limit=limit,
200205
offset=offset,
@@ -207,7 +212,7 @@ def get_gtfs_feeds(
207212
dataset_longitudes=dataset_longitudes,
208213
bounding_filter_method=bounding_filter_method,
209214
is_official=is_official,
210-
include_wip=include_wip,
215+
published_only=published_only,
211216
db_session=db_session,
212217
)
213218
except InternalHTTPException as e:
@@ -265,7 +270,7 @@ def get_gtfs_rt_feeds(
265270
) -> List[GtfsRTFeed]:
266271
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
267272
try:
268-
include_wip = not is_user_email_restricted()
273+
published_only = is_user_email_restricted()
269274
feed_query = get_gtfs_rt_feeds_query(
270275
limit=limit,
271276
offset=offset,
@@ -276,7 +281,7 @@ def get_gtfs_rt_feeds(
276281
subdivision_name=subdivision_name,
277282
municipality=municipality,
278283
is_official=is_official,
279-
include_wip=include_wip,
284+
published_only=published_only,
280285
db_session=db_session,
281286
)
282287
except InternalHTTPException as e:
@@ -328,8 +333,8 @@ def get_gtfs_rt_feeds(
328333
)
329334
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
330335
)
331-
if is_official:
332-
feed_query = feed_query.filter(Feed.official)
336+
feed_query = add_official_filter(feed_query, is_official)
337+
333338
feed_query = feed_query.limit(limit).offset(offset)
334339
return self._get_response(feed_query, GtfsRTFeedImpl)
335340

api/src/feeds/impl/search_api_impl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status, is
5252
status_list = [s.strip().lower() for s in status[0].split(",") if s]
5353
if status_list:
5454
query = query.where(t_feedsearch.c.status.in_([s.strip().lower() for s in status_list]))
55-
if is_official is not None and is_official:
56-
query = query.where(t_feedsearch.c.official == is_official)
55+
if is_official is not None:
56+
if is_official:
57+
query = query.where(t_feedsearch.c.official.is_(True))
58+
else:
59+
query = query.where(or_(t_feedsearch.c.official.is_(False), t_feedsearch.c.official.is_(None)))
5760
if search_query and len(search_query.strip()) > 0:
5861
query = query.filter(
5962
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))

api/src/shared/common/db_utils.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,66 +37,72 @@ def get_gtfs_feeds_query(
3737
dataset_latitudes: str | None = None,
3838
dataset_longitudes: str | None = None,
3939
bounding_filter_method: str | None = None,
40-
is_official: bool = False,
41-
include_wip: bool = False,
40+
is_official: bool | None = None,
41+
published_only: bool = True,
4242
include_options_for_joinedload: bool = True,
4343
) -> Query[any]:
4444
"""Get the DB query to use to retrieve the GTFS feeds.."""
4545
gtfs_feed_filter = GtfsFeedFilter(
4646
stable_id=stable_id,
4747
provider__ilike=provider,
4848
producer_url__ilike=producer_url,
49-
location=LocationFilter(
50-
country_code=country_code,
51-
subdivision_name__ilike=subdivision_name,
52-
municipality__ilike=municipality,
53-
),
49+
location=None,
5450
)
5551

56-
subquery = gtfs_feed_filter.filter(select(Gtfsfeed.id).join(Location, Gtfsfeed.locations))
52+
subquery = gtfs_feed_filter.filter(select(Gtfsfeed.id))
5753
subquery = apply_bounding_filtering(
5854
subquery, dataset_latitudes, dataset_longitudes, bounding_filter_method
5955
).subquery()
60-
6156
feed_query = (
6257
db_session.query(Gtfsfeed)
6358
.outerjoin(Gtfsfeed.gtfsdatasets)
6459
.filter(Gtfsfeed.id.in_(subquery))
65-
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
60+
.filter(or_(Gtfsdataset.latest, Gtfsdataset.id == None)) # noqa: E711
6661
)
67-
if not include_wip:
62+
63+
if country_code or subdivision_name or municipality:
64+
location_filter = LocationFilter(
65+
country_code=country_code,
66+
subdivision_name__ilike=subdivision_name,
67+
municipality__ilike=municipality,
68+
)
69+
location_subquery = location_filter.filter(select(Location.id))
70+
feed_query = feed_query.filter(Gtfsfeed.locations.any(Location.id.in_(location_subquery)))
71+
72+
if published_only:
6873
feed_query = feed_query.filter(Gtfsfeed.operational_status == "published")
6974

75+
feed_query = add_official_filter(feed_query, is_official)
76+
7077
if include_options_for_joinedload:
7178
feed_query = feed_query.options(
7279
contains_eager(Gtfsfeed.gtfsdatasets)
7380
.joinedload(Gtfsdataset.validation_reports)
7481
.joinedload(Validationreport.notices),
7582
*get_joinedload_options(),
7683
).order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
77-
if is_official:
78-
feed_query = feed_query.filter(Feed.official)
84+
7985
feed_query = feed_query.limit(limit).offset(offset)
8086
return feed_query
8187

8288

8389
def get_all_gtfs_feeds(
8490
db_session: Session,
85-
include_wip: bool = False,
91+
published_only: bool = True,
8692
batch_size: int = 250,
8793
) -> Iterator[Gtfsfeed]:
8894
"""
8995
Fetch all GTFS feeds.
9096
9197
@param db_session: The database session.
92-
@param include_wip: Whether to include or exclude WIP feeds.
98+
@param published_only: Include only the published feeds.
9399
@param batch_size: The number of feeds to fetch from the database at a time.
94100
A lower value means less memory but more queries.
95101
96102
@return: The GTFS feeds in an iterator.
97103
"""
98104
feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
99-
if not include_wip:
105+
if published_only:
100106
feed_query = feed_query.filter(Gtfsfeed.operational_status == "published")
101107

102108
for batch in batched(feed_query, batch_size):
@@ -126,7 +132,7 @@ def get_gtfs_rt_feeds_query(
126132
subdivision_name: str | None,
127133
municipality: str | None,
128134
is_official: bool | None,
129-
include_wip: bool = False,
135+
published_only: bool = True,
130136
db_session: Session = None,
131137
) -> Query:
132138
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
@@ -160,37 +166,49 @@ def get_gtfs_rt_feeds_query(
160166
).subquery()
161167
feed_query = db_session.query(Gtfsrealtimefeed).filter(Gtfsrealtimefeed.id.in_(subquery))
162168

163-
if not include_wip:
169+
if published_only:
164170
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published")
165171

166172
feed_query = feed_query.options(
167173
joinedload(Gtfsrealtimefeed.entitytypes),
168174
joinedload(Gtfsrealtimefeed.gtfs_feeds),
169175
*get_joinedload_options(),
170176
)
171-
if is_official:
172-
feed_query = feed_query.filter(Feed.official)
177+
feed_query = add_official_filter(feed_query, is_official)
178+
173179
feed_query = feed_query.limit(limit).offset(offset)
174180
return feed_query
175181

176182

183+
def add_official_filter(query: Query, is_official: bool | None) -> Query:
184+
"""
185+
Add the is_official filter to the query if necessary
186+
"""
187+
if is_official is not None:
188+
if is_official:
189+
query = query.filter(Feed.official.is_(True))
190+
else:
191+
query = query.filter(or_(Feed.official.is_(False), Feed.official.is_(None)))
192+
return query
193+
194+
177195
def get_all_gtfs_rt_feeds(
178196
db_session: Session,
179-
include_wip: bool = False,
197+
published_only: bool = True,
180198
batch_size: int = 250,
181199
) -> Iterator[Gtfsrealtimefeed]:
182200
"""
183201
Fetch all GTFS realtime feeds.
184202
185203
@param db_session: The database session.
186-
@param include_wip: Whether to include or exclude WIP feeds.
204+
@param published_only: Include only the published feeds.
187205
@param batch_size: The number of feeds to fetch from the database at a time.
188206
A lower value means less memory but more queries.
189207
190208
@return: The GTFS realtime feeds in an iterator.
191209
"""
192210
feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
193-
if not include_wip:
211+
if published_only:
194212
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published")
195213

196214
for batch in batched(feed_query, batch_size):
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
mdb_source_id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_offical,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
1+
mdb_source_id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_official,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
22
40,gtfs,,CA,Ontario,London,London Transit Commission,TRUE,,,[email protected],,http://www.londontransit.ca/gtfsfeed/google_transit.zip,0,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-ontario-london-transit-commission-gtfs-2.zip?alt=media,https://www.londontransit.ca/open-data/ltcs-open-data-terms-of-use/,42.905244,43.051188,-81.36311,-81.137591,2022-02-22T19:51:34+00:00,inactive,,,
33
50,gtfs,,CA,Ontario,Barrie,ZBarrie Transit,FALSE,,,,,http://www.myridebarrie.ca/gtfs/Google_transit.zip,,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-ontario-barrie-transit-gtfs-3.zip?alt=media,https://www.barrie.ca/services-payments/transportation-parking/barrie-transit/barrie-gtfs,44.3218044,44.42020676,-79.74063237,-79.61089569,2022-03-01T22:43:25+00:00,deprecated,,40|mdb-702,Some|Comment
44
702,gtfs,,CA,[British Columbia,Whistler],BC Transit (Whistler Transit System),,,,,,http://whistler.mapstrat.com/current/google_transit.zip,,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-british-columbia-bc-transit-whistler-transit-system-gtfs-702.zip?alt=media,https://www.bctransit.com/open-data/terms-of-use,50.077122,50.159071,-123.043635,-122.926836,2022-03-16T22:05:05+00:00,development,,,
5-
1562,gtfs-rt,sa,CA,BC,Vancouver,Vancouver-Transit(éèàçíóúČ),,Realtime(ŘŤÜÎ),,,40,http://foo.org/google_transit.zip,0,,,,,,,,,,active,,10,
6-
1563,gtfs-rt,tu,US,SomeState,SomeCity,SomeCity Bus,,RT,,,mdb-50,http://bar.com,0,,,,,,,,,,inactive,,10,
5+
1562,gtfs-rt,sa,CA,BC,Vancouver,Vancouver-Transit(éèàçíóúČ),True,Realtime(ŘŤÜÎ),,,40,http://foo.org/google_transit.zip,0,,,,,,,,,,active,,10,
6+
1563,gtfs-rt,tu,US,SomeState,SomeCity,SomeCity Bus,False,RT,,,mdb-50,http://bar.com,0,,,,,,,,,,inactive,,10,

api/tests/integration/test_feeds_api.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,52 @@ def test_feeds_gtfs_rt_id_get(client: TestClient):
247247

248248

249249
@pytest.mark.parametrize(
250-
"endpoint",
250+
"values",
251251
[
252-
"/v1/gtfs_feeds",
253-
"/v1/gtfs_rt_feeds",
254-
"/v1/feeds",
252+
{
253+
"endpoint": "/v1/gtfs_feeds",
254+
"all_feeds_count": 10,
255+
"false_feeds_count": 9,
256+
"true_feeds_count": 1,
257+
"official_feeds": ["mdb-40"],
258+
},
259+
{
260+
"endpoint": "/v1/gtfs_rt_feeds",
261+
"all_feeds_count": 3,
262+
"false_feeds_count": 2,
263+
"true_feeds_count": 1,
264+
"official_feeds": ["mdb-1562"],
265+
},
266+
{
267+
"endpoint": "/v1/feeds",
268+
"all_feeds_count": 13,
269+
"false_feeds_count": 11,
270+
"true_feeds_count": 2,
271+
"official_feeds": ["mdb-1562", "mdb-40"],
272+
},
255273
],
274+
ids=["gtfs_feeds", "gtfs_rt_feeds", "feeds"],
256275
)
257-
def test_feeds_filter_by_official(client: TestClient, endpoint):
258-
# 1 - Test with official=false should return all feeds
276+
def test_feeds_filter_by_official(client: TestClient, values):
277+
endpoint = values["endpoint"]
278+
official_feeds = values["official_feeds"]
279+
all_feeds_count = values["all_feeds_count"]
280+
false_feeds_count = values["false_feeds_count"]
281+
true_feeds_count = values["true_feeds_count"]
282+
283+
# 1 - Test with official not specified should return all feeds
259284
response_no_filter = client.request(
260285
"GET",
261286
endpoint,
262287
headers=authHeaders,
263288
)
264289
assert response_no_filter.status_code == 200
265290
response_no_filter_json = response_no_filter.json()
291+
assert (
292+
len(response_no_filter_json) == all_feeds_count
293+
), f"official not specified should return {all_feeds_count} feeds but got {len(response_no_filter_json)}"
294+
295+
# 2 - Test with official=false
266296
response_official_false = client.request(
267297
"GET",
268298
endpoint,
@@ -271,8 +301,14 @@ def test_feeds_filter_by_official(client: TestClient, endpoint):
271301
)
272302
assert response_official_false.status_code == 200
273303
response_official_false_json = response_official_false.json()
274-
assert response_no_filter_json == response_official_false_json, "official=false parameter should return all feeds"
275-
# 2 - Test with official=true should return at least one feed
304+
assert (
305+
len(response_official_false_json) == false_feeds_count
306+
), f"official=false should return {false_feeds_count} feeds but got {len(response_official_false_json)}"
307+
assert not any(
308+
response["id"] in official_feeds for response in response_official_false_json
309+
), f"official=false expected no feed with stable_id {official_feeds} since it is official"
310+
311+
# 3 - Test with official=true
276312
response = client.request(
277313
"GET",
278314
endpoint,
@@ -281,7 +317,8 @@ def test_feeds_filter_by_official(client: TestClient, endpoint):
281317
)
282318
assert response.status_code == 200
283319
json_response = response.json()
284-
assert len(json_response) < len(response_no_filter_json), "Not all feeds are official"
320+
assert len(json_response) == true_feeds_count, f"official=true should return {true_feeds_count} feeds"
321+
assert json_response[0]["id"] in official_feeds, f"official=true should return {official_feeds}"
285322

286323

287324
def test_non_existent_gtfs_rt_feed_get(client: TestClient):

0 commit comments

Comments
 (0)