Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from typing import List, Union, TypeVar

from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.query import Query
Expand Down Expand Up @@ -39,7 +40,6 @@
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from middleware.request_context import is_user_email_restricted
from sqlalchemy import or_
from utils.date_utils import valid_iso_date
from utils.location_translation import (
create_location_translation_object,
Expand Down Expand Up @@ -89,12 +89,15 @@ def get_feeds(
status: str,
provider: str,
producer_url: str,
is_official: bool,
) -> List[BasicFeed]:
"""Get some (or all) feeds from the Mobility Database."""
feed_filter = FeedFilter(
status=status, provider__ilike=provider, producer_url__ilike=producer_url, stable_id=None
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
feed_query = feed_query.filter(
or_(
Expand Down Expand Up @@ -221,6 +224,7 @@ def get_gtfs_feeds(
dataset_latitudes: str,
dataset_longitudes: str,
bounding_filter_method: str,
is_official: bool,
) -> List[GtfsFeed]:
"""Get some (or all) GTFS feeds from the Mobility Database."""
gtfs_feed_filter = GtfsFeedFilter(
Expand Down Expand Up @@ -258,9 +262,10 @@ def get_gtfs_feeds(
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsfeed.provider, Gtfsfeed.stable_id)
.limit(limit)
.offset(offset)
)
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsFeedImpl)

def get_gtfs_rt_feed(
Expand Down Expand Up @@ -311,6 +316,7 @@ def get_gtfs_rt_feeds(
country_code: str,
subdivision_name: str,
municipality: str,
is_official: bool,
) -> List[GtfsRTFeed]:
"""Get some (or all) GTFS Realtime feeds from the Mobility Database."""
entity_types_list = entity_types.split(",") if entity_types else None
Expand Down Expand Up @@ -359,9 +365,10 @@ def get_gtfs_rt_feeds(
*BasicFeedImpl.get_joinedload_options(),
)
.order_by(Gtfsrealtimefeed.provider, Gtfsrealtimefeed.stable_id)
.limit(limit)
.offset(offset)
)
if is_official:
feed_query = feed_query.filter(Feed.official)
feed_query = feed_query.limit(limit).offset(offset)
return self._get_response(feed_query, GtfsRTFeedImpl)

@staticmethod
Expand Down
11 changes: 10 additions & 1 deletion api/src/feeds/impl/models/basic_feed_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ class Config:
def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
if not feed:
return None
latest_official_status = None
if len(feed.officialstatushistories) > 0:
latest_official_status = max(feed.officialstatushistories, key=lambda x: x.timestamp).is_official
return cls(
id=feed.stable_id,
data_type=feed.data_type,
status=feed.status,
official=latest_official_status,
created_at=feed.created_at,
external_ids=sorted(
[ExternalIdImpl.from_orm(item) for item in feed.externalids], key=lambda x: x.external_id
Expand All @@ -48,7 +52,12 @@ def from_orm(cls, feed: Feed | None, _=None) -> BasicFeed | None:
@staticmethod
def get_joinedload_options() -> [_AbstractLoad]:
"""Returns common joinedload options for feeds queries."""
return [joinedload(Feed.locations), joinedload(Feed.externalids), joinedload(Feed.redirectingids)]
return [
joinedload(Feed.locations),
joinedload(Feed.externalids),
joinedload(Feed.redirectingids),
joinedload(Feed.officialstatushistories),
]


class BasicFeedImpl(BaseFeedImpl, BasicFeed):
Expand Down
2 changes: 2 additions & 0 deletions api/src/feeds/impl/models/search_feed_item_result_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def from_orm_gtfs(cls, feed_search_row):
external_ids=feed_search_row.external_ids,
provider=feed_search_row.provider,
feed_name=feed_search_row.feed_name,
official=feed_search_row.official,
note=feed_search_row.note,
feed_contact_email=feed_search_row.feed_contact_email,
source_info=SourceInfo(
Expand Down Expand Up @@ -58,6 +59,7 @@ def from_orm_gtfs_rt(cls, feed_search_row):
external_ids=feed_search_row.external_ids,
provider=feed_search_row.provider,
feed_name=feed_search_row.feed_name,
official=feed_search_row.official,
note=feed_search_row.note,
feed_contact_email=feed_search_row.feed_contact_email,
source_info=SourceInfo(
Expand Down
25 changes: 18 additions & 7 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_parsed_search_tsquery(search_query: str) -> str:
return func.plainto_tsquery("english", unaccent(parsed_query))

@staticmethod
def add_search_query_filters(query, search_query, data_type, feed_id, status) -> Query:
def add_search_query_filters(query, search_query, data_type, feed_id, status, is_official) -> Query:
"""
Add filters to the search query.
Filter values are trimmed and converted to lowercase.
Expand All @@ -53,22 +53,32 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
status_list = [s.strip().lower() for s in status[0].split(",") if s]
if status_list:
query = query.where(t_feedsearch.c.status.in_([s.strip().lower() for s in status_list]))
if is_official is not None and is_official:
query = query.where(t_feedsearch.c.official == is_official)
if search_query and len(search_query.strip()) > 0:
query = query.filter(
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))
)
return query

@staticmethod
def create_count_search_query(status: List[str], feed_id: str, data_type: str, search_query: str) -> Query:
def create_count_search_query(
status: List[str],
feed_id: str,
data_type: str,
is_official: bool,
search_query: str,
) -> Query:
"""
Create a search query for the database.
"""
query = select(func.count(t_feedsearch.c.feed_id))
return SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status)
return SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)

@staticmethod
def create_search_query(status: List[str], feed_id: str, data_type: str, search_query: str) -> Query:
def create_search_query(
status: List[str], feed_id: str, data_type: str, is_official: bool, search_query: str
) -> Query:
"""
Create a search query for the database.
"""
Expand All @@ -80,7 +90,7 @@ def create_search_query(status: List[str], feed_id: str, data_type: str, search_
rank_expression,
*feed_search_columns,
)
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status)
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
return query.order_by(rank_expression.desc())

def search_feeds(
Expand All @@ -90,17 +100,18 @@ def search_feeds(
status: List[str],
feed_id: str,
data_type: str,
is_official: bool,
search_query: str,
) -> SearchFeeds200Response:
"""Search feeds using full-text search on feed, location and provider's information."""
query = self.create_search_query(status, feed_id, data_type, search_query)
query = self.create_search_query(status, feed_id, data_type, is_official, search_query)
feed_rows = Database().select(
query=query,
limit=limit,
offset=offset,
)
feed_total_count = Database().select(
query=self.create_count_search_query(status, feed_id, data_type, search_query),
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query),
)
if feed_rows is None or feed_total_count is None:
return SearchFeeds200Response(
Expand Down
9 changes: 9 additions & 0 deletions api/src/scripts/populate_db_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Feature,
t_feedsearch,
Location,
Officialstatushistory,
)
from scripts.populate_db import set_up_configs, DatabasePopulateHelper
from utils.logger import Logger
Expand Down Expand Up @@ -172,6 +173,14 @@ def populate_test_feeds(self, feeds_data):
)
locations.append(location)
feed.locations = locations
if "official" in feed_data:
official_status_history = Officialstatushistory(
feed_id=feed.id,
is_official=feed_data["official"],
reviewer_email="[email protected]",
timestamp=feed_data["created_at"],
)
feed.officialstatushistories.append(official_status_history)
self.db.session.add(feed)
logger.info(f"Added feed {feed.stable_id}")

Expand Down
3 changes: 3 additions & 0 deletions api/tests/integration/test_data/extra_test_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"status": "active",
"created_at": "2024-02-08T00:00:00Z",
"provider": "BlaBlaCar Bus",
"official": true,
"feed_name": "",
"note": "",
"feed_contact_email": "",
Expand Down Expand Up @@ -43,6 +44,7 @@
"status": "active",
"created_at": "2024-02-08T00:00:00Z",
"provider": "BlaBlaCar Bus",
"official": false,
"feed_name": "",
"note": "",
"feed_contact_email": "",
Expand Down Expand Up @@ -80,6 +82,7 @@
"status": "active",
"created_at": "2024-02-08T00:00:00Z",
"provider": "BlaBlaCar Bus",
"official": true,
"feed_name": "",
"note": "",
"feed_contact_email": "",
Expand Down
2 changes: 1 addition & 1 deletion api/tests/integration/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_bounding_box_disjoint(latitudes, longitudes, method, expected_found, te
def test_merge_gtfs_feed(test_database):
results = {
feed.id: feed
for feed in FeedsApiImpl().get_gtfs_feeds(None, None, None, None, None, None, None, None, None, None)
for feed in FeedsApiImpl().get_gtfs_feeds(None, None, None, None, None, None, None, None, None, None, None)
if feed.id in TEST_GTFS_FEED_STABLE_IDS
}
assert len(results) == len(TEST_GTFS_FEED_STABLE_IDS)
Expand Down
38 changes: 38 additions & 0 deletions api/tests/integration/test_feeds_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,44 @@ def test_feeds_gtfs_rt_id_get(client: TestClient):
assert response.status_code == 200


@pytest.mark.parametrize(
"endpoint",
[
"/v1/gtfs_feeds",
"/v1/gtfs_rt_feeds",
"/v1/feeds",
],
)
def test_feeds_filter_by_official(client: TestClient, endpoint):
# 1 - Test with official=false should return all feeds
response_no_filter = client.request(
"GET",
endpoint,
headers=authHeaders,
)
assert response_no_filter.status_code == 200
response_no_filter_json = response_no_filter.json()
response_official_false = client.request(
"GET",
endpoint,
headers=authHeaders,
params=[("is_official", "false")],
)
assert response_official_false.status_code == 200
response_official_false_json = response_official_false.json()
assert response_no_filter_json == response_official_false_json, "official=false parameter should return all feeds"
# 2 - Test with official=true should return at least one feed
response = client.request(
"GET",
endpoint,
headers=authHeaders,
params=[("is_official", "true")],
)
assert response.status_code == 200
json_response = response.json()
assert len(json_response) < len(response_no_filter_json), "Not all feeds are official"


def test_non_existent_gtfs_rt_feed_get(client: TestClient):
"""Test case for feeds_gtfs_rt_id_get with a non-existent feed"""
response = client.request(
Expand Down
23 changes: 23 additions & 0 deletions api/tests/integration/test_search_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,26 @@ def test_search_feeds_filter_accents(client: TestClient, values: dict):
assert len(response_body.results) == len(values["expected_ids"])
assert response_body.total == len(values["expected_ids"])
assert all(result.id in values["expected_ids"] for result in response_body.results)


def test_search_filter_by_official_status(client: TestClient):
"""
Retrieve feeds with the official status.
"""
params = [
("limit", 100),
("offset", 0),
("is_official", "true"),
]
headers = {
"Authentication": "special-key",
}
response = client.request(
"GET",
"/v1/search",
headers=headers,
params=params,
)
# Parse the response body into a Python object
response_body = SearchFeeds200Response.parse_obj(response.json())
assert response_body.total == 2, "There should be 2 official feeds in extra_test_data.json"
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, **kwargs):
data_type="gtfs",
status="active",
feed_name="feed_name",
official=None,
note="note",
feed_contact_email="feed_contact_email",
producer_url="producer_url",
Expand Down
Loading
Loading