Skip to content

Commit 7deaacd

Browse files
committed
reuse common trigger_download_dataset function
1 parent f1ac75e commit 7deaacd

File tree

3 files changed

+37
-48
lines changed

3 files changed

+37
-48
lines changed

functions-python/helpers/pub_sub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from google.cloud.pubsub_v1 import PublisherClient
2424
from google.cloud.pubsub_v1.publisher.futures import Future
2525

26-
from shared.database_gen.sqlacodegen_models import Feed
26+
from shared.database_gen.sqlacodegen_models import Feed, Gtfsfeed
2727

2828
PROJECT_ID = os.getenv("PROJECT_ID")
2929
DATASET_BATCH_TOPIC = os.getenv("DATASET_PROCESSING_TOPIC_NAME")
@@ -73,7 +73,7 @@ def publish_messages(data: List[Dict], project_id, topic_name) -> None:
7373

7474

7575
def trigger_dataset_download(
76-
feed: Feed,
76+
feed: Feed | Gtfsfeed,
7777
execution_id: str,
7878
topic_name: str = DATASET_BATCH_TOPIC,
7979
) -> None:

functions-python/operations_api/src/feeds_operations/impl/feeds_operations_impl.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
import os
1716

1817
from fastapi.encoders import jsonable_encoder
1918
from fastapi.responses import JSONResponse
@@ -57,7 +56,7 @@
5756
Feed,
5857
Gtfsrealtimefeed,
5958
)
60-
from shared.helpers.pub_sub import get_execution_id, publish_messages
59+
from shared.helpers.pub_sub import get_execution_id, trigger_dataset_download
6160
from shared.helpers.query_helper import (
6261
query_feed_by_stable_id,
6362
get_feeds_query,
@@ -74,9 +73,6 @@
7473
from .models.operation_gtfs_rt_feed_impl import OperationGtfsRtFeedImpl
7574
from .request_validator import validate_request
7675

77-
pubsub_topic_name = os.getenv("DATASET_PROCESSING_TOPIC_NAME")
78-
project_id = os.getenv("PROJECT_ID")
79-
8076

8177
class OperationsApiImpl(BaseOperationsApi):
8278
"""Implementation of the operations API."""
@@ -112,24 +108,24 @@ def assert_no_existing_feed_url(producer_url: str, db_session: Session):
112108
detail=message,
113109
)
114110

115-
@staticmethod
116-
def send_feed_process_event(feed: type[Gtfsfeed] | None, request=None):
117-
"""Send a message to Pub/Sub to process the feed."""
118-
message_payload = {
119-
"execution_id": get_execution_id(
120-
get_request_context(), "feed-created-process"
121-
),
122-
"producer_url": feed.producer_url,
123-
"feed_stable_id": feed.stable_id,
124-
"feed_id": feed.id,
125-
"dataset_stable_id": None,
126-
"dataset_hash": None,
127-
"authentication_type": feed.authentication_type,
128-
"authentication_info_url": feed.authentication_info_url,
129-
"api_key_parameter_name": feed.api_key_parameter_name,
130-
}
131-
publish_messages([message_payload], project_id, pubsub_topic_name)
132-
logging.debug("Sent feed process event")
111+
# @staticmethod
112+
# def send_feed_process_event(feed: type[Gtfsfeed] | None, request=None):
113+
# """Send a message to Pub/Sub to process the feed."""
114+
# message_payload = {
115+
# "execution_id": get_execution_id(
116+
# get_request_context(), "feed-created-process"
117+
# ),
118+
# "producer_url": feed.producer_url,
119+
# "feed_stable_id": feed.stable_id,
120+
# "feed_id": feed.id,
121+
# "dataset_stable_id": None,
122+
# "dataset_hash": None,
123+
# "authentication_type": feed.authentication_type,
124+
# "authentication_info_url": feed.authentication_info_url,
125+
# "api_key_parameter_name": feed.api_key_parameter_name,
126+
# }
127+
# publish_messages([message_payload], project_id, pubsub_topic_name)
128+
# logging.debug("Sent feed process event")
133129

134130
@with_db_session
135131
async def get_feeds(
@@ -391,7 +387,10 @@ async def create_gtfs_feed(
391387
db_session.add(new_feed)
392388
db_session.commit()
393389
created_feed = db_session.get(Gtfsfeed, new_feed.id)
394-
self.send_feed_process_event(created_feed)
390+
trigger_dataset_download(
391+
created_feed,
392+
get_execution_id(get_request_context(), "feed-created-process"),
393+
)
395394
logging.info("Created new GTFS feed with ID: %s", new_feed.stable_id)
396395
payload = OperationGtfsFeedImpl.from_orm(created_feed).model_dump()
397396
return JSONResponse(status_code=201, content=jsonable_encoder(payload))

functions-python/operations_api/tests/feeds_operations/impl/test_create_feeds_operations_impl_gtfs.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
import uuid
2525
from unittest.mock import patch
2626

27-
# For asserting module-level topic/project used on publish
28-
from feeds_operations.impl import feeds_operations_impl as ops_module
29-
3027

3128
@pytest.fixture
3229
def update_request_gtfs_feed():
@@ -65,7 +62,7 @@ def db_session():
6562

6663

6764
@pytest.mark.asyncio
68-
@patch("feeds_operations.impl.feeds_operations_impl.publish_messages")
65+
@patch("feeds_operations.impl.feeds_operations_impl.trigger_dataset_download")
6966
async def test_create_gtfs_feed_success(mock_publish_messages, db_session):
7067
api = OperationsApiImpl()
7168
unique_url = f"https://new-feed.example.com/{uuid.uuid4()}"
@@ -114,27 +111,20 @@ async def test_create_gtfs_feed_success(mock_publish_messages, db_session):
114111
# Assert publish_messages was called exactly once with expected payload
115112
assert mock_publish_messages.call_count == 1
116113
args, kwargs = mock_publish_messages.call_args
117-
assert len(args) == 3 # data list, project_id, topic_name
118-
data_list, project_id, topic_name = args
119-
120-
# Project/topic should be whatever module-level variables resolve to (may be None in tests)
121-
assert project_id == ops_module.project_id
122-
assert topic_name == ops_module.pubsub_topic_name
114+
assert len(args) == 2 # data list, project_id, topic_name
115+
feed, execution_id = args
123116

124117
# Validate message payload shape and values
125-
assert isinstance(data_list, list) and len(data_list) == 1
126-
message = data_list[0]
127-
assert message["producer_url"] == unique_url
128-
assert message["feed_stable_id"] == payload["stable_id"]
129-
assert message["feed_id"] == payload["id"]
130-
assert message["dataset_stable_id"] is None
131-
assert message["dataset_hash"] is None
132-
assert message["authentication_type"] == "0"
133-
assert message["authentication_info_url"] is None
134-
assert message["api_key_parameter_name"] is None
118+
# assert isinstance(feed, list) and len(data_list) == 1
119+
# message = feed[0]
120+
assert feed.producer_url == unique_url
121+
assert feed.stable_id == payload["stable_id"]
122+
assert feed.id == payload["id"]
123+
assert feed.authentication_type == "0"
124+
assert feed.authentication_info_url is None
125+
assert feed.api_key_parameter_name is None
135126
# Non-deterministic but must start with expected prefix
136-
assert isinstance(message.get("execution_id"), str)
137-
assert message["execution_id"].startswith("feed-created-process-")
127+
assert execution_id.startswith("feed-created-process-")
138128
finally:
139129
# Cleanup to avoid impacting other tests
140130
stable_id = payload.get("stable_id") if isinstance(payload, dict) else None

0 commit comments

Comments
 (0)