diff --git a/.github/workflows/integration-tests-pr.yml b/.github/workflows/integration-tests-pr.yml
index ffe6e570f..5a79bd6bd 100644
--- a/.github/workflows/integration-tests-pr.yml
+++ b/.github/workflows/integration-tests-pr.yml
@@ -126,7 +126,7 @@ jobs:
- name: Start API
run: |
- scripts/api-start.sh &
+ scripts/api-start.sh > api_logs.txt 2>&1 & # Redirect stdout and stderr to api_logs.txt
sleep 10 # Wait for the API to start
- name: Health Check
@@ -150,3 +150,4 @@ jobs:
path: |
integration-tests/src/integration_tests_log.html
integration-tests/src/datasets_validation.csv
+ api_logs.txt
diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py
index ca77c23ed..998090152 100644
--- a/api/src/feeds/impl/feeds_api_impl.py
+++ b/api/src/feeds/impl/feeds_api_impl.py
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import List, Union, TypeVar
+
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.query import Query
@@ -37,6 +38,8 @@
from feeds_gen.models.gtfs_dataset import GtfsDataset
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
+from middleware.request_context import is_user_email_restricted
+from sqlalchemy import or_
from utils.date_utils import valid_iso_date
from utils.location_translation import (
create_location_translation_object,
@@ -65,6 +68,13 @@ def get_feed(
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
+ .filter(
+ or_(
+ Feed.operational_status == None, # noqa: E711
+ Feed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
.first()
)
if feed:
@@ -86,6 +96,13 @@ def get_feeds(
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
+ feed_query = feed_query.filter(
+ or_(
+ Feed.operational_status == None, # noqa: E711
+ Feed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
# Results are sorted by provider
feed_query = feed_query.order_by(Feed.provider, Feed.stable_id)
feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options())
@@ -118,6 +135,13 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
producer_url__ilike=None,
)
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
+ .filter(
+ or_(
+ Gtfsfeed.operational_status == None, # noqa: E711
+ Gtfsfeed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
.outerjoin(Location, Feed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
@@ -156,6 +180,13 @@ def get_gtfs_feed_datasets(
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
+ .filter(
+ or_(
+ Feed.operational_status == None, # noqa: E711
+ Feed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
.first()
)
@@ -213,6 +244,13 @@ def get_gtfs_feeds(
.get_session()
.query(Gtfsfeed)
.filter(Gtfsfeed.id.in_(subquery))
+ .filter(
+ or_(
+ Gtfsfeed.operational_status == None, # noqa: E711
+ Gtfsfeed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
@@ -241,6 +279,13 @@ def get_gtfs_rt_feed(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
+ .filter(
+ or_(
+ Gtfsrealtimefeed.operational_status == None, # noqa: E711
+ Gtfsrealtimefeed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
@@ -301,6 +346,13 @@ def get_gtfs_rt_feeds(
.get_session()
.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.id.in_(subquery))
+ .filter(
+ or_(
+ Gtfsrealtimefeed.operational_status == None, # noqa: E711
+ Gtfsrealtimefeed.operational_status != "wip",
+ not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
+ )
+ )
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py
index 652a67de8..e8906b13d 100644
--- a/api/src/feeds/impl/search_api_impl.py
+++ b/api/src/feeds/impl/search_api_impl.py
@@ -9,6 +9,8 @@
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
from feeds_gen.apis.search_api_base import BaseSearchApi
from feeds_gen.models.search_feeds200_response import SearchFeeds200Response
+from middleware.request_context import is_user_email_restricted
+from sqlalchemy import or_
feed_search_columns = [column for column in t_feedsearch.columns if column.name != "document"]
@@ -36,6 +38,13 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
The search query is also converted to its unaccented version.
"""
query = query.filter(t_feedsearch.c.data_type != "gbfs") # Filter out GBFS feeds
+ query = query.filter(
+ or_(
+ t_feedsearch.c.operational_status == None, # noqa: E711
+ t_feedsearch.c.operational_status != "wip",
+ is_user_email_restricted(),
+ )
+ )
if feed_id:
query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower())
if data_type:
diff --git a/api/src/middleware/request_context.py b/api/src/middleware/request_context.py
index 569f53844..e019bc633 100644
--- a/api/src/middleware/request_context.py
+++ b/api/src/middleware/request_context.py
@@ -101,3 +101,15 @@ def __repr__(self) -> str:
def get_request_context():
return _request_context.get()
+
+
+def is_user_email_restricted() -> bool:
+ """
+ Check if an email's domain is restricted (e.g., for WIP visibility).
+ """
+ request_context = get_request_context()
+ if not isinstance(request_context, RequestContext):
+ return True # Default to restricted
+ email = get_request_context().user_email
+ unrestricted_domains = ["@mobilitydata.org"]
+ return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains)
diff --git a/api/tests/unittest/middleware/test_request_context.py b/api/tests/unittest/middleware/test_request_context.py
index 3cb32057d..23ef7c120 100644
--- a/api/tests/unittest/middleware/test_request_context.py
+++ b/api/tests/unittest/middleware/test_request_context.py
@@ -3,7 +3,7 @@
from starlette.datastructures import Headers
-from middleware.request_context import RequestContext, get_request_context, _request_context
+from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted
class TestRequestContext(unittest.TestCase):
@@ -54,3 +54,45 @@ def test_get_request_context(self):
request_context = RequestContext(MagicMock())
_request_context.set(request_context)
self.assertEqual(request_context, get_request_context())
+
+ def test_is_user_email_restricted(self):
+ self.assertTrue(is_user_email_restricted())
+ scope_instance = {
+ "type": "http",
+ "asgi": {"version": "3.0"},
+ "http_version": "1.1",
+ "method": "GET",
+ "headers": [
+ (b"host", b"localhost"),
+ (b"x-forwarded-proto", b"https"),
+ (b"x-forwarded-for", b"client, proxy1"),
+ (b"server", b"server"),
+ (b"user-agent", b"user-agent"),
+ (b"x-goog-iap-jwt-assertion", b"jwt"),
+ (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
+ (b"x-goog-authenticated-user-id", b"user_id"),
+ (b"x-goog-authenticated-user-email", b"email"),
+ ],
+ "path": "/",
+ "raw_path": b"/",
+ "query_string": b"",
+ "client": ("127.0.0.1", 32767),
+ "server": ("127.0.0.1", 80),
+ }
+ request_context = RequestContext(scope=scope_instance)
+ _request_context.set(request_context)
+ self.assertTrue(is_user_email_restricted())
+ scope_instance["headers"] = [
+ (b"host", b"localhost"),
+ (b"x-forwarded-proto", b"https"),
+ (b"x-forwarded-for", b"client, proxy1"),
+ (b"server", b"server"),
+ (b"user-agent", b"user-agent"),
+ (b"x-goog-iap-jwt-assertion", b"jwt"),
+ (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
+ (b"x-goog-authenticated-user-id", b"user_id"),
+ (b"x-goog-authenticated-user-email", b"test@mobilitydata.org"),
+ ]
+ request_context = RequestContext(scope=scope_instance)
+ _request_context.set(request_context)
+ self.assertTrue(is_user_email_restricted())
diff --git a/api/tests/unittest/test_feeds.py b/api/tests/unittest/test_feeds.py
index 3ee91f336..b419de7fc 100644
--- a/api/tests/unittest/test_feeds.py
+++ b/api/tests/unittest/test_feeds.py
@@ -84,7 +84,7 @@ def test_feeds_get(client: TestClient, mocker):
mock_filter_offset = Mock()
mock_filter_order_by = Mock()
mock_options = Mock()
- mock_filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
+ mock_filter.return_value.filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
mock_filter_order_by.options.return_value = mock_options
mock_options.offset.return_value = mock_filter_offset
# Target is set to None as deep copy is failing for unknown reasons
@@ -119,7 +119,7 @@ def test_feed_get(client: TestClient, mocker):
Unit test for get_feeds
"""
mock_filter = mocker.patch.object(FeedFilter, "filter")
- mock_filter.return_value.filter.return_value.first.return_value = mock_feed
+ mock_filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_feed
response = client.request(
"GET",
diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml
index a70204ad4..66b7ba633 100644
--- a/liquibase/changelog.xml
+++ b/liquibase/changelog.xml
@@ -29,4 +29,5 @@
+
\ No newline at end of file
diff --git a/liquibase/changes/feat_780.sql b/liquibase/changes/feat_780.sql
new file mode 100644
index 000000000..71752b177
--- /dev/null
+++ b/liquibase/changes/feat_780.sql
@@ -0,0 +1,191 @@
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM pg_type
+ WHERE typname = 'operationalstatus'
+ ) THEN
+ CREATE TYPE OperationalStatus AS ENUM ('wip');
+ END IF;
+END $$;
+
+DO $$
+BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_name = 'feed'
+ AND column_name = 'operational_status'
+ ) THEN
+ ALTER TABLE Feed ADD COLUMN operational_status OperationalStatus DEFAULT NULL;
+ END IF;
+END $$;
+
+
+-- Dropping the materialized view if it exists as we cannot update it
+DROP MATERIALIZED VIEW IF EXISTS FeedSearch;
+
+CREATE MATERIALIZED VIEW FeedSearch AS
+SELECT
+ -- feed
+ Feed.stable_id AS feed_stable_id,
+ Feed.id AS feed_id,
+ Feed.data_type,
+ Feed.status,
+ Feed.feed_name,
+ Feed.note,
+ Feed.feed_contact_email,
+ -- source
+ Feed.producer_url,
+ Feed.authentication_info_url,
+ Feed.authentication_type,
+ Feed.api_key_parameter_name,
+ Feed.license_url,
+ Feed.provider,
+ Feed.operational_status,
+ -- latest_dataset
+ Latest_dataset.id AS latest_dataset_id,
+ Latest_dataset.hosted_url AS latest_dataset_hosted_url,
+ Latest_dataset.downloaded_at AS latest_dataset_downloaded_at,
+ Latest_dataset.bounding_box AS latest_dataset_bounding_box,
+ Latest_dataset.hash AS latest_dataset_hash,
+ -- external_ids
+ ExternalIdJoin.external_ids,
+ -- redirect_ids
+ RedirectingIdJoin.redirect_ids,
+ -- feed gtfs_rt references
+ FeedReferenceJoin.feed_reference_ids,
+ -- feed gtfs_rt entities
+ EntityTypeFeedJoin.entities,
+ -- locations
+ FeedLocationJoin.locations,
+ -- translations
+ FeedCountryTranslationJoin.translations AS country_translations,
+ FeedSubdivisionNameTranslationJoin.translations AS subdivision_name_translations,
+ FeedMunicipalityTranslationJoin.translations AS municipality_translations,
+ -- full-text searchable document
+ setweight(to_tsvector('english', coalesce(unaccent(Feed.feed_name), '')), 'C') ||
+ setweight(to_tsvector('english', coalesce(unaccent(Feed.provider), '')), 'C') ||
+ setweight(to_tsvector('english', coalesce(unaccent((
+ SELECT string_agg(
+ coalesce(location->>'country_code', '') || ' ' ||
+ coalesce(location->>'country', '') || ' ' ||
+ coalesce(location->>'subdivision_name', '') || ' ' ||
+ coalesce(location->>'municipality', ''),
+ ' '
+ )
+ FROM json_array_elements(FeedLocationJoin.locations) AS location
+ )), '')), 'A') ||
+ setweight(to_tsvector('english', coalesce(unaccent((
+ SELECT string_agg(
+ coalesce(translation->>'value', ''),
+ ' '
+ )
+ FROM json_array_elements(FeedCountryTranslationJoin.translations) AS translation
+ )), '')), 'A') ||
+ setweight(to_tsvector('english', coalesce(unaccent((
+ SELECT string_agg(
+ coalesce(translation->>'value', ''),
+ ' '
+ )
+ FROM json_array_elements(FeedSubdivisionNameTranslationJoin.translations) AS translation
+ )), '')), 'A') ||
+ setweight(to_tsvector('english', coalesce(unaccent((
+ SELECT string_agg(
+ coalesce(translation->>'value', ''),
+ ' '
+ )
+ FROM json_array_elements(FeedMunicipalityTranslationJoin.translations) AS translation
+ )), '')), 'A') AS document
+FROM Feed
+LEFT JOIN (
+ SELECT *
+ FROM gtfsdataset
+ WHERE latest = true
+) AS Latest_dataset ON Latest_dataset.feed_id = Feed.id AND Feed.data_type = 'gtfs'
+LEFT JOIN (
+ SELECT
+ feed_id,
+ json_agg(json_build_object('external_id', associated_id, 'source', source)) AS external_ids
+ FROM externalid
+ GROUP BY feed_id
+) AS ExternalIdJoin ON ExternalIdJoin.feed_id = Feed.id
+LEFT JOIN (
+ SELECT
+ gtfs_rt_feed_id,
+ array_agg(FeedReferenceJoinInnerQuery.stable_id) AS feed_reference_ids
+ FROM FeedReference
+ LEFT JOIN Feed AS FeedReferenceJoinInnerQuery ON FeedReferenceJoinInnerQuery.id = FeedReference.gtfs_feed_id
+ GROUP BY gtfs_rt_feed_id
+) AS FeedReferenceJoin ON FeedReferenceJoin.gtfs_rt_feed_id = Feed.id AND Feed.data_type = 'gtfs_rt'
+LEFT JOIN (
+ SELECT
+ target_id,
+ json_agg(json_build_object('target_id', target_id, 'comment', redirect_comment)) AS redirect_ids
+ FROM RedirectingId
+ GROUP BY target_id
+) AS RedirectingIdJoin ON RedirectingIdJoin.target_id = Feed.id
+LEFT JOIN (
+ SELECT
+ LocationFeed.feed_id,
+ json_agg(json_build_object('country', country, 'country_code', country_code, 'subdivision_name',
+ subdivision_name, 'municipality', municipality)) AS locations
+ FROM Location
+ LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id
+ GROUP BY LocationFeed.feed_id
+) AS FeedLocationJoin ON FeedLocationJoin.feed_id = Feed.id
+LEFT JOIN (
+ SELECT
+ LocationFeed.feed_id,
+ json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations
+ FROM Location
+ LEFT JOIN Translation ON Location.country = Translation.key
+ LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id
+ WHERE Translation.language_code = 'en'
+ AND Translation.type = 'country'
+ AND Location.country IS NOT NULL
+ GROUP BY LocationFeed.feed_id
+) AS FeedCountryTranslationJoin ON FeedCountryTranslationJoin.feed_id = Feed.id
+LEFT JOIN (
+ SELECT
+ LocationFeed.feed_id,
+ json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations
+ FROM Location
+ LEFT JOIN Translation ON Location.subdivision_name = Translation.key
+ LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id
+ WHERE Translation.language_code = 'en'
+ AND Translation.type = 'subdivision_name'
+ AND Location.subdivision_name IS NOT NULL
+ GROUP BY LocationFeed.feed_id
+) AS FeedSubdivisionNameTranslationJoin ON FeedSubdivisionNameTranslationJoin.feed_id = Feed.id
+LEFT JOIN (
+ SELECT
+ LocationFeed.feed_id,
+ json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations
+ FROM Location
+ LEFT JOIN Translation ON Location.municipality = Translation.key
+ LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id
+ WHERE Translation.language_code = 'en'
+ AND Translation.type = 'municipality'
+ AND Location.municipality IS NOT NULL
+ GROUP BY LocationFeed.feed_id
+) AS FeedMunicipalityTranslationJoin ON FeedMunicipalityTranslationJoin.feed_id = Feed.id
+LEFT JOIN (
+ SELECT
+ feed_id,
+ array_agg(entity_name) AS entities
+ FROM EntityTypeFeed
+ GROUP BY feed_id
+) AS EntityTypeFeedJoin ON EntityTypeFeedJoin.feed_id = Feed.id AND Feed.data_type = 'gtfs_rt'
+;
+
+
+-- This index allows concurrent refresh on the materialized view avoiding table locks
+CREATE UNIQUE INDEX idx_unique_feed_id ON FeedSearch(feed_id);
+
+-- Indices for feedsearch view optimization
+CREATE INDEX feedsearch_document_idx ON FeedSearch USING GIN(document);
+CREATE INDEX feedsearch_feed_stable_id ON FeedSearch(feed_stable_id);
+CREATE INDEX feedsearch_data_type ON FeedSearch(data_type);
+CREATE INDEX feedsearch_status ON FeedSearch(status);
+