Skip to content

Commit 88940c8

Browse files
authored
Merge pull request #1148 from MobilityData/1105-implement-or-filter-for-feature-search
feat: implemented OR filter for feature search
2 parents 44cf418 + 667313e commit 88940c8

File tree

7 files changed

+134
-42
lines changed

7 files changed

+134
-42
lines changed

.github/workflows/build-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ jobs:
119119

120120
- name: Upload DB models
121121
uses: actions/upload-artifact@v4
122-
if: ${{ steps.set-should-run-tests.outputs.result == 'true' }}
123122
with:
124123
name: database_gen
125124
path: api/src/shared/database_gen/

api/src/feeds/impl/models/search_feed_item_result_impl.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -33,37 +33,41 @@ def from_orm_gtfs(cls, feed_search_row: t_feedsearch):
3333
feed_contact_email=feed_search_row.feed_contact_email,
3434
source_info=SourceInfo(
3535
producer_url=feed_search_row.producer_url,
36-
authentication_type=int(feed_search_row.authentication_type)
37-
if feed_search_row.authentication_type
38-
else None,
36+
authentication_type=(
37+
int(feed_search_row.authentication_type) if feed_search_row.authentication_type else None
38+
),
3939
authentication_info_url=feed_search_row.authentication_info_url,
4040
api_key_parameter_name=feed_search_row.api_key_parameter_name,
4141
license_url=feed_search_row.license_url,
4242
),
4343
redirects=feed_search_row.redirect_ids,
4444
locations=cls.resolve_locations(feed_search_row.locations),
45-
latest_dataset=LatestDataset(
46-
id=feed_search_row.latest_dataset_id,
47-
hosted_url=feed_search_row.latest_dataset_hosted_url,
48-
downloaded_at=feed_search_row.latest_dataset_downloaded_at,
49-
hash=feed_search_row.latest_dataset_hash,
50-
service_date_range_start=feed_search_row.latest_dataset_service_date_range_start,
51-
service_date_range_end=feed_search_row.latest_dataset_service_date_range_end,
52-
agency_timezone=feed_search_row.latest_dataset_agency_timezone,
53-
validation_report=LatestDatasetValidationReport(
54-
total_error=feed_search_row.latest_total_error,
55-
total_warning=feed_search_row.latest_total_warning,
56-
total_info=feed_search_row.latest_total_info,
57-
unique_error_count=feed_search_row.latest_unique_error_count,
58-
unique_warning_count=feed_search_row.latest_unique_warning_count,
59-
unique_info_count=feed_search_row.latest_unique_info_count,
60-
features=sorted([feature for feature in feed_search_row.latest_dataset_features])
61-
if feed_search_row.latest_dataset_features
62-
else [],
63-
),
64-
)
65-
if feed_search_row.latest_dataset_id
66-
else None,
45+
latest_dataset=(
46+
LatestDataset(
47+
id=feed_search_row.latest_dataset_id,
48+
hosted_url=feed_search_row.latest_dataset_hosted_url,
49+
downloaded_at=feed_search_row.latest_dataset_downloaded_at,
50+
hash=feed_search_row.latest_dataset_hash,
51+
service_date_range_start=feed_search_row.latest_dataset_service_date_range_start,
52+
service_date_range_end=feed_search_row.latest_dataset_service_date_range_end,
53+
agency_timezone=feed_search_row.latest_dataset_agency_timezone,
54+
validation_report=LatestDatasetValidationReport(
55+
total_error=feed_search_row.latest_total_error,
56+
total_warning=feed_search_row.latest_total_warning,
57+
total_info=feed_search_row.latest_total_info,
58+
unique_error_count=feed_search_row.latest_unique_error_count,
59+
unique_warning_count=feed_search_row.latest_unique_warning_count,
60+
unique_info_count=feed_search_row.latest_unique_info_count,
61+
features=(
62+
sorted([feature for feature in feed_search_row.latest_dataset_features])
63+
if feed_search_row.latest_dataset_features
64+
else []
65+
),
66+
),
67+
)
68+
if feed_search_row.latest_dataset_id
69+
else None
70+
),
6771
)
6872

