Skip to content

Commit ca3e5cb

Browse files
authored
feat: added WIP flag to feeds (#812)
1 parent bca0ca1 commit ca3e5cb

File tree

8 files changed

+312
-4
lines changed

8 files changed

+312
-4
lines changed

.github/workflows/integration-tests-pr.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ jobs:
126126
127127
- name: Start API
128128
run: |
129-
scripts/api-start.sh &
129+
scripts/api-start.sh > api_logs.txt 2>&1 & # Redirect stdout and stderr to api_logs.txt
130130
sleep 10 # Wait for the API to start
131131
132132
- name: Health Check
@@ -150,3 +150,4 @@ jobs:
150150
path: |
151151
integration-tests/src/integration_tests_log.html
152152
integration-tests/src/datasets_validation.csv
153+
api_logs.txt

api/src/feeds/impl/feeds_api_impl.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime
22
from typing import List, Union, TypeVar
3+
34
from sqlalchemy import select
45
from sqlalchemy.orm import joinedload
56
from sqlalchemy.orm.query import Query
@@ -37,6 +38,8 @@
3738
from feeds_gen.models.gtfs_dataset import GtfsDataset
3839
from feeds_gen.models.gtfs_feed import GtfsFeed
3940
from feeds_gen.models.gtfs_rt_feed import GtfsRTFeed
41+
from middleware.request_context import is_user_email_restricted
42+
from sqlalchemy import or_
4043
from utils.date_utils import valid_iso_date
4144
from utils.location_translation import (
4245
create_location_translation_object,
@@ -65,6 +68,13 @@ def get_feed(
6568
FeedFilter(stable_id=id, provider__ilike=None, producer_url__ilike=None, status=None)
6669
.filter(Database().get_query_model(Feed))
6770
.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
71+
.filter(
72+
or_(
73+
Feed.operational_status == None, # noqa: E711
74+
Feed.operational_status != "wip",
75+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
76+
)
77+
)
6878
.first()
6979
)
7080
if feed:
@@ -86,6 +96,13 @@ def get_feeds(
8696
)
8797
feed_query = feed_filter.filter(Database().get_query_model(Feed))
8898
feed_query = feed_query.filter(Feed.data_type != "gbfs") # Filter out GBFS feeds
99+
feed_query = feed_query.filter(
100+
or_(
101+
Feed.operational_status == None, # noqa: E711
102+
Feed.operational_status != "wip",
103+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
104+
)
105+
)
89106
# Results are sorted by provider
90107
feed_query = feed_query.order_by(Feed.provider, Feed.stable_id)
91108
feed_query = feed_query.options(*BasicFeedImpl.get_joinedload_options())
@@ -118,6 +135,13 @@ def _get_gtfs_feed(stable_id: str) -> tuple[Gtfsfeed | None, dict[str, LocationT
118135
producer_url__ilike=None,
119136
)
120137
.filter(Database().get_session().query(Gtfsfeed, t_location_with_translations_en))
138+
.filter(
139+
or_(
140+
Gtfsfeed.operational_status == None, # noqa: E711
141+
Gtfsfeed.operational_status != "wip",
142+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
143+
)
144+
)
121145
.outerjoin(Location, Feed.locations)
122146
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
123147
.options(
@@ -156,6 +180,13 @@ def get_gtfs_feed_datasets(
156180
producer_url__ilike=None,
157181
)
158182
.filter(Database().get_query_model(Gtfsfeed))
183+
.filter(
184+
or_(
185+
Feed.operational_status == None, # noqa: E711
186+
Feed.operational_status != "wip",
187+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
188+
)
189+
)
159190
.first()
160191
)
161192

@@ -213,6 +244,13 @@ def get_gtfs_feeds(
213244
.get_session()
214245
.query(Gtfsfeed)
215246
.filter(Gtfsfeed.id.in_(subquery))
247+
.filter(
248+
or_(
249+
Gtfsfeed.operational_status == None, # noqa: E711
250+
Gtfsfeed.operational_status != "wip",
251+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
252+
)
253+
)
216254
.options(
217255
joinedload(Gtfsfeed.gtfsdatasets)
218256
.joinedload(Gtfsdataset.validation_reports)
@@ -241,6 +279,13 @@ def get_gtfs_rt_feed(
241279
Database()
242280
.get_session()
243281
.query(Gtfsrealtimefeed, t_location_with_translations_en)
282+
.filter(
283+
or_(
284+
Gtfsrealtimefeed.operational_status == None, # noqa: E711
285+
Gtfsrealtimefeed.operational_status != "wip",
286+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
287+
)
288+
)
244289
.outerjoin(Location, Gtfsrealtimefeed.locations)
245290
.outerjoin(t_location_with_translations_en, Location.id == t_location_with_translations_en.c.location_id)
246291
.options(
@@ -301,6 +346,13 @@ def get_gtfs_rt_feeds(
301346
.get_session()
302347
.query(Gtfsrealtimefeed)
303348
.filter(Gtfsrealtimefeed.id.in_(subquery))
349+
.filter(
350+
or_(
351+
Gtfsrealtimefeed.operational_status == None, # noqa: E711
352+
Gtfsrealtimefeed.operational_status != "wip",
353+
not is_user_email_restricted(), # Allow all feeds to be returned if the user is not restricted
354+
)
355+
)
304356
.options(
305357
joinedload(Gtfsrealtimefeed.entitytypes),
306358
joinedload(Gtfsrealtimefeed.gtfs_feeds),

api/src/feeds/impl/search_api_impl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from feeds.impl.models.search_feed_item_result_impl import SearchFeedItemResultImpl
1010
from feeds_gen.apis.search_api_base import BaseSearchApi
1111
from feeds_gen.models.search_feeds200_response import SearchFeeds200Response
12+
from middleware.request_context import is_user_email_restricted
13+
from sqlalchemy import or_
1214

1315
feed_search_columns = [column for column in t_feedsearch.columns if column.name != "document"]
1416

@@ -36,6 +38,13 @@ def add_search_query_filters(query, search_query, data_type, feed_id, status) ->
3638
The search query is also converted to its unaccented version.
3739
"""
3840
query = query.filter(t_feedsearch.c.data_type != "gbfs") # Filter out GBFS feeds
41+
query = query.filter(
42+
or_(
43+
t_feedsearch.c.operational_status == None, # noqa: E711
44+
t_feedsearch.c.operational_status != "wip",
45+
is_user_email_restricted(),
46+
)
47+
)
3948
if feed_id:
4049
query = query.where(t_feedsearch.c.feed_stable_id == feed_id.strip().lower())
4150
if data_type:

api/src/middleware/request_context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,15 @@ def __repr__(self) -> str:
101101

102102
def get_request_context():
103103
return _request_context.get()
104+
105+
106+
def is_user_email_restricted() -> bool:
107+
"""
108+
Check if an email's domain is restricted (e.g., for WIP visibility).
109+
"""
110+
request_context = get_request_context()
111+
if not isinstance(request_context, RequestContext):
112+
return True # Default to restricted
113+
email = get_request_context().user_email
114+
unrestricted_domains = ["@mobilitydata.org"]
115+
return not email or not any(email.endswith(f"@{domain}") for domain in unrestricted_domains)

api/tests/unittest/middleware/test_request_context.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from starlette.datastructures import Headers
55

6-
from middleware.request_context import RequestContext, get_request_context, _request_context
6+
from middleware.request_context import RequestContext, get_request_context, _request_context, is_user_email_restricted
77

88

99
class TestRequestContext(unittest.TestCase):
@@ -54,3 +54,45 @@ def test_get_request_context(self):
5454
request_context = RequestContext(MagicMock())
5555
_request_context.set(request_context)
5656
self.assertEqual(request_context, get_request_context())
57+
58+
def test_is_user_email_restricted(self):
59+
self.assertTrue(is_user_email_restricted())
60+
scope_instance = {
61+
"type": "http",
62+
"asgi": {"version": "3.0"},
63+
"http_version": "1.1",
64+
"method": "GET",
65+
"headers": [
66+
(b"host", b"localhost"),
67+
(b"x-forwarded-proto", b"https"),
68+
(b"x-forwarded-for", b"client, proxy1"),
69+
(b"server", b"server"),
70+
(b"user-agent", b"user-agent"),
71+
(b"x-goog-iap-jwt-assertion", b"jwt"),
72+
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
73+
(b"x-goog-authenticated-user-id", b"user_id"),
74+
(b"x-goog-authenticated-user-email", b"email"),
75+
],
76+
"path": "/",
77+
"raw_path": b"/",
78+
"query_string": b"",
79+
"client": ("127.0.0.1", 32767),
80+
"server": ("127.0.0.1", 80),
81+
}
82+
request_context = RequestContext(scope=scope_instance)
83+
_request_context.set(request_context)
84+
self.assertTrue(is_user_email_restricted())
85+
scope_instance["headers"] = [
86+
(b"host", b"localhost"),
87+
(b"x-forwarded-proto", b"https"),
88+
(b"x-forwarded-for", b"client, proxy1"),
89+
(b"server", b"server"),
90+
(b"user-agent", b"user-agent"),
91+
(b"x-goog-iap-jwt-assertion", b"jwt"),
92+
(b"x-cloud-trace-context", b"TRACE_ID/SPAN_ID;o=1"),
93+
(b"x-goog-authenticated-user-id", b"user_id"),
94+
(b"x-goog-authenticated-user-email", b"[email protected]"),
95+
]
96+
request_context = RequestContext(scope=scope_instance)
97+
_request_context.set(request_context)
98+
self.assertTrue(is_user_email_restricted())

api/tests/unittest/test_feeds.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_feeds_get(client: TestClient, mocker):
8484
mock_filter_offset = Mock()
8585
mock_filter_order_by = Mock()
8686
mock_options = Mock()
87-
mock_filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
87+
mock_filter.return_value.filter.return_value.filter.return_value.order_by.return_value = mock_filter_order_by
8888
mock_filter_order_by.options.return_value = mock_options
8989
mock_options.offset.return_value = mock_filter_offset
9090
# Target is set to None as deep copy is failing for unknown reasons
@@ -119,7 +119,7 @@ def test_feed_get(client: TestClient, mocker):
119119
Unit test for get_feeds
120120
"""
121121
mock_filter = mocker.patch.object(FeedFilter, "filter")
122-
mock_filter.return_value.filter.return_value.first.return_value = mock_feed
122+
mock_filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_feed
123123

124124
response = client.request(
125125
"GET",

liquibase/changelog.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@
2929
<include file="changes/feat_622.sql" relativeToChangelogFile="true"/>
3030
<include file="changes/feat_565.sql" relativeToChangelogFile="true"/>
3131
<include file="changes/feat_566.sql" relativeToChangelogFile="true"/>
32+
<include file="changes/feat_780.sql" relativeToChangelogFile="true"/>
3233
</databaseChangeLog>

0 commit comments

Comments
 (0)