Skip to content

Commit 9840926

Browse files
authored
Fix: Changed the API to return feeds where official is false or null (#1043)
1 parent 98e2765 commit 9840926

File tree

9 files changed

+120
-48
lines changed

9 files changed

+120
-48
lines changed

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
from sqlalchemy.orm import joinedload, Session
77
from sqlalchemy.orm.query import Query
88

9-
from shared.common.db_utils import get_gtfs_feeds_query, get_gtfs_rt_feeds_query, get_joinedload_options
9+
from shared.common.db_utils import (
10+
get_gtfs_feeds_query,
11+
get_gtfs_rt_feeds_query,
12+
get_joinedload_options,
13+
add_official_filter,
14+
)
1015
from shared.database.database import Database, with_db_session
1116
from shared.database_gen.sqlacodegen_models import (
1217
Feed,
@@ -98,8 +103,7 @@ def get_feeds(
98103
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
99104
)
100105
feed_query = feed_filter.filter(Database().get_query_model(db_session, Feed))
101-
if is_official:
102-
feed_query = feed_query.filter(Feed.official)
106+
feed_query = add_official_filter(feed_query, is_official)
103107
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
104108
feed_query = feed_query.filter(
105109
or_(
@@ -194,7 +198,7 @@ def get_gtfs_feeds(
194198
db_session: Session,
195199
) -> List[GtfsFeed]:
196200
try:
197-
include_wip = not is_user_email_restricted()
201+
published_only = is_user_email_restricted()
198202
feed_query = get_gtfs_feeds_query(
199203
limit=limit,
200204
offset=offset,
@@ -207,7 +211,7 @@ def get_gtfs_feeds(
207211
dataset_longitudes=dataset_longitudes,
208212
bounding_filter_method=bounding_filter_method,
209213
is_official=is_official,
210-
include_wip=include_wip,
214+
published_only=published_only,
211215
db_session=db_session,
212216
)
213217
except InternalHTTPException as e:
@@ -265,7 +269,7 @@ def get_gtfs_rt_feeds(
265269
) -> List[GtfsRTFeed]:
266270
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
267271
try:
268-
include_wip = not is_user_email_restricted()
272+
published_only = is_user_email_restricted()
269273
feed_query = get_gtfs_rt_feeds_query(
270274
limit=limit,
271275
offset=offset,
@@ -276,7 +280,7 @@ def get_gtfs_rt_feeds(
276280
subdivision_name=subdivision_name,
277281
municipality=municipality,
278282
is_official=is_official,
279-
include_wip=include_wip,
283+
published_only=published_only,
280284
db_session=db_session,
281285
)
282286
except InternalHTTPException as e:
@@ -328,8 +332,8 @@ def get_gtfs_rt_feeds(
328332
)
329333
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
330334
)
331-
if is_official:
332-
feed_query = feed_query.filter(Feed.official)
335+
feed_query = add_official_filter(feed_query, is_official)
336+
333337
feed_query = feed_query.limit(limit).offset(offset)
334338
return self._get_response(feed_query, GtfsRTFeedImpl)
335339

api/src/feeds/impl/search_api_impl.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,11 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status, is
5252
status_list = [s.strip().lower() for s in status[0].split(",") if s]
5353
if status_list:
5454
query = query.where(t_feedsearch.c.status.in_([s.strip().lower() for s in status_list]))
55-
if is_official is not None and is_official:
56-
query = query.where(t_feedsearch.c.official == is_official)
55+
if is_official is not None:
56+
if is_official:
57+
query = query.where(t_feedsearch.c.official.is_(True))
58+
else:
59+
query = query.where(or_(t_feedsearch.c.official.is_(False), t_feedsearch.c.official.is_(None)))
5760
if search_query and len(search_query.strip()) > 0:
5861
query = query.filter(
5962
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))

api/src/shared/common/db_utils.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def get_gtfs_feeds_query(
3838
dataset_longitudes: str | None = None,
3939
bounding_filter_method: str | None = None,
4040
is_official: bool = False,
41-
include_wip: bool = False,
41+
published_only: bool = True,
4242
include_options_for_joinedload: bool = True,
4343
) -> Query[any]:
4444
"""Get the DB query to use to retrieve the GTFS feeds.."""
@@ -64,7 +64,7 @@ def get_gtfs_feeds_query(
6464
.filter(Gtfsfeed.id.in_(subquery))
6565
.filter((Gtfsdataset.latest) | (Gtfsdataset.id == None)) # noqa: E711
6666
)
67-
if not include_wip:
67+
if published_only:
6868
feed_query = feed_query.filter(Gtfsfeed.operational_status == "published")
6969

7070
if include_options_for_joinedload:
@@ -74,29 +74,28 @@ def get_gtfs_feeds_query(
7474
.joinedload(Validationreport.notices),
7575
*get_joinedload_options(),
7676
).order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
77-
if is_official:
78-
feed_query = feed_query.filter(Feed.official)
77+
feed_query = add_official_filter(feed_query, is_official)
7978
feed_query = feed_query.limit(limit).offset(offset)
8079
return feed_query
8180

8281

8382
def get_all_gtfs_feeds(
8483
db_session: Session,
85-
include_wip: bool = False,
84+
published_only: bool = True,
8685
batch_size: int = 250,
8786
) -> Iterator[Gtfsfeed]:
8887
"""
8988
Fetch all GTFS feeds.
9089
9190
@param db_session: The database session.
92-
@param include_wip: Whether to include or exclude WIP feeds.
91+
@param published_only: Include only the published feeds.
9392
@param batch_size: The number of feeds to fetch from the database at a time.
9493
A lower value means less memory but more queries.
9594
9695
@return: The GTFS feeds in an iterator.
9796
"""
9897
feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
99-
if not include_wip:
98+
if published_only:
10099
feed_query = feed_query.filter(Gtfsfeed.operational_status == "published")
101100

102101
for batch in batched(feed_query, batch_size):
@@ -126,7 +125,7 @@ def get_gtfs_rt_feeds_query(
126125
subdivision_name: str | None,
127126
municipality: str | None,
128127
is_official: bool | None,
129-
include_wip: bool = False,
128+
published_only: bool = True,
130129
db_session: Session = None,
131130
) -> Query:
132131
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
@@ -160,37 +159,49 @@ def get_gtfs_rt_feeds_query(
160159
).subquery()
161160
feed_query = db_session.query(Gtfsrealtimefeed).filter(Gtfsrealtimefeed.id.in_(subquery))
162161

163-
if not include_wip:
162+
if published_only:
164163
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published")
165164

166165
feed_query = feed_query.options(
167166
joinedload(Gtfsrealtimefeed.entitytypes),
168167
joinedload(Gtfsrealtimefeed.gtfs_feeds),
169168
*get_joinedload_options(),
170169
)
171-
if is_official:
172-
feed_query = feed_query.filter(Feed.official)
170+
feed_query = add_official_filter(feed_query, is_official)
171+
173172
feed_query = feed_query.limit(limit).offset(offset)
174173
return feed_query
175174

176175

176+
def add_official_filter(query: Query, is_official: bool | None) -> Query:
177+
"""
178+
Add the is_official filter to the query if necessary
179+
"""
180+
if is_official is not None:
181+
if is_official:
182+
query = query.filter(Feed.official.is_(True))
183+
else:
184+
query = query.filter(or_(Feed.official.is_(False), Feed.official.is_(None)))
185+
return query
186+
187+
177188
def get_all_gtfs_rt_feeds(
178189
db_session: Session,
179-
include_wip: bool = False,
190+
published_only: bool = True,
180191
batch_size: int = 250,
181192
) -> Iterator[Gtfsrealtimefeed]:
182193
"""
183194
Fetch all GTFS realtime feeds.
184195
185196
@param db_session: The database session.
186-
@param include_wip: Whether to include or exclude WIP feeds.
197+
@param published_only: Include only the published feeds.
187198
@param batch_size: The number of feeds to fetch from the database at a time.
188199
A lower value means less memory but more queries.
189200
190201
@return: The GTFS realtime feeds in an iterator.
191202
"""
192203
feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
193-
if not include_wip:
204+
if published_only:
194205
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status == "published")
195206

196207
for batch in batched(feed_query, batch_size):
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
mdb_source_id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_offical,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
1+
mdb_source_id,data_type,entity_type,location.country_code,location.subdivision_name,location.municipality,provider,is_official,name,note,feed_contact_email,static_reference,urls.direct_download,urls.authentication_type,urls.authentication_info,urls.api_key_parameter_name,urls.latest,urls.license,location.bounding_box.minimum_latitude,location.bounding_box.maximum_latitude,location.bounding_box.minimum_longitude,location.bounding_box.maximum_longitude,location.bounding_box.extracted_on,status,features,redirect.id,redirect.comment
22
40,gtfs,,CA,Ontario,London,London Transit Commission,TRUE,,,[email protected],,http://www.londontransit.ca/gtfsfeed/google_transit.zip,0,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-ontario-london-transit-commission-gtfs-2.zip?alt=media,https://www.londontransit.ca/open-data/ltcs-open-data-terms-of-use/,42.905244,43.051188,-81.36311,-81.137591,2022-02-22T19:51:34+00:00,inactive,,,
33
50,gtfs,,CA,Ontario,Barrie,ZBarrie Transit,FALSE,,,,,http://www.myridebarrie.ca/gtfs/Google_transit.zip,,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-ontario-barrie-transit-gtfs-3.zip?alt=media,https://www.barrie.ca/services-payments/transportation-parking/barrie-transit/barrie-gtfs,44.3218044,44.42020676,-79.74063237,-79.61089569,2022-03-01T22:43:25+00:00,deprecated,,40|mdb-702,Some|Comment
44
702,gtfs,,CA,[British Columbia,Whistler],BC Transit (Whistler Transit System),,,,,,http://whistler.mapstrat.com/current/google_transit.zip,,,,https://storage.googleapis.com/storage/v1/b/mdb-latest/o/ca-british-columbia-bc-transit-whistler-transit-system-gtfs-702.zip?alt=media,https://www.bctransit.com/open-data/terms-of-use,50.077122,50.159071,-123.043635,-122.926836,2022-03-16T22:05:05+00:00,development,,,
5-
1562,gtfs-rt,sa,CA,BC,Vancouver,Vancouver-Transit(éèàçíóúČ),,Realtime(ŘŤÜÎ),,,40,http://foo.org/google_transit.zip,0,,,,,,,,,,active,,10,
6-
1563,gtfs-rt,tu,US,SomeState,SomeCity,SomeCity Bus,,RT,,,mdb-50,http://bar.com,0,,,,,,,,,,inactive,,10,
5+
1562,gtfs-rt,sa,CA,BC,Vancouver,Vancouver-Transit(éèàçíóúČ),True,Realtime(ŘŤÜÎ),,,40,http://foo.org/google_transit.zip,0,,,,,,,,,,active,,10,
6+
1563,gtfs-rt,tu,US,SomeState,SomeCity,SomeCity Bus,False,RT,,,mdb-50,http://bar.com,0,,,,,,,,,,inactive,,10,

api/tests/integration/test_feeds_api.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,22 +247,52 @@ def test_feeds_gtfs_rt_id_get(client: TestClient):
247247

248248

249249
@pytest.mark.parametrize(
250-
"endpoint",
250+
"values",
251251
[
252-
"/v1/gtfs_feeds",
253-
"/v1/gtfs_rt_feeds",
254-
"/v1/feeds",
252+
{
253+
"endpoint": "/v1/gtfs_feeds",
254+
"all_feeds_count": 10,
255+
"false_feeds_count": 9,
256+
"true_feeds_count": 1,
257+
"official_feeds": ["mdb-40"],
258+
},
259+
{
260+
"endpoint": "/v1/gtfs_rt_feeds",
261+
"all_feeds_count": 3,
262+
"false_feeds_count": 2,
263+
"true_feeds_count": 1,
264+
"official_feeds": ["mdb-1562"],
265+
},
266+
{
267+
"endpoint": "/v1/feeds",
268+
"all_feeds_count": 13,
269+
"false_feeds_count": 11,
270+
"true_feeds_count": 2,
271+
"official_feeds": ["mdb-1562", "mdb-40"],
272+
},
255273
],
274+
ids=["gtfs_feeds", "gtfs_rt_feeds", "feeds"],
256275
)
257-
def test_feeds_filter_by_official(client: TestClient, endpoint):
258-
# 1 - Test with official=false should return all feeds
276+
def test_feeds_filter_by_official(client: TestClient, values):
277+
endpoint = values["endpoint"]
278+
official_feeds = values["official_feeds"]
279+
all_feeds_count = values["all_feeds_count"]
280+
false_feeds_count = values["false_feeds_count"]
281+
true_feeds_count = values["true_feeds_count"]
282+
283+
# 1 - Test with official not specified should return all feeds
259284
response_no_filter = client.request(
260285
"GET",
261286
endpoint,
262287
headers=authHeaders,
263288
)
264289
assert response_no_filter.status_code == 200
265290
response_no_filter_json = response_no_filter.json()
291+
assert (
292+
len(response_no_filter_json) == all_feeds_count
293+
), f"official not specified should return {all_feeds_count} feeds but got {len(response_no_filter_json)}"
294+
295+
# 2 - Test with official=false
266296
response_official_false = client.request(
267297
"GET",
268298
endpoint,
@@ -271,8 +301,14 @@ def test_feeds_filter_by_official(client: TestClient, endpoint):
271301
)
272302
assert response_official_false.status_code == 200
273303
response_official_false_json = response_official_false.json()
274-
assert response_no_filter_json == response_official_false_json, "official=false parameter should return all feeds"
275-
# 2 - Test with official=true should return at least one feed
304+
assert (
305+
len(response_official_false_json) == false_feeds_count
306+
), f"official=false should return {false_feeds_count} feeds but got {len(response_official_false_json)}"
307+
assert not any(
308+
response["id"] in official_feeds for response in response_official_false_json
309+
), f"official=false expected no feed with stable_id {official_feeds} since it is official"
310+
311+
# 3 - Test with official=true
276312
response = client.request(
277313
"GET",
278314
endpoint,
@@ -281,7 +317,8 @@ def test_feeds_filter_by_official(client: TestClient, endpoint):
281317
)
282318
assert response.status_code == 200
283319
json_response = response.json()
284-
assert len(json_response) < len(response_no_filter_json), "Not all feeds are official"
320+
assert len(json_response) == true_feeds_count, f"official=true should return {true_feeds_count} feeds"
321+
assert json_response[0]["id"] in official_feeds, f"official=true should return {official_feeds}"
285322

286323

287324
def test_non_existent_gtfs_rt_feed_get(client: TestClient):

api/tests/integration/test_search_api.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -390,15 +390,29 @@ def test_search_feeds_filter_accents(client: TestClient, values: dict):
390390
assert all(result.id in values["expected_ids"] for result in response_body.results)
391391

392392

393-
def test_search_filter_by_official_status(client: TestClient):
393+
@pytest.mark.parametrize(
394+
"values",
395+
[
396+
{"official": True, "expected_count": 2},
397+
{"official": False, "expected_count": 11},
398+
{"official": None, "expected_count": 13},
399+
],
400+
ids=[
401+
"Official",
402+
"Not official",
403+
"Not specified",
404+
],
405+
)
406+
def test_search_filter_by_official_status(client: TestClient, values: dict):
394407
"""
395408
Retrieve feeds with the official status.
396409
"""
397-
params = [
398-
("limit", 100),
399-
("offset", 0),
400-
("is_official", "true"),
401-
]
410+
params = None
411+
if values["official"] is not None:
412+
params = [
413+
("is_official", str(values["official"]).lower()),
414+
]
415+
402416
headers = {
403417
"Authentication": "special-key",
404418
}
@@ -410,4 +424,7 @@ def test_search_filter_by_official_status(client: TestClient):
410424
)
411425
# Parse the response body into a Python object
412426
response_body = SearchFeeds200Response.parse_obj(response.json())
413-
assert response_body.total == 2, "There should be 2 official feeds in extra_test_data.json"
427+
expected_count = values["expected_count"]
428+
assert (
429+
response_body.total == expected_count
430+
), f"There should be {expected_count} feeds for official={values['official']}"

docs/DatabaseCatalogAPI.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,7 @@ components:
10581058
required: False
10591059
schema:
10601060
type: boolean
1061-
default: false
1061+
default: null
10621062

10631063
limit_query_param:
10641064
name: limit

functions-python/export_csv/src/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ def fetch_feeds() -> Iterator[Dict]:
145145
try:
146146
with db.start_db_session() as session:
147147
feed_count = 0
148-
for feed in get_all_gtfs_feeds(session, include_wip=False):
148+
for feed in get_all_gtfs_feeds(session, published_only=True):
149149
yield get_gtfs_feed_csv_data(feed)
150150
feed_count += 1
151151

152152
logging.info(f"Processed {feed_count} GTFS feeds.")
153153

154154
rt_feed_count = 0
155-
for feed in get_all_gtfs_rt_feeds(session, include_wip=False):
155+
for feed in get_all_gtfs_rt_feeds(session, published_only=True):
156156
yield get_gtfs_rt_feed_csv_data(feed)
157157
rt_feed_count += 1
158158

web-app/src/app/screens/Feeds/index.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ export default function Feed(): React.ReactElement {
9797
offset: paginationOffset,
9898
search_query: activeSearch,
9999
data_type: getDataTypeParamFromSelectedFeedTypes(selectedFeedTypes),
100-
is_official: isOfficialFeedSearch,
100+
is_official: isOfficialFeedSearch || undefined,
101101
// Fixed status values for now, until a status filter is implemented
102102
// Filtering out deprecated feeds
103103
status: ['active', 'inactive', 'development'],

0 commit comments

Comments
 (0)