Skip to content

Commit a66d94a

Browse files
authored
feat: added versions filtering to search (#1162)
1 parent f3e9508 commit a66d94a

File tree

7 files changed

+336
-9
lines changed

7 files changed

+336
-9
lines changed

api/src/feeds/impl/models/search_feed_item_result_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def from_orm_gbfs(cls, feed_search_row):
8080
external_ids=feed_search_row.external_ids,
8181
provider=feed_search_row.provider,
8282
feed_contact_email=feed_search_row.feed_contact_email,
83+
versions=feed_search_row.versions,
8384
source_info=SourceInfo(
8485
producer_url=feed_search_row.producer_url,
8586
authentication_type=(

api/src/feeds/impl/search_api_impl.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sqlalchemy import func, select
44
from sqlalchemy.orm import Query, Session
5-
5+
from sqlalchemy.dialects.postgresql import array
66
from shared.database.database import Database, with_db_session
77
from shared.database.sql_functions.unaccent import unaccent
88
from shared.database_gen.sqlacodegen_models import t_feedsearch
@@ -31,7 +31,9 @@ def get_parsed_search_tsquery(search_query: str) -> str:
3131
return func.plainto_tsquery("english", unaccent(parsed_query))
3232

3333
@staticmethod
34-
def add_search_query_filters(query, search_query, data_type, feed_id, status, is_official, features) -> Query:
34+
def add_search_query_filters(
35+
query, search_query, data_type, feed_id, status, is_official, features, version
36+
) -> Query:
3537
"""
3638
Add filters to the search query.
3739
Filter values are trimmed and converted to lowercase.
@@ -58,6 +60,10 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status, is
5860
query = query.where(t_feedsearch.c.official.is_(True))
5961
else:
6062
query = query.where(or_(t_feedsearch.c.official.is_(False), t_feedsearch.c.official.is_(None)))
63+
if version:
64+
versions_list = [v.strip().lower() for v in version.split(",") if v]
65+
if versions_list:
66+
query = query.where(t_feedsearch.c.versions.op("?|")(array(versions_list)))
6167
if search_query and len(search_query.strip()) > 0:
6268
query = query.filter(
6369
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))
@@ -77,15 +83,16 @@ def create_count_search_query(
7783
feed_id: str,
7884
data_type: str,
7985
is_official: bool,
80-
search_query: str,
8186
features,
87+
version: str,
88+
search_query: str,
8289
) -> Query:
8390
"""
8491
Create a search query for the database.
8592
"""
8693
query = select(func.count(t_feedsearch.c.feed_id))
8794
return SearchApiImpl.add_search_query_filters(
88-
query, search_query, data_type, feed_id, status, is_official, features
95+
query, search_query, data_type, feed_id, status, is_official, features, version
8996
)
9097

9198
@staticmethod
@@ -96,6 +103,7 @@ def create_search_query(
96103
is_official: bool,
97104
search_query: str,
98105
features: List[str],
106+
version: str,
99107
) -> Query:
100108
"""
101109
Create a search query for the database.
@@ -109,7 +117,7 @@ def create_search_query(
109117
*feed_search_columns,
110118
)
111119
query = SearchApiImpl.add_search_query_filters(
112-
query, search_query, data_type, feed_id, status, is_official, features
120+
query, search_query, data_type, feed_id, status, is_official, features, version
113121
)
114122
return query.order_by(rank_expression.desc())
115123

@@ -122,12 +130,13 @@ def search_feeds(
122130
feed_id: str,
123131
data_type: str,
124132
is_official: bool,
133+
version: str,
125134
search_query: str,
126-
features: List[str],
135+
feature: List[str],
127136
db_session: "Session",
128137
) -> SearchFeeds200Response:
129138
"""Search feeds using full-text search on feed, location and provider's information."""
130-
query = self.create_search_query(status, feed_id, data_type, is_official, search_query, features)
139+
query = self.create_search_query(status, feed_id, data_type, is_official, search_query, feature, version)
131140
feed_rows = Database().select(
132141
session=db_session,
133142
query=query,
@@ -136,7 +145,9 @@ def search_feeds(
136145
)
137146
feed_total_count = Database().select(
138147
session=db_session,
139-
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query, features),
148+
query=self.create_count_search_query(
149+
status, feed_id, data_type, is_official, feature, version, search_query
150+
),
140151
)
141152
if feed_rows is None or feed_total_count is None:
142153
return SearchFeeds200Response(

api/tests/integration/test_search_api.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,50 @@ def test_search_filter_by_official_status(client: TestClient, values: dict):
426426
), f"There should be {expected_count} feeds for official={values['official']}"
427427

428428

429+
@pytest.mark.parametrize(
430+
"values",
431+
[
432+
{"versions": "1.0", "expected_count": 0},
433+
{"versions": "2.3,3.0", "expected_count": 2},
434+
{"versions": "3.0", "expected_count": 1},
435+
{"versions": "2.3", "expected_count": 2},
436+
{"versions": None, "expected_count": 16},
437+
],
438+
ids=[
439+
"Version 1.0",
440+
"Versions 2.3 and 3.0",
441+
"Version 3.0",
442+
"Version 2.3",
443+
"No version specified",
444+
],
445+
)
446+
def test_search_filter_by_versions(client: TestClient, values: dict):
447+
"""
448+
Retrieve feeds with the version.
449+
"""
450+
params = None
451+
if values["versions"] is not None:
452+
params = [
453+
("version", values["versions"]),
454+
]
455+
456+
headers = {
457+
"Authentication": "special-key",
458+
}
459+
response = client.request(
460+
"GET",
461+
"/v1/search",
462+
headers=headers,
463+
params=params,
464+
)
465+
# Parse the response body into a Python object
466+
response_body = SearchFeeds200Response.parse_obj(response.json())
467+
expected_count = values["expected_count"]
468+
assert (
469+
response_body.total == expected_count
470+
), f"There should be {expected_count} feeds for versions={values['versions']}"
471+
472+
429473
@pytest.mark.parametrize(
430474
"values",
431475
[

docs/DatabaseCatalogAPI.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ paths:
334334
- $ref: "#/components/parameters/feed_id_query_param"
335335
- $ref: "#/components/parameters/data_type_query_param"
336336
- $ref: "#/components/parameters/is_official_query_param"
337+
- $ref: "#/components/parameters/version_query_param"
337338
- $ref: "#/components/parameters/search_text_query_param"
338339
- $ref: "#/components/parameters/feature"
339340
security:
@@ -713,6 +714,12 @@ components:
713714
* sa - service alerts
714715
# Have to put the enum inline because of a bug in openapi-generator
715716
# $ref: "#/components/schemas/EntityTypes"
717+
versions:
718+
type: array
719+
items:
720+
type: string
721+
example: 2.3
722+
description: The supported versions of the GBFS feed.
716723
feed_references:
717724
description:
718725
A list of the GTFS feeds that the real time source is associated with, represented by their MDB source IDs.
@@ -1335,6 +1342,15 @@ components:
13351342
schema:
13361343
type: string
13371344

1345+
version_query_param:
1346+
name: version
1347+
in: query
1348+
description: Comma separated list of GBFS versions to filter by.
1349+
required: False
1350+
schema:
1351+
type: string
1352+
example: 2.0,2.1
1353+
13381354
data_type_query_param:
13391355
name: data_type
13401356
in: query

liquibase/changelog.xml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@
4848
<include file="changes/feat_1055.sql" relativeToChangelogFile="true"/>
4949
<include file="changes/feat_1041.sql" relativeToChangelogFile="true"/>
5050
<include file="changes/feat_997.sql" relativeToChangelogFile="true"/>
51-
<!-- Materialized view updated. Added features and totals. -->
51+
<!-- Materialized view updated. Added features and totals. -->
5252
<include file="changes/feat_993.sql" relativeToChangelogFile="true"/>
5353
<!-- Materialized view updated. Used Feed.official field as official status. -->
5454
<include file="changes/feat_1083.sql" relativeToChangelogFile="true"/>
5555
<include file="changes/feat_1132.sql" relativeToChangelogFile="true"/>
5656
<include file="changes/feat_1124.sql" relativeToChangelogFile="true"/>
57+
<!-- Materialized view updated. Added versions of GBFS feeds-->
58+
<include file="changes/feat_1118.sql" relativeToChangelogFile="true"/>
5759
</databaseChangeLog>

0 commit comments

Comments
 (0)