Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/integration-tests-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ jobs:

- name: Start API
run: |
scripts/api-start.sh &
scripts/api-start.sh > api_logs.txt 2>&1 & # Redirect stdout and stderr to api_logs.txt
sleep 10 # Wait for the API to start

- name: Health Check
Expand All @@ -150,3 +150,4 @@ jobs:
path: |
integration-tests/src/integration_tests_log.html
integration-tests/src/datasets_validation.csv
api_logs.txt
52 changes: 52 additions & 0 deletions api/src/feeds/impl/feeds_api_impl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime
from typing import List, Union, TypeVar

from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.query import Query
Expand Down Expand Up @@ -37,6 +38,8 @@
from feeds_gen.models.gtfs_dataset import GtfsDataset
from feeds_gen.models.gtfs_feed import GtfsFeed
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
from middleware.request_context import is_user_email_restricted
from sqlalchemy import or_
from utils.date_utils import valid_iso_date
from utils.location_translation import (
create_location_translation_object,
Expand Down Expand Up @@ -65,6 +68,13 @@ def get_feed(
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
.filter(Database().get_query_model(Feed))
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
.filter(
or_(
Feed.operational_status == None, # noqa: E711
Feed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.first()
)
if feed:
Expand All @@ -86,6 +96,13 @@ def get_feeds(
)
feed_query = feed_filter.filter(Database().get_query_model(Feed))
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
feed_query = feed_query.filter(
or_(
Feed.operational_status == None, # noqa: E711
Feed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
# Results are sorted by provider
feed_query = feed_query.order_by(Feed.provider, Feed.stable_id)
feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options())
Expand Down Expand Up @@ -118,6 +135,13 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
producer_url__ilike=None,
)
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
.filter(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Gtfsfeed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.outerjoin(Location, Feed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
Expand Down Expand Up @@ -156,6 +180,13 @@ def get_gtfs_feed_datasets(
producer_url__ilike=None,
)
.filter(Database().get_query_model(Gtfsfeed))
.filter(
or_(
Feed.operational_status == None, # noqa: E711
Feed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.first()
)

Expand Down Expand Up @@ -213,6 +244,13 @@ def get_gtfs_feeds(
.get_session()
.query(Gtfsfeed)
.filter(Gtfsfeed.id.in_(subquery))
.filter(
or_(
Gtfsfeed.operational_status == None, # noqa: E711
Gtfsfeed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
Expand Down Expand Up @@ -241,6 +279,13 @@ def get_gtfs_rt_feed(
Database()
.get_session()
.query(Gtfsrealtimefeed, t_location_with_translations_en)
.filter(
or_(
Gtfsrealtimefeed.operational_status == None, # noqa: E711
Gtfsrealtimefeed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.outerjoin(Location, Gtfsrealtimefeed.locations)
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
.options(
Expand Down Expand Up @@ -301,6 +346,13 @@ def get_gtfs_rt_feeds(
.get_session()
.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.id.in_(subquery))
.filter(
or_(
Gtfsrealtimefeed.operational_status == None, # noqa: E711
Gtfsrealtimefeed.operational_status != "wip",
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
)
)
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
Expand Down
9 changes: 9 additions & 0 deletions api/src/feeds/impl/search_api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
from feeds_gen.apis.search_api_base import BaseSearchApi
from feeds_gen.models.search_feeds200_response import SearchFeeds200Response
from middleware.request_context import is_user_email_restricted
from sqlalchemy import or_

feed_search_columns = [column for column in t_feedsearch.columns if column.name != "document"]

Expand Down Expand Up @@ -36,6 +38,13 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
The search query is also converted to its unaccented version.
"""
query = query.filter(t_feedsearch.c.data_type != "gbfs") # Filter out GBFS feeds
query = query.filter(
or_(
t_feedsearch.c.operational_status == None, # noqa: E711
t_feedsearch.c.operational_status != "wip",
is_user_email_restricted(),
)
)
if feed_id:
query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower())
if data_type:
Expand Down
12 changes: 12 additions & 0 deletions api/src/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,15 @@ def __repr__(self) -> str:

def get_request_context():
return _request_context.get()


def is_user_email_restricted() -> bool:
"""
Check if an email's domain is restricted (e.g., for WIP visibility).
"""
request_context = get_request_context()
if not isinstance(request_context, RequestContext):
return True # Default to restricted
email = get_request_context().user_email
unrestricted_domains = ["@mobilitydata.org"]
return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains)
44 changes: 43 additions & 1 deletion api/tests/unittest/middleware/test_request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from starlette.datastructures import Headers

from middleware.request_context import RequestContext, get_request_context, _request_context
from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted


class TestRequestContext(unittest.TestCase):
Expand Down Expand Up @@ -54,3 +54,45 @@ def test_get_request_context(self):
request_context = RequestContext(MagicMock())
_request_context.set(request_context)
self.assertEqual(request_context, get_request_context())

def test_is_user_email_restricted(self):
self.assertTrue(is_user_email_restricted())
scope_instance = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": "GET",
"headers": [
(b"host", b"localhost"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-for", b"client, proxy1"),
(b"server", b"server"),
(b"user-agent", b"user-agent"),
(b"x-goog-iap-jwt-assertion", b"jwt"),
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
(b"x-goog-authenticated-user-id", b"user_id"),
(b"x-goog-authenticated-user-email", b"email"),
],
"path": "/",
"raw_path": b"/",
"query_string": b"",
"client": ("127.0.0.1", 32767),
"server": ("127.0.0.1", 80),
}
request_context = RequestContext(scope=scope_instance)
_request_context.set(request_context)
self.assertTrue(is_user_email_restricted())
scope_instance["headers"] = [
(b"host", b"localhost"),
(b"x-forwarded-proto", b"https"),
(b"x-forwarded-for", b"client, proxy1"),
(b"server", b"server"),
(b"user-agent", b"user-agent"),
(b"x-goog-iap-jwt-assertion", b"jwt"),
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
(b"x-goog-authenticated-user-id", b"user_id"),
(b"x-goog-authenticated-user-email", b"[email protected]"),
]
request_context = RequestContext(scope=scope_instance)
_request_context.set(request_context)
self.assertTrue(is_user_email_restricted())
4 changes: 2 additions & 2 deletions api/tests/unittest/test_feeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_feeds_get(client: TestClient, mocker):
mock_filter_offset = Mock()
mock_filter_order_by = Mock()
mock_options = Mock()
mock_filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
mock_filter.return_value.filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
mock_filter_order_by.options.return_value = mock_options
mock_options.offset.return_value = mock_filter_offset
# Target is set to None as deep copy is failing for unknown reasons
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_feed_get(client: TestClient, mocker):
Unit test for get_feeds
"""
mock_filter = mocker.patch.object(FeedFilter, "filter")
mock_filter.return_value.filter.return_value.first.return_value = mock_feed
mock_filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_feed

response = client.request(
"GET",
Expand Down
1 change: 1 addition & 0 deletions liquibase/changelog.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@
<include file="changes/feat_622.sql" relativeToChangelogFile="true"/>
<include file="changes/feat_565.sql" relativeToChangelogFile="true"/>
<include file="changes/feat_566.sql" relativeToChangelogFile="true"/>
<include file="changes/feat_780.sql" relativeToChangelogFile="true"/>
</databaseChangeLog>
Loading