diff --git a/.github/workflows/integration-tests-pr.yml b/.github/workflows/integration-tests-pr.yml index ffe6e570f..5a79bd6bd 100644 --- a/.github/workflows/integration-tests-pr.yml +++ b/.github/workflows/integration-tests-pr.yml @@ -126,7 +126,7 @@ jobs: - name: Start API run: | - scripts/api-start.sh & + scripts/api-start.sh > api_logs.txt 2>&1 & # Redirect stdout and stderr to api_logs.txt sleep 10 # Wait for the API to start - name: Health Check @@ -150,3 +150,4 @@ jobs: path: | integration-tests/src/integration_tests_log.html integration-tests/src/datasets_validation.csv + api_logs.txt diff --git a/api/src/feeds/impl/feeds_api_impl.py b/api/src/feeds/impl/feeds_api_impl.py index ca77c23ed..998090152 100644 --- a/api/src/feeds/impl/feeds_api_impl.py +++ b/api/src/feeds/impl/feeds_api_impl.py @@ -1,5 +1,6 @@ from datetime import datetime from typing import List, Union, TypeVar + from sqlalchemy import select from sqlalchemy.orm import joinedload from sqlalchemy.orm.query import Query @@ -37,6 +38,8 @@ from feeds_gen.models.gtfs_dataset import GtfsDataset from feeds_gen.models.gtfs_feed import GtfsFeed from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed +from middleware.request_context import is_user_email_restricted +from sqlalchemy import or_ from utils.date_utils import valid_iso_date from utils.location_translation import ( create_location_translation_object, @@ -65,6 +68,13 @@ def get_feed( FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None) .filter(Database().get_query_model(Feed)) .filter(Feed.data_type != "gbfs") # Filter out GBFS feeds + .filter( + or_( + Feed.operational_status == None, # noqa: E711 + Feed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) .first() ) if feed: @@ -86,6 +96,13 @@ def get_feeds( ) feed_query = feed_filter.filter(Database().get_query_model(Feed)) feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds + feed_query = feed_query.filter( + or_( + Feed.operational_status == None, # noqa: E711 + Feed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) # Results are sorted by provider feed_query = feed_query.order_by(Feed.provider, Feed.stable_id) feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options()) @@ -118,6 +135,13 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT producer_url__ilike=None, ) .filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en)) + .filter( + or_( + Gtfsfeed.operational_status == None, # noqa: E711 + Gtfsfeed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) .outerjoin(Location, Feed.locations) .outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id) .options( @@ -156,6 +180,13 @@ def get_gtfs_feed_datasets( producer_url__ilike=None, ) .filter(Database().get_query_model(Gtfsfeed)) + .filter( + or_( + Feed.operational_status == None, # noqa: E711 + Feed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) .first() ) @@ -213,6 +244,13 @@ def get_gtfs_feeds( .get_session() .query(Gtfsfeed) .filter(Gtfsfeed.id.in_(subquery)) + .filter( + or_( + Gtfsfeed.operational_status == None, # noqa: E711 + Gtfsfeed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) .options( joinedload(Gtfsfeed.gtfsdatasets) .joinedload(Gtfsdataset.validation_reports) @@ -241,6 +279,13 @@ def get_gtfs_rt_feed( Database() .get_session() .query(Gtfsrealtimefeed, t_location_with_translations_en) + .filter( + or_( + Gtfsrealtimefeed.operational_status == None, # noqa: E711 + Gtfsrealtimefeed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) .outerjoin(Location, Gtfsrealtimefeed.locations) .outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id) .options( @@ -301,6 +346,13 @@ def get_gtfs_rt_feeds( .get_session() .query(Gtfsrealtimefeed) .filter(Gtfsrealtimefeed.id.in_(subquery)) + .filter( + or_( + Gtfsrealtimefeed.operational_status == None, # noqa: E711 + Gtfsrealtimefeed.operational_status != "wip", + not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted + ) + ) .options( joinedload(Gtfsrealtimefeed.entitytypes), joinedload(Gtfsrealtimefeed.gtfs_feeds), diff --git a/api/src/feeds/impl/search_api_impl.py b/api/src/feeds/impl/search_api_impl.py index 652a67de8..e8906b13d 100644 --- a/api/src/feeds/impl/search_api_impl.py +++ b/api/src/feeds/impl/search_api_impl.py @@ -9,6 +9,8 @@ from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl from feeds_gen.apis.search_api_base import BaseSearchApi from feeds_gen.models.search_feeds200_response import SearchFeeds200Response +from middleware.request_context import is_user_email_restricted +from sqlalchemy import or_ feed_search_columns = [column for column in t_feedsearch.columns if column.name != "document"] @@ -36,6 +38,13 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) -> The search query is also converted to its unaccented version. """ query = query.filter(t_feedsearch.c.data_type != "gbfs") # Filter out GBFS feeds + query = query.filter( + or_( + t_feedsearch.c.operational_status == None, # noqa: E711 + t_feedsearch.c.operational_status != "wip", + is_user_email_restricted(), + ) + ) if feed_id: query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower()) if data_type: diff --git a/api/src/middleware/request_context.py b/api/src/middleware/request_context.py index 569f53844..e019bc633 100644 --- a/api/src/middleware/request_context.py +++ b/api/src/middleware/request_context.py @@ -101,3 +101,15 @@ def __repr__(self) -> str: def get_request_context(): return _request_context.get() + + +def is_user_email_restricted() -> bool: + """ + Check if an email's domain is restricted (e.g., for WIP visibility). + """ + request_context = get_request_context() + if not isinstance(request_context, RequestContext): + return True # Default to restricted + email = get_request_context().user_email + unrestricted_domains = ["@mobilitydata.org"] + return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains) diff --git a/api/tests/unittest/middleware/test_request_context.py b/api/tests/unittest/middleware/test_request_context.py index 3cb32057d..23ef7c120 100644 --- a/api/tests/unittest/middleware/test_request_context.py +++ b/api/tests/unittest/middleware/test_request_context.py @@ -3,7 +3,7 @@ from starlette.datastructures import Headers -from middleware.request_context import RequestContext, get_request_context, _request_context +from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted class TestRequestContext(unittest.TestCase): @@ -54,3 +54,45 @@ def test_get_request_context(self): request_context = RequestContext(MagicMock()) _request_context.set(request_context) self.assertEqual(request_context, get_request_context()) + + def test_is_user_email_restricted(self): + self.assertTrue(is_user_email_restricted()) + scope_instance = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": "GET", + "headers": [ + (b"host", b"localhost"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-for", b"client, proxy1"), + (b"server", b"server"), + (b"user-agent", b"user-agent"), + (b"x-goog-iap-jwt-assertion", b"jwt"), + (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"), + (b"x-goog-authenticated-user-id", b"user_id"), + (b"x-goog-authenticated-user-email", b"email"), + ], + "path": "/", + "raw_path": b"/", + "query_string": b"", + "client": ("127.0.0.1", 32767), + "server": ("127.0.0.1", 80), + } + request_context = RequestContext(scope=scope_instance) + _request_context.set(request_context) + self.assertTrue(is_user_email_restricted()) + scope_instance["headers"] = [ + (b"host", b"localhost"), + (b"x-forwarded-proto", b"https"), + (b"x-forwarded-for", b"client, proxy1"), + (b"server", b"server"), + (b"user-agent", b"user-agent"), + (b"x-goog-iap-jwt-assertion", b"jwt"), + (b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"), + (b"x-goog-authenticated-user-id", b"user_id"), + (b"x-goog-authenticated-user-email", b"test@mobilitydata.org"), + ] + request_context = RequestContext(scope=scope_instance) + _request_context.set(request_context) + self.assertTrue(is_user_email_restricted()) diff --git a/api/tests/unittest/test_feeds.py b/api/tests/unittest/test_feeds.py index 3ee91f336..b419de7fc 100644 --- a/api/tests/unittest/test_feeds.py +++ b/api/tests/unittest/test_feeds.py @@ -84,7 +84,7 @@ def test_feeds_get(client: TestClient, mocker): mock_filter_offset = Mock() mock_filter_order_by = Mock() mock_options = Mock() - mock_filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by + mock_filter.return_value.filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by mock_filter_order_by.options.return_value = mock_options mock_options.offset.return_value = mock_filter_offset # Target is set to None as deep copy is failing for unknown reasons @@ -119,7 +119,7 @@ def test_feed_get(client: TestClient, mocker): Unit test for get_feeds """ mock_filter = mocker.patch.object(FeedFilter, "filter") - mock_filter.return_value.filter.return_value.first.return_value = mock_feed + mock_filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_feed response = client.request( "GET", diff --git a/liquibase/changelog.xml b/liquibase/changelog.xml index a70204ad4..66b7ba633 100644 --- a/liquibase/changelog.xml +++ b/liquibase/changelog.xml @@ -29,4 +29,5 @@ + \ No newline at end of file diff --git a/liquibase/changes/feat_780.sql b/liquibase/changes/feat_780.sql new file mode 100644 index 000000000..71752b177 --- /dev/null +++ b/liquibase/changes/feat_780.sql @@ -0,0 +1,191 @@ +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_type + WHERE typname = 'operationalstatus' + ) THEN + CREATE TYPE OperationalStatus AS ENUM ('wip'); + END IF; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_name = 'feed' + AND column_name = 'operational_status' + ) THEN + ALTER TABLE Feed ADD COLUMN operational_status OperationalStatus DEFAULT NULL; + END IF; +END $$; + + +-- Dropping the materialized view if it exists as we cannot update it +DROP MATERIALIZED VIEW IF EXISTS FeedSearch; + +CREATE MATERIALIZED VIEW FeedSearch AS +SELECT + -- feed + Feed.stable_id AS feed_stable_id, + Feed.id AS feed_id, + Feed.data_type, + Feed.status, + Feed.feed_name, + Feed.note, + Feed.feed_contact_email, + -- source + Feed.producer_url, + Feed.authentication_info_url, + Feed.authentication_type, + Feed.api_key_parameter_name, + Feed.license_url, + Feed.provider, + Feed.operational_status, + -- latest_dataset + Latest_dataset.id AS latest_dataset_id, + Latest_dataset.hosted_url AS latest_dataset_hosted_url, + Latest_dataset.downloaded_at AS latest_dataset_downloaded_at, + Latest_dataset.bounding_box AS latest_dataset_bounding_box, + Latest_dataset.hash AS latest_dataset_hash, + -- external_ids + ExternalIdJoin.external_ids, + -- redirect_ids + RedirectingIdJoin.redirect_ids, + -- feed gtfs_rt references + FeedReferenceJoin.feed_reference_ids, + -- feed gtfs_rt entities + EntityTypeFeedJoin.entities, + -- locations + FeedLocationJoin.locations, + -- translations + FeedCountryTranslationJoin.translations AS country_translations, + FeedSubdivisionNameTranslationJoin.translations AS subdivision_name_translations, + FeedMunicipalityTranslationJoin.translations AS municipality_translations, + -- full-text searchable document + setweight(to_tsvector('english', coalesce(unaccent(Feed.feed_name), '')), 'C') || + setweight(to_tsvector('english', coalesce(unaccent(Feed.provider), '')), 'C') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(location->>'country_code', '') || ' ' || + coalesce(location->>'country', '') || ' ' || + coalesce(location->>'subdivision_name', '') || ' ' || + coalesce(location->>'municipality', ''), + ' ' + ) + FROM json_array_elements(FeedLocationJoin.locations) AS location + )), '')), 'A') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(translation->>'value', ''), + ' ' + ) + FROM json_array_elements(FeedCountryTranslationJoin.translations) AS translation + )), '')), 'A') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(translation->>'value', ''), + ' ' + ) + FROM json_array_elements(FeedSubdivisionNameTranslationJoin.translations) AS translation + )), '')), 'A') || + setweight(to_tsvector('english', coalesce(unaccent(( + SELECT string_agg( + coalesce(translation->>'value', ''), + ' ' + ) + FROM json_array_elements(FeedMunicipalityTranslationJoin.translations) AS translation + )), '')), 'A') AS document +FROM Feed +LEFT JOIN ( + SELECT * + FROM gtfsdataset + WHERE latest = true +) AS Latest_dataset ON Latest_dataset.feed_id = Feed.id AND Feed.data_type = 'gtfs' +LEFT JOIN ( + SELECT + feed_id, + json_agg(json_build_object('external_id', associated_id, 'source', source)) AS external_ids + FROM externalid + GROUP BY feed_id +) AS ExternalIdJoin ON ExternalIdJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + gtfs_rt_feed_id, + array_agg(FeedReferenceJoinInnerQuery.stable_id) AS feed_reference_ids + FROM FeedReference + LEFT JOIN Feed AS FeedReferenceJoinInnerQuery ON FeedReferenceJoinInnerQuery.id = FeedReference.gtfs_feed_id + GROUP BY gtfs_rt_feed_id +) AS FeedReferenceJoin ON FeedReferenceJoin.gtfs_rt_feed_id = Feed.id AND Feed.data_type = 'gtfs_rt' +LEFT JOIN ( + SELECT + target_id, + json_agg(json_build_object('target_id', target_id, 'comment', redirect_comment)) AS redirect_ids + FROM RedirectingId + GROUP BY target_id +) AS RedirectingIdJoin ON RedirectingIdJoin.target_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('country', country, 'country_code', country_code, 'subdivision_name', + subdivision_name, 'municipality', municipality)) AS locations + FROM Location + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + GROUP BY LocationFeed.feed_id +) AS FeedLocationJoin ON FeedLocationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations + FROM Location + LEFT JOIN Translation ON Location.country = Translation.key + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + WHERE Translation.language_code = 'en' + AND Translation.type = 'country' + AND Location.country IS NOT NULL + GROUP BY LocationFeed.feed_id +) AS FeedCountryTranslationJoin ON FeedCountryTranslationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations + FROM Location + LEFT JOIN Translation ON Location.subdivision_name = Translation.key + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + WHERE Translation.language_code = 'en' + AND Translation.type = 'subdivision_name' + AND Location.subdivision_name IS NOT NULL + GROUP BY LocationFeed.feed_id +) AS FeedSubdivisionNameTranslationJoin ON FeedSubdivisionNameTranslationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + LocationFeed.feed_id, + json_agg(json_build_object('value', Translation.value, 'key', Translation.key)) AS translations + FROM Location + LEFT JOIN Translation ON Location.municipality = Translation.key + LEFT JOIN LocationFeed ON LocationFeed.location_id = Location.id + WHERE Translation.language_code = 'en' + AND Translation.type = 'municipality' + AND Location.municipality IS NOT NULL + GROUP BY LocationFeed.feed_id +) AS FeedMunicipalityTranslationJoin ON FeedMunicipalityTranslationJoin.feed_id = Feed.id +LEFT JOIN ( + SELECT + feed_id, + array_agg(entity_name) AS entities + FROM EntityTypeFeed + GROUP BY feed_id +) AS EntityTypeFeedJoin ON EntityTypeFeedJoin.feed_id = Feed.id AND Feed.data_type = 'gtfs_rt' +; + + +-- This index allows concurrent refresh on the materialized view avoiding table locks +CREATE UNIQUE INDEX idx_unique_feed_id ON FeedSearch(feed_id); + +-- Indices for feedsearch view optimization +CREATE INDEX feedsearch_document_idx ON FeedSearch USING GIN(document); +CREATE INDEX feedsearch_feed_stable_id ON FeedSearch(feed_stable_id); +CREATE INDEX feedsearch_data_type ON FeedSearch(data_type); +CREATE INDEX feedsearch_status ON FeedSearch(status); +