6973
@classmethod
@@ -78,9 +82,9 @@ def from_orm_gbfs(cls, feed_search_row):
7882
feed_contact_email=feed_search_row.feed_contact_email,
7983
source_info=SourceInfo(
8084
producer_url=feed_search_row.producer_url,
81-
authentication_type=int(feed_search_row.authentication_type)
82-
if feed_search_row.authentication_type
83-
else None,
85+
authentication_type=(
86+
int(feed_search_row.authentication_type) if feed_search_row.authentication_type else None
87+
),
8488
authentication_info_url=feed_search_row.authentication_info_url,
8589
api_key_parameter_name=feed_search_row.api_key_parameter_name,
8690
license_url=feed_search_row.license_url,
@@ -104,9 +108,9 @@ def from_orm_gtfs_rt(cls, feed_search_row):
104108
feed_contact_email=feed_search_row.feed_contact_email,
105109
source_info=SourceInfo(
106110
producer_url=feed_search_row.producer_url,
107-
authentication_type=int(feed_search_row.authentication_type)
108-
if feed_search_row.authentication_type
109-
else None,
111+
authentication_type=(
112+
int(feed_search_row.authentication_type) if feed_search_row.authentication_type else None
113+
),
110114
authentication_info_url=feed_search_row.authentication_info_url,
111115
api_key_parameter_name=feed_search_row.api_key_parameter_name,
112116
license_url=feed_search_row.license_url,
@@ -125,9 +129,9 @@ def resolve_locations(cls, locations):
125129
return [
126130
{
127131
**location,
128-
"country": location.get("country")
129-
if location.get("country")
130-
else cls.resolve_country_by_code(location),
132+
"country": (
133+
location.get("country") if location.get("country") else cls.resolve_country_by_code(location)
134+
),
131135
}
132136
for location in locations
133137
]

api/src/feeds/impl/search_api_impl.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_parsed_search_tsquery(search_query: str) -> str:
3131
return func.plainto_tsquery("english", unaccent(parsed_query))
3232

3333
@staticmethod
34-
def add_search_query_filters(query, search_query, data_type, feed_id, status, is_official) -> Query:
34+
def add_search_query_filters(query, search_query, data_type, feed_id, status, is_official, features) -> Query:
3535
"""
3636
Add filters to the search query.
3737
Filter values are trimmed and converted to lowercase.
@@ -62,6 +62,13 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status, is
6262
query = query.filter(
6363
t_feedsearch.c.document.op("@@")(SearchApiImpl.get_parsed_search_tsquery(search_query))
6464
)
65+
# Add feature filter with OR logic
66+
if features:
67+
features_list = [s.strip() for s in features[0].split(",") if s]
68+
if features_list:
69+
query = query.filter(
70+
t_feedsearch.c.latest_dataset_features.op("&&")(features_list)
71+
) # overlap: Test if elements are a superset of the elements of the argument array expression.
6572
return query
6673

6774
@staticmethod
@@ -71,16 +78,24 @@ def create_count_search_query(
7178
data_type: str,
7279
is_official: bool,
7380
search_query: str,
81+
features,
7482
) -> Query:
7583
"""
7684
Create a search query for the database.
7785
"""
7886
query = select(func.count(t_feedsearch.c.feed_id))
79-
return SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
87+
return SearchApiImpl.add_search_query_filters(
88+
query, search_query, data_type, feed_id, status, is_official, features
89+
)
8090

8191
@staticmethod
8292
def create_search_query(
83-
status: List[str], feed_id: str, data_type: str, is_official: bool, search_query: str
93+
status: List[str],
94+
feed_id: str,
95+
data_type: str,
96+
is_official: bool,
97+
search_query: str,
98+
features: List[str],
8499
) -> Query:
85100
"""
86101
Create a search query for the database.
@@ -93,7 +108,9 @@ def create_search_query(
93108
rank_expression,
94109
*feed_search_columns,
95110
)
96-
query = SearchApiImpl.add_search_query_filters(query, search_query, data_type, feed_id, status, is_official)
111+
query = SearchApiImpl.add_search_query_filters(
112+
query, search_query, data_type, feed_id, status, is_official, features
113+
)
97114
return query.order_by(rank_expression.desc())
98115

99116
@with_db_session
@@ -106,10 +123,11 @@ def search_feeds(
106123
data_type: str,
107124
is_official: bool,
108125
search_query: str,
126+
features: List[str],
109127
db_session: "Session",
110128
) -> SearchFeeds200Response:
111129
"""Search feeds using full-text search on feed, location and provider's information."""
112-
query = self.create_search_query(status, feed_id, data_type, is_official, search_query)
130+
query = self.create_search_query(status, feed_id, data_type, is_official, search_query, features)
113131
feed_rows = Database().select(
114132
session=db_session,
115133
query=query,
@@ -118,7 +136,7 @@ def search_feeds(
118136
)
119137
feed_total_count = Database().select(
120138
session=db_session,
121-
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query),
139+
query=self.create_count_search_query(status, feed_id, data_type, is_official, search_query, features),
122140
)
123141
if feed_rows is None or feed_total_count is None:
124142
return SearchFeeds200Response(

api/tests/integration/test_search_api.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,3 +424,56 @@ def test_search_filter_by_official_status(client: TestClient, values: dict):
424424
assert (
425425
response_body.total == expected_count
426426
), f"There should be {expected_count} feeds for official={values['official']}"
427+
428+
429+
@pytest.mark.parametrize(
430+
"values",
431+
[
432+
{"feature": "", "expected_count": 16},
433+
{"feature": "Bike Allowed", "expected_count": 2},
434+
{"feature": "Stops Wheelchair Accessibility", "expected_count": 0},
435+
],
436+
ids=[
437+
"All",
438+
"Bike Allowed",
439+
"Stops Wheelchair Accessibility",
440+
],
441+
)
442+
def test_search_filter_by_feature(client: TestClient, values: dict):
443+
"""
444+
Retrieve feeds that contain specific features.
445+
"""
446+
params = None
447+
if values["feature"] is not None:
448+
params = [
449+
("feature", values["feature"]),
450+
]
451+
452+
headers = {
453+
"Authentication": "special-key",
454+
}
455+
response = client.request(
456+
"GET",
457+
"/v1/search",
458+
headers=headers,
459+
params=params,
460+
)
461+
# Assert the status code of the HTTP response
462+
assert response.status_code == 200
463+
464+
# Parse the response body into a Python object
465+
response_body = SearchFeeds200Response.model_validate(response.json())
466+
expected_count = values["expected_count"]
467+
assert (
468+
response_body.total == expected_count
469+
), f"There should be {expected_count} feeds with feature={values['feature']}"
470+
471+
# Verify all returned feeds have at least one of the requested features
472+
if values["feature"] and expected_count > 0:
473+
requested_features = set(values["feature"].split(","))
474+
for result in response_body.results:
475+
features = result.latest_dataset.validation_report.features
476+
# Check that at least one of the feed's features is in the requested features
477+
assert requested_features.intersection(features), (
478+
f"Feed {result.id} with features {features} does not match " f"requested features {requested_features}"
479+
)

docs/DatabaseCatalogAPI.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ paths:
335335
- $ref: "#/components/parameters/data_type_query_param"
336336
- $ref: "#/components/parameters/is_official_query_param"
337337
- $ref: "#/components/parameters/search_text_query_param"
338+
- $ref: "#/components/parameters/feature"
338339
security:
339340
- Authentication: []
340341
responses:
@@ -1120,6 +1121,17 @@ components:
11201121
- inactive
11211122
- development
11221123
- future
1124+
feature:
1125+
name: feature
1126+
in: query
1127+
description: Filter feeds by their GTFS features. [GTFS features definitions defined here](https://gtfs.org/getting-started/features/overview)
1128+
required: false
1129+
style: form
1130+
explode: false
1131+
schema:
1132+
type: array
1133+
items:
1134+
type: string
11231135
provider:
11241136
name: provider
11251137
in: query

web-app/src/app/screens/Feeds/index.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,13 @@ export default function Feed(): React.ReactElement {
121121
selectedFeedTypes,
122122
config.enableGbfsInSearchPage,
123123
),
124-
is_official: isOfficialFeedSearch || undefined,
124+
is_official: isOfficialTagFilterEnabled
125+
? isOfficialFeedSearch || undefined
126+
: undefined,
125127
// Fixed status values for now, until a status filter is implemented
126128
// Filtering out deprecated feeds
127129
status: ['active', 'inactive', 'development', 'future'],
130+
feature: areFeatureFiltersEnabled ? selectedFeatures : undefined,
128131
},
129132
},
130133
}),

web-app/src/app/services/feeds/types.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,8 @@ export interface components {
643643
statuses?: Array<
644644
'active' | 'deprecated' | 'inactive' | 'development' | 'future'
645645
>;
646+
/** @description Filter feeds by their GTFS features. [GTFS features definitions defined here](https://gtfs.org/getting-started/features/overview) */
647+
feature?: string[];
646648
/** @description List only feeds with the specified value. Can be a partial match. Case insensitive. */
647649
provider?: string;
648650
/** @description List only feeds with the specified value. Can be a partial match. Case insensitive. */
@@ -991,6 +993,7 @@ export interface operations {
991993
data_type?: components['parameters']['data_type_query_param'];
992994
is_official?: components['parameters']['is_official_query_param'];
993995
search_query?: components['parameters']['search_text_query_param'];
996+
feature?: components['parameters']['feature'];
994997
};
995998
};
996999
responses: {

0 commit comments

Comments
 (0)