Skip to content

Commit 15745bd

Browse files
authored
feat: download feed after creation (#1437)
1 parent adc3b6c commit 15745bd

File tree

8 files changed

+107
-70
lines changed

8 files changed

+107
-70
lines changed

functions-python/helpers/pub_sub.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,19 @@
1515
#
1616
import json
1717
import logging
18+
import os
1819
import uuid
1920
from typing import Dict, List
2021

2122
from google.cloud import pubsub_v1
2223
from google.cloud.pubsub_v1 import PublisherClient
2324
from google.cloud.pubsub_v1.publisher.futures import Future
2425

26+
from shared.database_gen.sqlacodegen_models import Feed, Gtfsfeed
27+
28+
PROJECT_ID = os.getenv("PROJECT_ID")
29+
DATASET_BATCH_TOPIC = os.getenv("DATASET_PROCESSING_TOPIC_NAME")
30+
2531

2632
def get_pubsub_client():
2733
"""
@@ -43,7 +49,13 @@ def get_execution_id(request, prefix: str) -> str:
4349
@param request: HTTP request object
4450
@param prefix: prefix for the execution ID. Example: "batch-datasets"
4551
"""
46-
trace_id = request.headers.get("X-Cloud-Trace-Context")
52+
trace_id = (
53+
request.headers.get("X-Cloud-Trace-Context")
54+
if hasattr(request, "headers")
55+
else None
56+
)
57+
if not trace_id:
58+
trace_id = request.trace_id if hasattr(request, "trace_id") else None
4759
execution_id = f"{prefix}-{trace_id}" if trace_id else f"{prefix}-{uuid.uuid4()}"
4860
return execution_id
4961

@@ -58,3 +70,41 @@ def publish_messages(data: List[Dict], project_id, topic_name) -> None:
5870
message_data = json.dumps(element).encode("utf-8")
5971
future = publish(publisher, topic_path, message_data)
6072
logging.info(f"Published message to Pub/Sub with ID: {future.result()}")
73+
74+
75+
def trigger_dataset_download(
76+
feed: Feed | Gtfsfeed,
77+
execution_id: str,
78+
topic_name: str = DATASET_BATCH_TOPIC,
79+
) -> None:
80+
"""Publishes the feed to the configured Pub/Sub topic."""
81+
publisher = get_pubsub_client()
82+
topic_path = publisher.topic_path(PROJECT_ID, topic_name)
83+
logging.debug("Publishing to Pub/Sub topic: %s", topic_path)
84+
85+
message_data = {
86+
"execution_id": execution_id,
87+
"producer_url": feed.producer_url,
88+
"feed_stable_id": feed.stable_id,
89+
"feed_id": feed.id,
90+
"dataset_id": None,
91+
"dataset_hash": None,
92+
"authentication_type": feed.authentication_type,
93+
"authentication_info_url": feed.authentication_info_url,
94+
"api_key_parameter_name": feed.api_key_parameter_name,
95+
}
96+
97+
try:
98+
# Convert to JSON string
99+
json_message = json.dumps(message_data)
100+
future = publisher.publish(topic_path, data=json_message.encode("utf-8"))
101+
future.add_done_callback(
102+
lambda _: logging.info(
103+
"Published feed %s to dataset batch topic", feed.stable_id
104+
)
105+
)
106+
future.result()
107+
logging.info("Message published for feed %s", feed.stable_id)
108+
except Exception as e:
109+
logging.error("Error publishing to dataset batch topic: %s", str(e))
110+
raise

functions-python/operations_api/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ uvloop==0.19.0
2424

2525
# Additional packages
2626
google-cloud-logging==3.10.0
27+
google-cloud-pubsub
2728
functions-framework==3.*
2829
SQLAlchemy==2.0.23
2930
geoalchemy2==0.14.7

functions-python/operations_api/src/feeds_operations/impl/feeds_operations_impl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16+
1617
from fastapi.encoders import jsonable_encoder
1718
from fastapi.responses import JSONResponse
1819

@@ -47,13 +48,15 @@
4748
from feeds_gen.models.update_request_gtfs_rt_feed import (
4849
UpdateRequestGtfsRtFeed,
4950
)
51+
from middleware.request_context_oauth2 import get_request_context
5052
from shared.database.database import with_db_session, refresh_materialized_view
5153
from shared.database_gen.sqlacodegen_models import (
5254
Gtfsfeed,
5355
t_feedsearch,
5456
Feed,
5557
Gtfsrealtimefeed,
5658
)
59+
from shared.helpers.pub_sub import get_execution_id, trigger_dataset_download
5760
from shared.helpers.query_helper import (
5861
query_feed_by_stable_id,
5962
get_feeds_query,
@@ -365,6 +368,10 @@ async def create_gtfs_feed(
365368
db_session.add(new_feed)
366369
db_session.commit()
367370
created_feed = db_session.get(Gtfsfeed, new_feed.id)
371+
trigger_dataset_download(
372+
created_feed,
373+
get_execution_id(get_request_context(), "feed-created-process"),
374+
)
368375
logging.info("Created new GTFS feed with ID: %s", new_feed.stable_id)
369376
payload = OperationGtfsFeedImpl.from_orm(created_feed).model_dump()
370377
return JSONResponse(status_code=201, content=jsonable_encoder(payload))

functions-python/operations_api/tests/feeds_operations/impl/test_create_feeds_operations_impl_gtfs.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
import json
2424
import uuid
25+
from unittest.mock import patch
2526

2627

2728
@pytest.fixture
@@ -61,7 +62,8 @@ def db_session():
6162

6263

6364
@pytest.mark.asyncio
64-
async def test_create_gtfs_feed_success(db_session):
65+
@patch("feeds_operations.impl.feeds_operations_impl.trigger_dataset_download")
66+
async def test_create_gtfs_feed_success(mock_publish_messages, db_session):
6567
api = OperationsApiImpl()
6668
unique_url = f"https://new-feed.example.com/{uuid.uuid4()}"
6769
request = OperationCreateRequestGtfsFeed(
@@ -105,6 +107,24 @@ async def test_create_gtfs_feed_success(db_session):
105107
assert created.data_type == "gtfs"
106108
assert created.provider == "New Provider"
107109
assert created.operational_status == "wip"
110+
111+
# Assert publish_messages was called exactly once with expected payload
112+
assert mock_publish_messages.call_count == 1
113+
args, kwargs = mock_publish_messages.call_args
114+
assert len(args) == 2 # data list, project_id, topic_name
115+
feed, execution_id = args
116+
117+
# Validate message payload shape and values
118+
# assert isinstance(feed, list) and len(data_list) == 1
119+
# message = feed[0]
120+
assert feed.producer_url == unique_url
121+
assert feed.stable_id == payload["stable_id"]
122+
assert feed.id == payload["id"]
123+
assert feed.authentication_type == "0"
124+
assert feed.authentication_info_url is None
125+
assert feed.api_key_parameter_name is None
126+
# Non-deterministic but must start with expected prefix
127+
assert execution_id.startswith("feed-created-process-")
108128
finally:
109129
# Cleanup to avoid impacting other tests
110130
stable_id = payload.get("stable_id") if isinstance(payload, dict) else None

functions-python/tasks_executor/src/tasks/data_import/data_import_utils.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

functions-python/tasks_executor/src/tasks/data_import/import_jbda_feeds.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
Externalid,
4040
Officialstatushistory,
4141
)
42-
from tasks.data_import.data_import_utils import trigger_dataset_download
43-
from google.cloud import pubsub_v1
42+
43+
from shared.helpers.pub_sub import trigger_dataset_download
4444

4545
T = TypeVar("T", bound="Feed")
4646

@@ -648,9 +648,8 @@ def commit_changes(
648648
logger.info("Commit after processing items (count=%d)", total_processed)
649649
db_session.commit()
650650
execution_id = str(uuid.uuid4())
651-
publisher = pubsub_v1.PublisherClient()
652651
for feed in feeds_to_publish:
653-
trigger_dataset_download(feed, execution_id, publisher)
652+
trigger_dataset_download(feed, execution_id)
654653
except IntegrityError:
655654
db_session.rollback()
656655
logger.exception("Commit failed with IntegrityError; rolled back")

functions-python/tasks_executor/tests/tasks/data_import/test_jbda_import.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import os
2-
import json
32
import unittest
43
from typing import Any, Dict, List
5-
from unittest.mock import patch
4+
from unittest.mock import patch, MagicMock
65

76
from sqlalchemy.orm import Session
87

@@ -205,8 +204,6 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
205204
class TestImportJBDA(unittest.TestCase):
206205
@with_db_session(db_url=default_db_url)
207206
def test_import_creates_gtfs_rt_and_related_links(self, db_session: Session):
208-
fake_pub = _FakePublisher()
209-
210207
# The importer will call HEAD on these URLs for org1/feed1
211208
base = (
212209
"https://api.gtfs-data.jp/v2/organizations/org1/feeds/feed1/files/feed.zip"
@@ -221,6 +218,8 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
221218
# fail for anything else (e.g., feed3 current)
222219
return _FakeResponse(status=404)
223220

221+
# Patch requests.Session and head; replace old pubsub mocks with trigger_dataset_download
222+
mock_trigger = MagicMock()
224223
with patch(
225224
"tasks.data_import.import_jbda_feeds.requests.Session",
226225
return_value=_FakeSessionOK(),
@@ -230,12 +229,8 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
230229
), patch(
231230
"tasks.data_import.import_jbda_feeds.REQUEST_TIMEOUT_S", 0.01
232231
), patch(
233-
"tasks.data_import.import_jbda_feeds.pubsub_v1.PublisherClient",
234-
return_value=fake_pub,
235-
), patch(
236-
"tasks.data_import.data_import_utils.PROJECT_ID", "test-project"
237-
), patch(
238-
"tasks.data_import.data_import_utils.DATASET_BATCH_TOPIC", "dataset-batch"
232+
"tasks.data_import.import_jbda_feeds.trigger_dataset_download",
233+
mock_trigger,
239234
), patch.dict(
240235
os.environ, {"COMMIT_BATCH_SIZE": "1"}, clear=False
241236
):
@@ -262,6 +257,9 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
262257
.first()
263258
)
264259
self.assertIsNotNone(sched)
260+
# ensure the instance is attached to this test session before accessing relationships
261+
sched = db_session.merge(sched)
262+
265263
self.assertEqual(sched.feed_name, "Feed One")
266264
# producer_url now points to the verified JBDA URL (HEAD-checked)
267265
self.assertEqual(sched.producer_url, url_current)
@@ -273,7 +271,7 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
273271
next1 = next(link for link in links if link.code == "jbda-next_1")
274272
self.assertEqual(next1.url, url_next1)
275273

276-
# RT feeds + entity types + back-links
274+
# RT feeds + entity types & back-links
277275
tu = (
278276
db_session.query(Gtfsrealtimefeed)
279277
.filter(Gtfsrealtimefeed.stable_id == "jbda-org1-feed1-tu")
@@ -284,6 +282,11 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
284282
.filter(Gtfsrealtimefeed.stable_id == "jbda-org1-feed1-vp")
285283
.first()
286284
)
285+
286+
# merge to ensure attached before accessing lazy attrs
287+
tu = db_session.merge(tu)
288+
vp = db_session.merge(vp)
289+
287290
self.assertIsNotNone(tu)
288291
self.assertIsNotNone(vp)
289292
self.assertEqual(len(tu.entitytypes), 1)
@@ -295,16 +298,15 @@ def _head_side_effect(url, allow_redirects=True, timeout=15):
295298
self.assertEqual(tu.producer_url, "https://rt.example/one/tu.pb")
296299
self.assertEqual(vp.producer_url, "https://rt.example/one/vp.pb")
297300

298-
# Pub/Sub was called exactly once (only 1 new GTFS feed)
299-
self.assertEqual(len(fake_pub.published), 1)
300-
topic_path, data_bytes = fake_pub.published[0]
301-
self.assertEqual(topic_path, "projects/test-project/topics/dataset-batch")
302-
303-
payload = json.loads(data_bytes.decode("utf-8"))
304-
self.assertEqual(payload["feed_stable_id"], "jbda-org1-feed1")
305-
self.assertEqual(payload["producer_url"], url_current)
306-
self.assertIsNone(payload["dataset_id"])
307-
self.assertIsNone(payload["dataset_hash"])
301+
# trigger_dataset_download (new API) should have been called exactly once for the new GTFS feed
302+
mock_trigger.assert_called_once()
303+
called_args = mock_trigger.call_args[0]
304+
detached_feed = called_args[0]
305+
merged_feed = db_session.merge(detached_feed)
306+
# self.assertGreaterEqual(len(called_args), 1)
307+
self.assertEqual(getattr(merged_feed, "stable_id", None), "jbda-org1-feed1")
308+
# second arg should be an execution id (string)
309+
self.assertIsInstance(called_args[1], str)
308310

309311
@with_db_session(db_url=default_db_url)
310312
def test_import_http_failure_graceful(self, db_session: Session):

infra/functions-python/main.tf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ resource "google_cloudfunctions2_function" "operations_api" {
723723
PROJECT_ID = var.project_id
724724
PYTHONNODEBUGRANGES = 0
725725
GOOGLE_CLIENT_ID = var.operations_oauth2_client_id
726+
DATASET_PROCESSING_TOPIC_NAME = "datasets-batch-topic-${var.environment}"
726727
}
727728
available_memory = local.function_operations_api_config.memory
728729
timeout_seconds = local.function_operations_api_config.timeout

0 commit comments

Comments
 (0)