Skip to content

Commit 814887b

Browse files
authored
feat: added official value in endpoint params and responses (#833)
1 parent 73ea4bc commit 814887b

File tree

15 files changed

+336
-14
lines changed

15 files changed

+336
-14
lines changed

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
from typing import List, Union, TypeVar
33

4+
from sqlalchemy import or_
45
from sqlalchemy import select
56
from sqlalchemy.orm import joinedload
67
from sqlalchemy.orm.query import Query
@@ -39,7 +40,6 @@
3940
from feeds_gen.models.gtfs_feed import GtfsFeed
4041
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
4142
from middleware.request_context import is_user_email_restricted
42-
from sqlalchemy import or_
4343
from utils.date_utils import valid_iso_date
4444
from utils.location_translation import (
4545
create_location_translation_object,
@@ -96,6 +96,7 @@ def get_feeds(
9696
status: str,
9797
provider: str,
9898
producer_url: str,
99+
is_official: bool,
99100
) -> List[BasicFeed]:
100101
"""Get some (or all) feeds from the Mobility Database."""
101102
is_email_restricted = is_user_email_restricted()
@@ -104,6 +105,8 @@ def get_feeds(
104105
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
105106
)
106107
feed_query = feed_filter.filter(Database().get_query_model(Feed))
108+
if is_official:
109+
feed_query = feed_query.filter(Feed.official)
107110
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
108111
feed_query = feed_query.filter(
109112
or_(
@@ -230,6 +233,7 @@ def get_gtfs_feeds(
230233
dataset_latitudes: str,
231234
dataset_longitudes: str,
232235
bounding_filter_method: str,
236+
is_official: bool,
233237
) -> List[GtfsFeed]:
234238
"""Get some (or all) GTFS feeds from the Mobility Database."""
235239
gtfs_feed_filter = GtfsFeedFilter(
@@ -269,9 +273,10 @@ def get_gtfs_feeds(
269273
*BasicFeedImpl.get_joinedload_options(),
270274
)
271275
.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
272-
.limit(limit)
273-
.offset(offset)
274276
)
277+
if is_official:
278+
feed_query = feed_query.filter(Feed.official)
279+
feed_query = feed_query.limit(limit).offset(offset)
275280
return self._get_response(feed_query, GtfsFeedImpl)
276281

277282
def get_gtfs_rt_feed(
@@ -322,6 +327,7 @@ def get_gtfs_rt_feeds(
322327
country_code: str,
323328
subdivision_name: str,
324329
municipality: str,
330+
is_official: bool,
325331
) -> List[GtfsRTFeed]:
326332
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
327333
entity_types_list = entity_types.split(",") if entity_types else None
@@ -370,9 +376,10 @@ def get_gtfs_rt_feeds(
370376
*BasicFeedImpl.get_joinedload_options(),
371377
)
372378
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
373-
.limit(limit)
374-
.offset(offset)
375379
)
380+
if is_official:
381+
feed_query = feed_query.filter(Feed.official)
382+
feed_query = feed_query.limit(limit).offset(offset)
376383
return self._get_response(feed_query, GtfsRTFeedImpl)
377384

378385
@staticmethod

api/src/feeds/impl/models/basic_feed_impl.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@ class Config:
2323
def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
2424
if not feed:
2525
return None
26+
latest_official_status = None
27+
if len(feed.officialstatushistories) > 0:
28+
latest_official_status = max(feed.officialstatushistories, key=lambda x: x.timestamp).is_official
2629
return cls(
2730
id=feed.stable_id,
2831
data_type=feed.data_type,
2932
status=feed.status,
33+
official=latest_official_status,
3034
created_at=feed.created_at,
3135
external_ids=sorted(
3236
[ExternalIdImpl.from_orm(item) for item in feed.externalids], key=lambda x: x.external_id
@@ -48,7 +52,12 @@ def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
4852
@staticmethod
4953
def get_joinedload_options() -> [_AbstractLoad]:
5054
"""Returns common joinedload options for feeds queries."""
51-
return [joinedload(Feed.locations), joinedload(Feed.externalids), joinedload(Feed.redirectingids)]
55+
return [
56+
joinedload(Feed.locations),
57+
joinedload(Feed.externalids),
58+
joinedload(Feed.redirectingids),
59+
joinedload(Feed.officialstatushistories),
60+
]
5261

5362

5463
class BasicFeedImpl(BaseFeedImpl, BasicFeed):

api/src/feeds/impl/models/search_feed_item_result_impl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def from_orm_gtfs(cls, feed_search_row):
2525
external_ids=feed_search_row.external_ids,
2626
provider=feed_search_row.provider,
2727
feed_name=feed_search_row.feed_name,
28+
official=feed_search_row.official,
2829
note=feed_search_row.note,
2930
feed_contact_email=feed_search_row.feed_contact_email,
3031
source_info=SourceInfo(
@@ -58,6 +59,7 @@ def from_orm_gtfs_rt(cls, feed_search_row):
5859
external_ids=feed_search_row.external_ids,
5960
provider=feed_search_row.provider,
6061
feed_name=feed_search_row.feed_name,
62+
official=feed_search_row.official,
6163
note=feed_search_row.note,
6264
feed_contact_email=feed_search_row.feed_contact_email,
6365
source_info=SourceInfo(

api/src/feeds/impl/search_api_impl.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_parsed_search_tsquery(search_query: str) -> str:
3131
return func.plainto_tsquery("english", unaccent(parsed_query))
3232

3333
@staticmethod
34-
def add_search_query_filters(query, search_query, data_type, feed_id, status) -> Query:
34+
def add_search_query_filters(query, search_query, data_type, feed_id, status, is_official) -> Query:
3535
"""
3636
Add filters to the search query.
3737
Filter values are trimmed and converted to lowercase.
@@ -53,22 +53,32 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
5353
status_list = [s.strip().lower() for s in status[0].split(",") if s]
5454
if status_list:
5555
query = query.where(t_feedsearch.c.status.in_([s.strip().lower() for s in status_list]))
56+
if is_official is not None and is_official:
57+
query = query.where(t_feedsearch.c.official == is_official)
5658
if search_query and len(search_query.strip()) > 0:
5759
query = query.filter(
5860
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))
5961
)
6062
return query
6163

6264
@staticmethod
63-
def create_count_search_query(status: List[str], feed_id: str, data_type: str, search_query: str) -> Query:
65+
def create_count_search_query(
66+
status: List[str],
67+
feed_id: str,
68+
data_type: str,
69+
is_official: bool,
70+
search_query: str,
71+
) -> Query:
6472
"""
6573
Create a search query for the database.
6674
"""
6775
query = select(func.count(t_feedsearch.c.feed_id))
68-
return SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status)
76+
return SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
6977

7078
@staticmethod
71-
def create_search_query(status: List[str], feed_id: str, data_type: str, search_query: str) -> Query:
79+
def create_search_query(
80+
status: List[str], feed_id: str, data_type: str, is_official: bool, search_query: str
81+
) -> Query:
7282
"""
7383
Create a search query for the database.
7484
"""
@@ -80,7 +90,7 @@ def create_search_query(status: List[str], feed_id: str, data_type: str, search_
8090
rank_expression,
8191
*feed_search_columns,
8292
)
83-
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status)
93+
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
8494
return query.order_by(rank_expression.desc())
8595

8696
def search_feeds(
@@ -90,17 +100,18 @@ def search_feeds(
90100
status: List[str],
91101
feed_id: str,
92102
data_type: str,
103+
is_official: bool,
93104
search_query: str,
94105
) -> SearchFeeds200Response:
95106
"""Search feeds using full-text search on feed, location and provider's information."""
96-
query = self.create_search_query(status, feed_id, data_type, search_query)
107+
query = self.create_search_query(status, feed_id, data_type, is_official, search_query)
97108
feed_rows = Database().select(
98109
query=query,
99110
limit=limit,
100111
offset=offset,
101112
)
102113
feed_total_count = Database().select(
103-
query=self.create_count_search_query(status, feed_id, data_type, search_query),
114+
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query),
104115
)
105116
if feed_rows is None or feed_total_count is None:
106117
return SearchFeeds200Response(

api/src/scripts/populate_db_test_data.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Feature,
1515
t_feedsearch,
1616
Location,
17+
Officialstatushistory,
1718
)
1819
from scripts.populate_db import set_up_configs, DatabasePopulateHelper
1920
from utils.logger import Logger
@@ -172,6 +173,14 @@ def populate_test_feeds(self, feeds_data):
172173
)
173174
locations.append(location)
174175
feed.locations = locations
176+
if "official" in feed_data:
177+
official_status_history = Officialstatushistory(
178+
feed_id=feed.id,
179+
is_official=feed_data["official"],
180+
reviewer_email="[email protected]",
181+
timestamp=feed_data["created_at"],
182+
)
183+
feed.officialstatushistories.append(official_status_history)
175184
self.db.session.add(feed)
176185
logger.info(f"Added feed {feed.stable_id}")
177186

api/tests/integration/test_data/extra_test_data.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"status": "active",
77
"created_at": "2024-02-08T00:00:00Z",
88
"provider": "BlaBlaCar Bus",
9+
"official": true,
910
"feed_name": "",
1011
"note": "",
1112
"feed_contact_email": "",
@@ -43,6 +44,7 @@
4344
"status": "active",
4445
"created_at": "2024-02-08T00:00:00Z",
4546
"provider": "BlaBlaCar Bus",
47+
"official": false,
4648
"feed_name": "",
4749
"note": "",
4850
"feed_contact_email": "",
@@ -80,6 +82,7 @@
8082
"status": "active",
8183
"created_at": "2024-02-08T00:00:00Z",
8284
"provider": "BlaBlaCar Bus",
85+
"official": true,
8386
"feed_name": "",
8487
"note": "",
8588
"feed_contact_email": "",

api/tests/integration/test_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_bounding_box_disjoint(latitudes, longitudes, method, expected_found, te
100100
def test_merge_gtfs_feed(test_database):
101101
results = {
102102
feed.id: feed
103-
for feed in FeedsApiImpl().get_gtfs_feeds(None, None, None, None, None, None, None, None, None, None)
103+
for feed in FeedsApiImpl().get_gtfs_feeds(None, None, None, None, None, None, None, None, None, None, None)
104104
if feed.id in TEST_GTFS_FEED_STABLE_IDS
105105
}
106106
assert len(results) == len(TEST_GTFS_FEED_STABLE_IDS)

api/tests/integration/test_feeds_api.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,44 @@ def test_feeds_gtfs_rt_id_get(client: TestClient):
246246
assert response.status_code == 200
247247

248248

249+
@pytest.mark.parametrize(
250+
"endpoint",
251+
[
252+
"/v1/gtfs_feeds",
253+
"/v1/gtfs_rt_feeds",
254+
"/v1/feeds",
255+
],
256+
)
257+
def test_feeds_filter_by_official(client: TestClient, endpoint):
258+
# 1 - Test with official=false should return all feeds
259+
response_no_filter = client.request(
260+
"GET",
261+
endpoint,
262+
headers=authHeaders,
263+
)
264+
assert response_no_filter.status_code == 200
265+
response_no_filter_json = response_no_filter.json()
266+
response_official_false = client.request(
267+
"GET",
268+
endpoint,
269+
headers=authHeaders,
270+
params=[("is_official", "false")],
271+
)
272+
assert response_official_false.status_code == 200
273+
response_official_false_json = response_official_false.json()
274+
assert response_no_filter_json == response_official_false_json, "official=false parameter should return all feeds"
275+
# 2 - Test with official=true should return at least one feed
276+
response = client.request(
277+
"GET",
278+
endpoint,
279+
headers=authHeaders,
280+
params=[("is_official", "true")],
281+
)
282+
assert response.status_code == 200
283+
json_response = response.json()
284+
assert len(json_response) < len(response_no_filter_json), "Not all feeds are official"
285+
286+
249287
def test_non_existent_gtfs_rt_feed_get(client: TestClient):
250288
"""Test case for feeds_gtfs_rt_id_get with a non-existent feed"""
251289
response = client.request(

api/tests/integration/test_search_api.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,3 +388,26 @@ def test_search_feeds_filter_accents(client: TestClient, values: dict):
388388
assert len(response_body.results) == len(values["expected_ids"])
389389
assert response_body.total == len(values["expected_ids"])
390390
assert all(result.id in values["expected_ids"] for result in response_body.results)
391+
392+
393+
def test_search_filter_by_official_status(client: TestClient):
394+
"""
395+
Retrieve feeds with the official status.
396+
"""
397+
params = [
398+
("limit", 100),
399+
("offset", 0),
400+
("is_official", "true"),
401+
]
402+
headers = {
403+
"Authentication": "special-key",
404+
}
405+
response = client.request(
406+
"GET",
407+
"/v1/search",
408+
headers=headers,
409+
params=params,
410+
)
411+
# Parse the response body into a Python object
412+
response_body = SearchFeeds200Response.parse_obj(response.json())
413+
assert response_body.total == 2, "There should be 2 official feeds in extra_test_data.json"

api/tests/unittest/models/test_search_feed_item_result_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(self, **kwargs):
2323
data_type="gtfs",
2424
status="active",
2525
feed_name="feed_name",
26+
official=None,
2627
note="note",
2728
feed_contact_email="feed_contact_email",
2829
producer_url="producer_url",

0 commit comments

Comments
 (0)