diff --git a/functions-python/batch_process_dataset/src/main.py b/functions-python/batch_process_dataset/src/main.py index 68ac63f00..67b1d5f2c 100644 --- a/functions-python/batch_process_dataset/src/main.py +++ b/functions-python/batch_process_dataset/src/main.py @@ -323,7 +323,8 @@ def process_dataset(cloud_event: CloudEvent): dataset_file: DatasetFile = None error_message = None try: - # Extract data from message + # Extract data from message + logging.info(f"Cloud Event: {cloud_event}") data = base64.b64decode(cloud_event.data["message"]["data"]).decode() json_payload = json.loads(data) logging.info( @@ -331,6 +332,12 @@ def process_dataset(cloud_event: CloudEvent): ) stable_id = json_payload["feed_stable_id"] execution_id = json_payload["execution_id"] + except Exception as e: + error_message = f"[{stable_id}] Error parsing message: [{e}]" + logging.error(error_message) + logging.error(f"Function completed with error:{error_message}") + return + try: trace_service = DatasetTraceService() trace = trace_service.get_by_execution_and_stable_ids(execution_id, stable_id) diff --git a/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py b/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py index 13f6e6fc1..9ceb23f8c 100644 --- a/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py +++ b/functions-python/batch_process_dataset/tests/test_batch_process_dataset_main.py @@ -394,7 +394,7 @@ def test_process_dataset_normal_execution( @patch("batch_process_dataset.src.main.Logger") @patch("batch_process_dataset.src.main.DatasetTraceService") @patch("batch_process_dataset.src.main.DatasetProcessor") - def test_process_dataset_exception( + def test_process_dataset_exception_caught( self, mock_dataset_processor, mock_dataset_trace, _ ): db_url = os.getenv("TEST_FEEDS_DATABASE_URL", default=default_db_url) @@ -413,11 +413,7 @@ def test_process_dataset_exception( mock_dataset_trace.get_by_execution_and_stable_ids.return_value = 0 # Call the function - try: - process_dataset(cloud_event) - assert False - except AttributeError: - assert True + process_dataset(cloud_event) @patch("batch_process_dataset.src.main.Logger") @patch("batch_process_dataset.src.main.DatasetTraceService") diff --git a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py index 14298b48b..797d46446 100644 --- a/functions-python/extract_location/src/reverse_geolocation/location_extractor.py +++ b/functions-python/extract_location/src/reverse_geolocation/location_extractor.py @@ -207,6 +207,9 @@ def update_location( .filter(Gtfsfeed.stable_id == dataset.feed.stable_id) .one_or_none() ) + if gtfs_feed is None: + logging.error(f"Feed {dataset.feed.stable_id} not found a GTFS feed.") + raise Exception(f"Feed {dataset.feed.stable_id} not found a GTFS feed.") for gtfs_rt_feed in gtfs_feed.gtfs_rt_feeds: logging.info(f"Updating GTFS-RT feed with stable ID {gtfs_rt_feed.stable_id}") diff --git a/functions-python/feed_sync_dispatcher_transitland/src/main.py b/functions-python/feed_sync_dispatcher_transitland/src/main.py index 9f718182c..1c69738c2 100644 --- a/functions-python/feed_sync_dispatcher_transitland/src/main.py +++ b/functions-python/feed_sync_dispatcher_transitland/src/main.py @@ -14,13 +14,11 @@ # limitations under the License. # -import json import logging import os import random import time -from dataclasses import dataclass, asdict -from typing import Optional, List +from typing import Optional import functions_framework import pandas as pd @@ -29,14 +27,15 @@ from requests.exceptions import RequestException, HTTPError from sqlalchemy.orm import Session -from database_gen.sqlacodegen_models import Gtfsfeed +from database_gen.sqlacodegen_models import Feed from helpers.feed_sync.feed_sync_common import FeedSyncProcessor, FeedSyncPayload from helpers.feed_sync.feed_sync_dispatcher import feed_sync_dispatcher +from helpers.feed_sync.models import TransitFeedSyncPayload from helpers.logger import Logger from helpers.pub_sub import get_pubsub_client, get_execution_id +from typing import Tuple, List +from collections import defaultdict -# Logging configuration -logging.basicConfig(level=logging.INFO) # Environment variables PUBSUB_TOPIC_NAME = os.getenv("PUBSUB_TOPIC_NAME") @@ -45,68 +44,66 @@ TRANSITLAND_API_KEY = os.getenv("TRANSITLAND_API_KEY") TRANSITLAND_OPERATOR_URL = os.getenv("TRANSITLAND_OPERATOR_URL") TRANSITLAND_FEED_URL = os.getenv("TRANSITLAND_FEED_URL") -spec = ["gtfs", "gtfs-rt"] # session instance to reuse connections session = requests.Session() -@dataclass -class TransitFeedSyncPayload: +def process_feed_urls(feed: dict, urls_in_db: List[str]) -> Tuple[List[str], List[str]]: """ - Data class for transit feed sync payloads. + Extracts the valid feed URLs and their corresponding entity types from the feed dictionary. If the same URL + corresponds to multiple entity types, the types are concatenated with a comma. """ + url_keys_to_types = { + "static_current": "", + "realtime_alerts": "sa", + "realtime_trip_updates": "tu", + "realtime_vehicle_positions": "vp", + } - external_id: str - feed_id: str - feed_url: Optional[str] = None - execution_id: Optional[str] = None - spec: Optional[str] = None - auth_info_url: Optional[str] = None - auth_param_name: Optional[str] = None - type: Optional[str] = None - operator_name: Optional[str] = None - country: Optional[str] = None - state_province: Optional[str] = None - city_name: Optional[str] = None - source: Optional[str] = None - payload_type: Optional[str] = None + urls = feed.get("urls", {}) + url_to_entity_types = defaultdict(list) - def to_dict(self): - return asdict(self) + for key, entity_type in url_keys_to_types.items(): + if (url := urls.get(key)) and (url not in urls_in_db): + if entity_type: + logging.info(f"Found URL for entity type: {entity_type}") + url_to_entity_types[url].append(entity_type) - def to_json(self): - return json.dumps(self.to_dict()) + valid_urls = [] + entity_types = [] + for url, types in url_to_entity_types.items(): + valid_urls.append(url) + logging.info(f"URL = {url}, Entity types = {types}") + entity_types.append(",".join(types)) -class TransitFeedSyncProcessor(FeedSyncProcessor): - def check_url_status(self, url: str) -> bool: - """ - Checks if a URL returns a valid response status code. - """ - try: - logging.info(f"Checking URL: {url}") - if url is None or len(url) == 0: - logging.warning("URL is empty. Skipping check.") - return False - response = requests.head(url, timeout=25) - logging.info(f"URL status code: {response.status_code}") - return response.status_code < 400 - except requests.RequestException as e: - logging.warning(f"Failed to reach {url}: {e}") - return False + return valid_urls, entity_types + +class TransitFeedSyncProcessor(FeedSyncProcessor): def process_sync( - self, db_session: Optional[Session] = None, execution_id: Optional[str] = None + self, db_session: Session, execution_id: Optional[str] = None ) -> List[FeedSyncPayload]: """ Process data synchronously to fetch, extract, combine, filter and prepare payloads for publishing to a queue based on conditions related to the data retrieved from TransitLand API. """ - feeds_data = self.get_data( - TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, spec, session + feeds_data_gtfs_rt = self.get_data( + TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, "gtfs_rt", session + ) + logging.info( + "Fetched %s GTFS-RT feeds from TransitLand API", + len(feeds_data_gtfs_rt["feeds"]), + ) + + feeds_data_gtfs = self.get_data( + TRANSITLAND_FEED_URL, TRANSITLAND_API_KEY, "gtfs", session + ) + logging.info( + "Fetched %s GTFS feeds from TransitLand API", len(feeds_data_gtfs["feeds"]) ) - logging.info("Fetched %s feeds from TransitLand API", len(feeds_data["feeds"])) + feeds_data = feeds_data_gtfs["feeds"] + feeds_data_gtfs_rt["feeds"] operators_data = self.get_data( TRANSITLAND_OPERATOR_URL, TRANSITLAND_API_KEY, session=session @@ -115,8 +112,10 @@ def process_sync( "Fetched %s operators from TransitLand API", len(operators_data["operators"]), ) - - feeds = self.extract_feeds_data(feeds_data) + all_urls = set( + [element[0] for element in db_session.query(Feed.producer_url).all()] + ) + feeds = self.extract_feeds_data(feeds_data, all_urls) operators = self.extract_operators_data(operators_data) # Converts operators and feeds to pandas DataFrames @@ -135,16 +134,18 @@ def process_sync( # Filtered out rows where 'feed_url' is missing combined_df = combined_df[combined_df["feed_url"].notna()] - # Group by 'feed_id' and concatenate 'operator_name' while keeping first values of other columns + # Group by 'stable_id' and concatenate 'operator_name' while keeping first values of other columns df_grouped = ( - combined_df.groupby("feed_id") + combined_df.groupby("stable_id") .agg( { "operator_name": lambda x: ", ".join(x), "feeds_onestop_id": "first", + "feed_id": "first", "feed_url": "first", "operator_feed_id": "first", "spec": "first", + "entity_types": "first", "country": "first", "state_province": "first", "city_name": "first", @@ -173,11 +174,6 @@ def process_sync( filtered_df = filtered_df.drop_duplicates( subset=["feed_url"] ) # Drop duplicates - filtered_df = filtered_df[filtered_df["feed_url"].apply(self.check_url_status)] - logging.info( - "Filtered out %s feeds with invalid URLs", - len(df_grouped) - len(filtered_df), - ) # Convert filtered DataFrame to dictionary format combined_data = filtered_df.to_dict(orient="records") @@ -187,7 +183,7 @@ def process_sync( for data in combined_data: external_id = data["feeds_onestop_id"] feed_url = data["feed_url"] - source = "TLD" + source = "tld" if not self.check_external_id(db_session, external_id, source): payload_type = "new" @@ -201,6 +197,8 @@ def process_sync( # prepare payload payload = TransitFeedSyncPayload( external_id=external_id, + stable_id=data["stable_id"], + entity_types=data["entity_types"], feed_id=data["feed_id"], execution_id=execution_id, feed_url=feed_url, @@ -212,7 +210,7 @@ def process_sync( country=data["country"], state_province=data["state_province"], city_name=data["city_name"], - source="TLD", + source="tld", payload_type=payload_type, ) payloads.append(FeedSyncPayload(external_id=external_id, payload=payload)) @@ -277,25 +275,39 @@ def get_data( logging.info("Finished fetching data.") return all_data - def extract_feeds_data(self, feeds_data: dict) -> List[dict]: + def extract_feeds_data(self, feeds_data: dict, urls_in_db: List[str]) -> List[dict]: """ This function extracts relevant data from the Transitland feeds endpoint containing feeds information. Returns a list of dictionaries representing each feed. """ feeds = [] - for feed in feeds_data["feeds"]: - feed_url = feed["urls"].get("static_current") - feeds.append( - { - "feed_id": feed["id"], - "feed_url": feed_url, - "spec": feed["spec"].lower(), - "feeds_onestop_id": feed["onestop_id"], - "auth_info_url": feed["authorization"].get("info_url"), - "auth_param_name": feed["authorization"].get("param_name"), - "type": feed["authorization"].get("type"), - } - ) + for feed in feeds_data: + feed_urls, entity_types = process_feed_urls(feed, urls_in_db) + logging.info("Feed %s has %s valid URL(s)", feed["id"], len(feed_urls)) + logging.info("Feed %s entity types: %s", feed["id"], entity_types) + if len(feed_urls) == 0: + logging.warning("Feed URL not found for feed %s", feed["id"]) + continue + + for feed_url, entity_types in zip(feed_urls, entity_types): + if entity_types is not None and len(entity_types) > 0: + stable_id = f"{feed['id']}-{entity_types.replace(',', '_')}" + else: + stable_id = feed["id"] + logging.info("Stable ID: %s", stable_id) + feeds.append( + { + "feed_id": feed["id"], + "stable_id": stable_id, + "feed_url": feed_url, + "entity_types": entity_types if len(entity_types) > 0 else None, + "spec": feed["spec"].lower(), + "feeds_onestop_id": feed["onestop_id"], + "auth_info_url": feed["authorization"].get("info_url"), + "auth_param_name": feed["authorization"].get("param_name"), + "type": feed["authorization"].get("type"), + } + ) return feeds def extract_operators_data(self, operators_data: dict) -> List[dict]: @@ -309,16 +321,15 @@ def extract_operators_data(self, operators_data: dict) -> List[dict]: places = operator["agencies"][0]["places"] place = places[1] if len(places) > 1 else places[0] - operator_data = { - "operator_name": operator.get("name"), - "operator_feed_id": operator["feeds"][0]["id"] - if operator.get("feeds") - else None, - "country": place.get("adm0_name") if place else None, - "state_province": place.get("adm1_name") if place else None, - "city_name": place.get("city_name") if place else None, - } - operators.append(operator_data) + for related_feed in operator.get("feeds", []): + operator_data = { + "operator_name": operator.get("name"), + "operator_feed_id": related_feed["id"], + "country": place.get("adm0_name") if place else None, + "state_province": place.get("adm1_name") if place else None, + "city_name": place.get("city_name") if place else None, + } + operators.append(operator_data) return operators def check_external_id( @@ -328,12 +339,12 @@ def check_external_id( Checks if the external_id exists in the public.externalid table for the given source. :param db_session: SQLAlchemy session :param external_id: The external_id (feeds_onestop_id) to check - :param source: The source to filter by (e.g., 'TLD' for TransitLand) + :param source: The source to filter by (e.g., 'tld' for TransitLand) :return: True if the feed exists, False otherwise """ results = ( - db_session.query(Gtfsfeed) - .filter(Gtfsfeed.externalids.any(associated_id=external_id)) + db_session.query(Feed) + .filter(Feed.externalids.any(associated_id=external_id)) .all() ) return results is not None and len(results) > 0 @@ -345,12 +356,12 @@ def get_mbd_feed_url( Retrieves the feed_url from the public.feed table in the mbd for the given external_id. :param db_session: SQLAlchemy session :param external_id: The external_id (feeds_onestop_id) from TransitLand - :param source: The source to filter by (e.g., 'TLD' for TransitLand) + :param source: The source to filter by (e.g., 'tld' for TransitLand) :return: feed_url in mbd if exists, otherwise None """ results = ( - db_session.query(Gtfsfeed) - .filter(Gtfsfeed.externalids.any(associated_id=external_id)) + db_session.query(Feed) + .filter(Feed.externalids.any(associated_id=external_id)) .all() ) return results[0].producer_url if results else None diff --git a/functions-python/feed_sync_dispatcher_transitland/tests/test_feed_sync.py b/functions-python/feed_sync_dispatcher_transitland/tests/test_feed_sync.py index 04ec418aa..a7de96e2d 100644 --- a/functions-python/feed_sync_dispatcher_transitland/tests/test_feed_sync.py +++ b/functions-python/feed_sync_dispatcher_transitland/tests/test_feed_sync.py @@ -6,11 +6,12 @@ from database_gen.sqlacodegen_models import Gtfsfeed from feed_sync_dispatcher_transitland.src.main import ( TransitFeedSyncProcessor, - FeedSyncPayload, ) import pandas as pd from requests.exceptions import HTTPError +from helpers.feed_sync.feed_sync_common import FeedSyncPayload + @pytest.fixture def processor(): @@ -59,18 +60,16 @@ def test_get_data_rate_limit(mock_get, processor): def test_extract_feeds_data(processor): - feeds_data = { - "feeds": [ - { - "id": "feed1", - "urls": {"static_current": "http://example.com/feed1"}, - "spec": "gtfs", - "onestop_id": "onestop1", - "authorization": {}, - } - ] - } - result = processor.extract_feeds_data(feeds_data) + feeds_data = [ + { + "id": "feed1", + "urls": {"static_current": "http://example.com"}, + "spec": "gtfs", + "onestop_id": "onestop1", + "authorization": {}, + } + ] + result = processor.extract_feeds_data(feeds_data, []) assert len(result) == 1 assert result[0]["feed_id"] == "feed1" @@ -116,135 +115,87 @@ def test_get_mbd_feed_url(processor): def test_process_sync_new_feed(processor): mock_db_session = Mock(spec=DBSession) - feeds_data = { - "feeds": [ - { - "id": "feed1", - "urls": {"static_current": "http://example.com/feed1"}, - "spec": "gtfs", - "onestop_id": "onestop1", - "authorization": {}, - } - ], - "operators": [], - } - operators_data = { - "operators": [ - { - "name": "Operator 1", - "feeds": [{"id": "feed1"}], - "agencies": [{"places": [{"adm0_name": "USA"}]}], - } - ], - "feeds": [], - } - - processor.get_data = Mock(side_effect=[feeds_data, operators_data]) - - processor.check_url_status = Mock(return_value=True) - - with patch.object(processor, "check_external_id", return_value=False): - payloads = processor.process_sync( - db_session=mock_db_session, execution_id="exec123" - ) - assert len(payloads) == 1 - assert payloads[0].payload.payload_type == "new" - assert payloads[0].payload.external_id == "onestop1" + mock_db_session.query.return_value.all.return_value = [] + feeds_data = [ + { + "id": "feed1", + "urls": {"static_current": "http://example.com"}, + "spec": "gtfs", + "onestop_id": "onestop1", + "authorization": {}, + } + ] + operators_data = [ + { + "name": "Operator 1", + "feeds": [{"id": "feed1"}], + "agencies": [{"places": [{"adm0_name": "USA"}]}], + } + ] + processor.get_data = Mock( + return_value={"feeds": feeds_data, "operators": operators_data} + ) + processor.check_external_id = Mock(return_value=False) + payloads = processor.process_sync(mock_db_session, "exec123") + assert len(payloads) == 1, "Expected 1 payload" + assert payloads[0].payload.payload_type == "new" def test_process_sync_updated_feed(processor): mock_db_session = Mock(spec=DBSession) - feeds_data = { - "feeds": [ - { - "id": "feed1", - "urls": {"static_current": "http://example.com/feed1_updated"}, - "spec": "gtfs", - "onestop_id": "onestop1", - "authorization": {}, - } - ], - "operators": [], - } - operators_data = { - "operators": [ - { - "name": "Operator 1", - "feeds": [{"id": "feed1"}], - "agencies": [{"places": [{"adm0_name": "USA"}]}], - } - ], - "feeds": [], - } - - processor.get_data = Mock(side_effect=[feeds_data, operators_data]) - - processor.check_url_status = Mock(return_value=True) - - processor.check_external_id = Mock(return_value=True) - - processor.get_mbd_feed_url = Mock(return_value="http://example.com/feed1") - - payloads = processor.process_sync( - db_session=mock_db_session, execution_id="exec123" + mock_db_session.query.return_value.all.return_value = [] + feeds_data = [ + { + "id": "feed1", + "urls": {"static_current": "http://example.com"}, + "spec": "gtfs", + "onestop_id": "onestop1", + "authorization": {}, + } + ] + operators_data = [ + { + "name": "Operator 1", + "feeds": [{"id": "feed1"}], + "agencies": [{"places": [{"adm0_name": "USA"}]}], + } + ] + processor.get_data = Mock( + return_value={"feeds": feeds_data, "operators": operators_data} ) - - assert len(payloads) == 1 + processor.check_external_id = Mock(return_value=True) + processor.get_mbd_feed_url = Mock(return_value="http://example-2.com") + payloads = processor.process_sync(mock_db_session, "exec123") + assert len(payloads) == 1, "Expected 1 payload" assert payloads[0].payload.payload_type == "update" - assert payloads[0].payload.external_id == "onestop1" -@patch("feed_sync_dispatcher_transitland.src.main.TransitFeedSyncProcessor.get_data") -def test_process_sync_unchanged_feed(mock_get_data, processor): +def test_process_sync_unchanged_feed(processor): mock_db_session = Mock(spec=DBSession) - feeds_data = { - "feeds": [ - { - "id": "feed1", - "urls": {"static_current": "http://example.com/feed1"}, - "spec": "gtfs", - "onestop_id": "onestop1", - "authorization": {}, - } - ], - "operators": [], - } - operators_data = { - "operators": [ - { - "name": "Operator 1", - "feeds": [{"id": "feed1"}], - "agencies": [{"places": [{"adm0_name": "USA"}]}], - } - ], - "feeds": [], - } - - processor.get_data = Mock(side_effect=[feeds_data, operators_data]) - processor.check_url_status = Mock(return_value=True) - processor.check_external_id = Mock(return_value=True) - processor.get_mbd_feed_url = Mock(return_value="http://example.com/feed1") - processor.get_mbd_feed_url = Mock(return_value="http://example.com/feed1") - payloads = processor.process_sync( - db_session=mock_db_session, execution_id="exec123" - ) - - assert len(payloads) == 0 - - processor.get_mbd_feed_url.assert_called_once_with( - mock_db_session, "onestop1", "TLD" + mock_db_session.query.return_value.all.return_value = [] + feeds_data = [ + { + "id": "feed1", + "urls": {"static_current": "http://example.com"}, + "spec": "gtfs", + "onestop_id": "onestop1", + "authorization": {}, + } + ] + operators_data = [ + { + "name": "Operator 1", + "feeds": [{"id": "feed1"}], + "agencies": [{"places": [{"adm0_name": "USA"}]}], + } + ] + processor.get_data = Mock( + return_value={"feeds": feeds_data, "operators": operators_data} ) - - -@patch("feed_sync_dispatcher_transitland.src.main.requests.head") -def test_check_url_status(mock_head, processor): - mock_head.return_value.status_code = 200 - result = processor.check_url_status("http://example.com") - assert result is True - - mock_head.return_value.status_code = 404 - result = processor.check_url_status("http://example.com") - assert result is False + processor.check_external_id = Mock(return_value=True) + processor.get_mbd_feed_url = Mock(return_value="http://example.com") + payloads = processor.process_sync(mock_db_session, "exec123") + assert len(payloads) == 0, "No payloads expected" def test_merge_and_filter_dataframes(processor): diff --git a/functions-python/feed_sync_process_transitland/.coveragerc b/functions-python/feed_sync_process_transitland/.coveragerc index c52988ffd..89dac199f 100644 --- a/functions-python/feed_sync_process_transitland/.coveragerc +++ b/functions-python/feed_sync_process_transitland/.coveragerc @@ -1,6 +1,7 @@ [run] omit = */test*/* + */database_gen/* */dataset_service/* */helpers/* diff --git a/functions-python/feed_sync_process_transitland/function_config.json b/functions-python/feed_sync_process_transitland/function_config.json index 088c8bd32..adddc6cfa 100644 --- a/functions-python/feed_sync_process_transitland/function_config.json +++ b/functions-python/feed_sync_process_transitland/function_config.json @@ -12,7 +12,7 @@ } ], "ingress_settings": "ALLOW_INTERNAL_AND_GCLB", - "max_instance_request_concurrency": 20, + "max_instance_request_concurrency": 1, "max_instance_count": 10, "min_instance_count": 0, "available_cpu": 1 diff --git a/functions-python/feed_sync_process_transitland/src/feed_processor_utils.py b/functions-python/feed_sync_process_transitland/src/feed_processor_utils.py new file mode 100644 index 000000000..056bc7840 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/src/feed_processor_utils.py @@ -0,0 +1,101 @@ +import logging +import uuid +from datetime import datetime +from typing import Tuple, Optional +from sqlalchemy.orm import Session +import requests +from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload +from database_gen.sqlacodegen_models import ( + Gtfsfeed, + Gtfsrealtimefeed, + Externalid, + Entitytype, + Feed, +) +from helpers.locations import create_or_get_location + + +def check_url_status(url: str) -> bool: + """Check if a URL is reachable.""" + try: + response = requests.head(url, timeout=10) + return response.status_code < 400 or response.status_code == 403 + except requests.RequestException: + logging.warning(f"Failed to reach URL: {url}") + return False + + +def get_feed_model(spec: str) -> Tuple[type, str]: + """Map feed specification to model and type.""" + spec_lower = spec.lower().replace("-", "_") + if spec_lower == "gtfs": + return Gtfsfeed, spec_lower + if spec_lower == "gtfs_rt": + return Gtfsrealtimefeed, spec_lower + raise ValueError(f"Invalid feed specification: {spec}") + + +def get_tlnd_authentication_type(auth_type: Optional[str]) -> str: + """Map TransitLand authentication type to database format.""" + if auth_type in (None, ""): + return "0" + if auth_type == "query_param": + return "1" + if auth_type == "header": + return "2" + raise ValueError(f"Invalid authentication type: {auth_type}") + + +def create_new_feed(session: Session, stable_id: str, payload: FeedPayload) -> Feed: + """Create a new feed and its dependencies.""" + feed_type, data_type = get_feed_model(payload.spec) + + # Create new feed + new_feed = feed_type( + id=str(uuid.uuid4()), + stable_id=stable_id, + producer_url=payload.feed_url, + data_type=data_type, + authentication_type=get_tlnd_authentication_type(payload.type), + authentication_info_url=payload.auth_info_url, + api_key_parameter_name=payload.auth_param_name, + status="active", + provider=payload.operator_name, + operational_status="wip", + created_at=datetime.now(), + ) + + # Add external ID relationship + external_id = Externalid( + feed_id=new_feed.id, + associated_id=payload.external_id, + source=payload.source, + ) + new_feed.externalids = [external_id] + + # Add entity types if applicable + if feed_type == Gtfsrealtimefeed and payload.entity_types: + entity_type_names = payload.entity_types.split(",") + for entity_name in entity_type_names: + entity = session.query(Entitytype).filter_by(name=entity_name).first() + if not entity: + entity = Entitytype(name=entity_name) + session.add(entity) + new_feed.entitytypes.append(entity) + + # Add location if provided + location = create_or_get_location( + session, + payload.country, + payload.state_province, + payload.city_name, + ) + if location: + new_feed.locations = [location] + logging.debug(f"Added location for feed {new_feed.id}") + + # Persist the new feed + session.add(new_feed) + session.flush() + logging.info(f"Created new feed with ID: {new_feed.id}") + return new_feed diff --git a/functions-python/feed_sync_process_transitland/src/main.py b/functions-python/feed_sync_process_transitland/src/main.py index 1a6a3b6c0..50160e105 100644 --- a/functions-python/feed_sync_process_transitland/src/main.py +++ b/functions-python/feed_sync_process_transitland/src/main.py @@ -13,58 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import base64 import json import logging import os -import uuid -from typing import Optional, Tuple +from typing import Optional, List import functions_framework from google.cloud import pubsub_v1 -from sqlalchemy.orm import Session -from database_gen.sqlacodegen_models import Feed, Externalid, Redirectingid from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session -from helpers.database import start_db_session, close_db_session -from helpers.logger import Logger, StableIdFilter +from database_gen.sqlacodegen_models import Feed +from helpers.database import start_db_session, configure_polymorphic_mappers from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload -from helpers.locations import create_or_get_location - -# Configure logging -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" -) - -logger = logging.getLogger("feed_processor") -handler = logging.StreamHandler() -handler.setFormatter( - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") -) -logger.addHandler(handler) -logger.setLevel(logging.INFO) - -# Initialize GCP logger for cloud environment -Logger.init_logger() -gcp_logger = Logger("feed_processor").get_logger() - - -def log_message(level, message): - """Log messages to both local and GCP loggers""" - if level == "info": - logger.info(message) - gcp_logger.info(message) - elif level == "error": - logger.error(message) - gcp_logger.error(message) - elif level == "warning": - logger.warning(message) - gcp_logger.warning(message) - elif level == "debug": - logger.debug(message) - gcp_logger.debug(message) - +from helpers.logger import Logger +from .feed_processor_utils import check_url_status, create_new_feed # Environment variables PROJECT_ID = os.getenv("PROJECT_ID") @@ -73,404 +37,158 @@ def log_message(level, message): class FeedProcessor: - """Handles feed processing operations""" - def __init__(self, db_session: Session): self.session = db_session self.publisher = pubsub_v1.PublisherClient() + self.feed_stable_id: Optional[str] = None def process_feed(self, payload: FeedPayload) -> None: - """ - Processes feed idempotently based on database state - - Args: - payload (FeedPayload): The feed payload to process - """ - gcp_logger.addFilter(StableIdFilter(payload.external_id)) + """Process a feed based on its database state.""" try: - log_message( - "info", - f"Starting feed processing for external_id: {payload.external_id}", + logging.info( + f"Processing feed: external_id={payload.external_id}, feed_id={payload.feed_id}" ) + if not check_url_status(payload.feed_url): + logging.error(f"Feed URL not reachable: {payload.feed_url}. Skipping.") + return - # Check current state of feed in database - current_feed_id, current_url = self.get_current_feed_info( - payload.external_id, payload.source - ) + self.feed_stable_id = f"{payload.source}-{payload.stable_id}".lower() + current_feeds = self._get_current_feeds(payload.external_id, payload.source) - if current_feed_id is None: - log_message("info", "Processing new feed") - # If no existing feed_id found - check if URL exists in any feed - if self.check_feed_url_exists(payload.feed_url): - log_message("error", f"Feed URL already exists: {payload.feed_url}") - return - self.process_new_feed(payload) + if not current_feeds: + new_feed = self._process_new_feed_or_skip(payload) else: - # If Feed exists - check if URL has changed - if current_url != payload.feed_url: - log_message("info", "Processing feed update") - log_message( - "debug", - f"Found existing feed: {current_feed_id} with different URL", - ) - self.process_feed_update(payload, current_feed_id) - else: - log_message( - "error", - f"Feed already exists with same URL: {payload.external_id}", - ) - return + new_feed = self._process_existing_feed_refs(payload, current_feeds) self.session.commit() - log_message("debug", "Database transaction committed successfully") - - # Publish to dataset_batch_topic if not authenticated - if not payload.auth_info_url: - self.publish_to_batch_topic(payload) - + self._publish_to_batch_topic_if_needed(payload, new_feed) except SQLAlchemyError as e: - error_msg = ( - f"Database error processing feed {payload.external_id}: {str(e)}" - ) - log_message("error", error_msg) - self.session.rollback() - log_message("error", "Database transaction rolled back due to error") - raise - except Exception as e: - error_msg = f"Error processing feed {payload.external_id}: {str(e)}" - log_message("error", error_msg) - self.session.rollback() - log_message("error", "Database transaction rolled back due to error") - raise - - def process_new_feed(self, payload: FeedPayload) -> None: - """ - Process creation of a new feed - - Args: - payload (FeedPayload): The feed payload for new feed - """ - try: - log_message( - "info", - f"Starting new feed creation for external_id: {payload.external_id}", - ) - - # Check if feed with same URL exists - if self.check_feed_url_exists(payload.feed_url): - log_message("error", f"Feed URL already exists: {payload.feed_url}") - return - - # Generate new feed ID and stable ID - feed_id = str(uuid.uuid4()) - stable_id = f"{payload.source}-{payload.external_id}" - - log_message( - "debug", f"Generated new feed_id: {feed_id} and stable_id: {stable_id}" - ) - - try: - # Create new feed - new_feed = Feed( - id=feed_id, - data_type=payload.spec, - producer_url=payload.feed_url, - authentication_type=payload.type if payload.type else "0", - authentication_info_url=payload.auth_info_url, - api_key_parameter_name=payload.auth_param_name, - stable_id=stable_id, - status="active", - provider=payload.operator_name, - operational_status="wip", - ) - - # external ID mapping - external_id = Externalid( - feed_id=feed_id, - associated_id=payload.external_id, - source=payload.source, - ) - - # Add relationships - new_feed.externalids.append(external_id) - - # Create or get location - location = create_or_get_location( - self.session, - payload.country, - payload.state_province, - payload.city_name, - ) - - if location is not None: - new_feed.locations.append(location) - log_message( - "debug", f"Added location information for feed: {feed_id}" - ) - else: - log_message( - "debug", f"No location information to add for feed: {feed_id}" - ) - - self.session.add(new_feed) - self.session.flush() - - log_message("debug", f"Successfully created feed with ID: {feed_id}") - log_message( - "info", - f"Created new feed with ID: {feed_id} for external_id: {payload.external_id}", - ) - - except SQLAlchemyError as e: - self.session.rollback() - error_msg = f"Database error creating feed for external_id {payload.external_id}: {str(e)}" - log_message("error", error_msg) - raise - + self._rollback_transaction(f"Database error: {str(e)}") except Exception as e: - error_msg = f"Database error creating feed for external_id {payload.external_id}: {str(e)}" - log_message("error", error_msg) - raise - - def process_feed_update(self, payload: FeedPayload, old_feed_id: str) -> None: - """ - Process feed update when URL has changed - - Args: - payload (FeedPayload): The feed payload for update - old_feed_id (str): The ID of the existing feed to be updated - """ - log_message( - "info", - f"Starting feed update process for external_id: {payload.external_id}", + self._rollback_transaction(f"Error processing feed: {str(e)}") + + def _process_new_feed_or_skip(self, payload: FeedPayload) -> Optional[Feed]: + """Process a new feed or skip if the URL already exists.""" + if self._check_feed_url_exists(payload.feed_url): + logging.error(f"Feed URL already exists: {payload.feed_url}. Skipping.") + return + logging.info(f"Creating new feed for external_id: {payload.external_id}") + return create_new_feed(self.session, self.feed_stable_id, payload) + + def _process_existing_feed_refs( + self, payload: FeedPayload, current_feeds: List[Feed] + ) -> Optional[Feed]: + """Process existing feeds, updating if necessary.""" + matching_feeds = [ + f for f in current_feeds if f.producer_url == payload.feed_url + ] + if matching_feeds: + logging.info(f"Feed with URL already exists: {payload.feed_url}. Skipping.") + return + + stable_id_matches = [ + f for f in current_feeds if self.feed_stable_id in f.stable_id + ] + reference_count = len(stable_id_matches) + active_match = [f for f in stable_id_matches if f.status == "active"] + if reference_count > 0: + logging.info(f"Updating feed for stable_id: {self.feed_stable_id}") + self.feed_stable_id = f"{self.feed_stable_id}_{reference_count}".lower() + new_feed = self._deprecate_old_feed(payload, active_match[0].id) + else: + logging.info( + f"No matching stable_id. Creating new feed for {payload.external_id}." + ) + new_feed = create_new_feed(self.session, self.feed_stable_id, payload) + return new_feed + + def _check_feed_url_exists(self, feed_url: str) -> bool: + """Check if a feed with the given URL exists.""" + existing_feeds = ( + self.session.query(Feed).filter_by(producer_url=feed_url).count() ) - log_message("debug", f"Old feed_id: {old_feed_id}, New URL: {payload.feed_url}") - - try: - # Get count of existing references to this external ID - reference_count = ( - self.session.query(Feed) - .join(Externalid) - .filter( - Externalid.associated_id == payload.external_id, - Externalid.source == payload.source, - ) - .count() - ) - - # Create new feed with updated URL - new_feed_id = str(uuid.uuid4()) - # Added counter to stable_id - stable_id = ( - f"{payload.source}-{payload.external_id}" - if reference_count == 1 - else f"{payload.source}-{payload.external_id}_{reference_count}" - ) - - log_message( - "debug", - f"Generated new stable_id: {stable_id} (reference count: {reference_count})", - ) + return existing_feeds > 0 - # Create new feed entry - new_feed = Feed( - id=new_feed_id, - data_type=payload.spec, - producer_url=payload.feed_url, - authentication_type=payload.type if payload.type else "0", - authentication_info_url=payload.auth_info_url, - api_key_parameter_name=payload.auth_param_name, - stable_id=stable_id, - status="active", - provider=payload.operator_name, - operational_status="wip", - ) - - # Add new feed to session - self.session.add(new_feed) - - # Update old feed status to deprecated - old_feed = self.session.get(Feed, old_feed_id) - if old_feed: - old_feed.status = "deprecated" - log_message("debug", f"Deprecating old feed ID: {old_feed_id}") - - # Create new external ID mapping for updated feed - new_external_id = Externalid( - feed_id=new_feed_id, - associated_id=payload.external_id, - source=payload.source, - ) - self.session.add(new_external_id) - log_message( - "debug", f"Created new external ID mapping for feed_id: {new_feed_id}" - ) - - # Create redirect - redirect = Redirectingid(source_id=old_feed_id, target_id=new_feed_id) - self.session.add(redirect) - log_message( - "debug", f"Created redirect from {old_feed_id} to {new_feed_id}" - ) - - # Create or get location and add to new feed - location = create_or_get_location( - self.session, payload.country, payload.state_province, payload.city_name - ) - - if location: - new_feed.locations.append(location) - log_message( - "debug", f"Added location information for feed: {new_feed_id}" - ) - - self.session.flush() - - log_message( - "info", - f"Updated feed for external_id: {payload.external_id}, new feed_id: {new_feed_id}", - ) - - except Exception as e: - log_message( - "error", - f"Error updating feed for external_id {payload.external_id}: {str(e)}", - ) - raise - - def check_feed_url_exists(self, feed_url: str) -> bool: - """ - Check if a feed with the given URL exists in any state (active or deprecated). - This check is used to prevent creating new feeds with URLs that are already in use. - - Args: - feed_url (str): The URL to check - - Returns: - bool: True if any feed with this URL exists (either active or deprecated), - preventing creation of new feeds with duplicate URLs - """ - results = self.session.query(Feed).filter(Feed.producer_url == feed_url).all() - - if results: - if len(results) > 1: - log_message("warning", f"Multiple feeds found with URL: {feed_url}") - return True - - result = results[0] - if result.status == "active": - log_message( - "info", f"Found existing feed with URL: {feed_url} (status: active)" - ) - return True - elif result.status == "deprecated": - log_message( - "error", - f"Feed URL {feed_url} exists in deprecated feed (id: {result.id}). " - "Cannot reuse URLs from deprecated feeds.", - ) - return True - - log_message("debug", f"No existing feed found with URL: {feed_url}") - return False - - def get_current_feed_info( - self, external_id: str, source: str - ) -> Tuple[Optional[str], Optional[str]]: - """ - Get current feed ID and URL for given external ID - - Args: - external_id (str): The external ID to look up - source (str): The source of the feed - - Returns: - Tuple[Optional[str], Optional[str]]: Tuple of (feed_id, feed_url) - """ - result = ( + def _get_current_feeds(self, external_id: str, source: str) -> List[Feed]: + """Retrieve current feeds for a given external ID and source.""" + return ( self.session.query(Feed) .filter(Feed.externalids.any(associated_id=external_id, source=source)) - .first() + .all() ) - if result is not None: - log_message( - "info", - f"Retrieved feed {result.stable_id} " - f"info for external_id: {external_id} (status: {result.status})", - ) - return result.id, result.producer_url - log_message("info", f"No existing feed found for external_id: {external_id}") - return None, None - - def publish_to_batch_topic(self, payload: FeedPayload) -> None: - """ - Publish feed to dataset batch topic - Args: - payload (FeedPayload): The feed payload to publish - """ + def _deprecate_old_feed( + self, payload: FeedPayload, old_feed_id: Optional[str] + ) -> Feed: + """Update the status of an old feed and create a new one.""" + if old_feed_id: + old_feed = self.session.get(Feed, old_feed_id) + if old_feed: + old_feed.status = "deprecated" + logging.info(f"Deprecated old feed: {old_feed.id}") + return create_new_feed(self.session, self.feed_stable_id, payload) + + def _publish_to_batch_topic_if_needed( + self, payload: FeedPayload, feed: Optional[Feed] + ) -> None: + """Publishes a feed to the dataset batch topic if it meets the necessary criteria.""" + if ( + feed is not None + and feed.authentication_type == "0" # Authentication type check + and payload.spec == "gtfs" # Only for GTFS feeds + ): + self._publish_to_topic(feed, payload) + + def _publish_to_topic(self, feed: Feed, payload: FeedPayload) -> None: + """Publishes the feed to the configured Pub/Sub topic.""" topic_path = self.publisher.topic_path(PROJECT_ID, DATASET_BATCH_TOPIC) - log_message("debug", f"Publishing to topic: {topic_path}") + logging.debug(f"Publishing to Pub/Sub topic: {topic_path}") - # Prepare message data in the expected format message_data = { "execution_id": payload.execution_id, - "producer_url": payload.feed_url, - "feed_stable_id": f"{payload.source}-{payload.external_id}", - "feed_id": payload.feed_id, + "producer_url": feed.producer_url, + "feed_stable_id": feed.stable_id, + "feed_id": feed.id, "dataset_id": None, "dataset_hash": None, - "authentication_type": payload.type if payload.type else "0", - "authentication_info_url": payload.auth_info_url, - "api_key_parameter_name": payload.auth_param_name, + "authentication_type": feed.authentication_type, + "authentication_info_url": feed.authentication_info_url, + "api_key_parameter_name": feed.api_key_parameter_name, } try: - log_message("debug", f"Preparing to publish feed_id: {payload.feed_id}") - # Convert to JSON string and encode as base64 - json_str = json.dumps(message_data) - encoded_data = base64.b64encode(json_str.encode("utf-8")) - - future = self.publisher.publish(topic_path, data=encoded_data) - future.result() - log_message( - "info", f"Published feed {payload.feed_id} to dataset batch topic" + # Convert to JSON string + json_message = json.dumps(message_data) + future = self.publisher.publish( + topic_path, data=json_message.encode("utf-8") + ) + future.add_done_callback( + lambda _: logging.info( + f"Published feed {feed.stable_id} to dataset batch topic" + ) ) + future.result() + logging.info(f"Message published for feed {feed.stable_id}") except Exception as e: - error_msg = f"Error publishing to dataset batch topic: {str(e)}" - log_message("error", error_msg) + logging.error(f"Error publishing to dataset batch topic: {str(e)}") raise + def _rollback_transaction(self, message: str) -> None: + """Rollback the current transaction and log an error.""" + logging.error(message) + self.session.rollback() -@functions_framework.cloud_event -def process_feed_event(cloud_event): - """ - Cloud Function to process feed events from Pub/Sub - Args: - cloud_event (CloudEvent): The cloud event - containing the Pub/Sub message - """ +@functions_framework.cloud_event +def process_feed_event(cloud_event) -> None: + """Cloud Function entry point for feed processing.""" + Logger.init_logger() + configure_polymorphic_mappers() try: - # Decode payload from Pub/Sub message - pubsub_message = base64.b64decode(cloud_event.data["message"]["data"]).decode() - message_data = json.loads(pubsub_message) - - payload = FeedPayload(**message_data) - + message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode() + payload = FeedPayload(**json.loads(message_data)) db_session = start_db_session(FEEDS_DATABASE_URL) - - try: - processor = FeedProcessor(db_session) - processor.process_feed(payload) - - log_message("info", f"Successfully processed feed: {payload.external_id}") - return "Success", 200 - - finally: - close_db_session(db_session) - + processor = FeedProcessor(db_session) + processor.process_feed(payload) except Exception as e: - error_msg = f"Error processing feed event: {str(e)}" - log_message("error", error_msg) - return error_msg, 500 + logging.error(f"Error processing feed event: {str(e)}") diff --git a/functions-python/feed_sync_process_transitland/tests/test_feed_processor_utils.py b/functions-python/feed_sync_process_transitland/tests/test_feed_processor_utils.py new file mode 100644 index 000000000..2d1a87733 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/tests/test_feed_processor_utils.py @@ -0,0 +1,76 @@ +from unittest.mock import patch + +import requests + +from database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed +from feed_sync_process_transitland.src.feed_processor_utils import ( + check_url_status, + get_feed_model, + get_tlnd_authentication_type, + create_new_feed, +) +from helpers.database import start_db_session, configure_polymorphic_mappers +from helpers.feed_sync.models import TransitFeedSyncPayload +from test_utils.database_utils import default_db_url + + +@patch("requests.head") +def test_check_url_status(mock_head): + mock_head.return_value.status_code = 200 + assert check_url_status("http://example.com") + mock_head.return_value.status_code = 404 + assert not check_url_status("http://example.com/404") + mock_head.return_value.status_code = 403 + assert check_url_status("http://example.com/403") + mock_head.side_effect = requests.RequestException("Error") + assert not check_url_status("http://example.com/exception") + + +def test_get_feed_model(): + assert get_feed_model("gtfs") == (Gtfsfeed, "gtfs") + assert get_feed_model("gtfs_rt") == (Gtfsrealtimefeed, "gtfs_rt") + try: + get_feed_model("invalid") + assert False + except ValueError: + assert True + + +def test_get_tlnd_authentication_type() -> str: + assert get_tlnd_authentication_type(None) == "0" + assert get_tlnd_authentication_type("") == "0" + assert get_tlnd_authentication_type("query_param") == "1" + assert get_tlnd_authentication_type("header") == "2" + try: + get_tlnd_authentication_type("invalid") + assert False + except ValueError: + assert True + + +@patch.dict("os.environ", {"FEEDS_DATABASE_URL": default_db_url}) +def test_create_new_feed_gtfs_rt(): + payload = { + "spec": "gtfs_rt", + "entity_types": "tu", + "feed_url": "http://example.com", + "feed_id": "102_tu", + "stable_id": "tld-102_tu", + "type": "query_param", + "auth_info_url": "http://example.com/info", + "auth_param_name": "key", + "operator_name": "Operator 1", + "external_id": "onestop1", + "source": "tld", + "country": "USA", + "state_province": "California", + "city_name": "San Francisco", + } + feed_payload = TransitFeedSyncPayload(**payload) + configure_polymorphic_mappers() + session = start_db_session(default_db_url, echo=False) + new_feed = create_new_feed(session, "tld-102_tu", feed_payload) + session.delete(new_feed) + assert new_feed.stable_id == "tld-102_tu" + assert new_feed.data_type == "gtfs_rt" + assert len(new_feed.entitytypes) == 1 diff --git a/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py b/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py index b4848ce56..5eb04dba7 100644 --- a/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py +++ b/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py @@ -1,24 +1,20 @@ import base64 import json import logging -import uuid from unittest import mock from unittest.mock import patch, Mock, MagicMock -import os import pytest -from google.api_core.exceptions import DeadlineExceeded from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session as DBSession -from database_gen.sqlacodegen_models import Feed +from database_gen.sqlacodegen_models import Feed, Gtfsfeed from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload with mock.patch("helpers.logger.Logger.init_logger") as mock_init_logger: from feed_sync_process_transitland.src.main import ( FeedProcessor, process_feed_event, - log_message, ) # Environment variables for tests @@ -68,10 +64,10 @@ def get_logger(self): @pytest.fixture(autouse=True) def mock_logging(): """Mock both local and GCP logging.""" - with patch("feed_sync_process_transitland.src.main.logger") as mock_log, patch( - "feed_sync_process_transitland.src.main.gcp_logger" - ) as mock_gcp_log, patch("helpers.logger.Logger", MockLogger): - for logger in [mock_log, mock_gcp_log]: + with patch("feed_sync_process_transitland.src.main.logging") as mock_log, patch( + "feed_sync_process_transitland.src.main.Logger", MockLogger + ): + for logger in [mock_log]: logger.info = MagicMock() logger.error = MagicMock() logger.warning = MagicMock() @@ -86,8 +82,9 @@ def feed_payload(): """Fixture for feed payload.""" return FeedPayload( external_id="test123", + stable_id="feed1", feed_id="feed1", - feed_url="https://example.com/feed1", + feed_url="https://example.com", execution_id="exec123", spec="gtfs", auth_info_url=None, @@ -157,135 +154,53 @@ def _create_payload_dict(feed_payload: FeedPayload) -> dict: def test_get_current_feed_info(self, processor, feed_payload, mock_logging): """Test retrieving current feed information.""" # Mock database query - processor.session.query.return_value.filter.return_value.first.return_value = ( - Mock( + processor.session.query.return_value.filter.return_value.all.return_value = [ + Feed( id="feed-uuid", producer_url="https://example.com/feed", stable_id="TLD-test123", status="active", ) - ) + ] - feed_id, url = processor.get_current_feed_info( + feeds = processor._get_current_feeds( feed_payload.external_id, feed_payload.source ) # Assertions + assert len(feeds) == 1 + feed_id, url = feeds[0].id, feeds[0].producer_url assert feed_id == "feed-uuid" assert url == "https://example.com/feed" - mock_logging.info.assert_called_with( - "Retrieved feed TLD-test123 " - f"info for external_id: {feed_payload.external_id} (status: active)" - ) # Test case when feed does not exist - processor.session.query.return_value.filter.return_value.first.return_value = ( - None - ) - feed_id, url = processor.get_current_feed_info( + processor.session.query.return_value.filter.return_value.all.return_value = [] + feeds = processor._get_current_feeds( feed_payload.external_id, feed_payload.source ) - - assert feed_id is None - assert url is None - mock_logging.info.assert_called_with( - f"No existing feed found for external_id: {feed_payload.external_id}" - ) + assert len(feeds) == 0 def test_check_feed_url_exists_comprehensive(self, processor, mock_logging): """Test comprehensive feed URL existence checks.""" test_url = "https://example.com/feed" # Test case 1: Active feed exists - mock_feed = Mock(id="test-id", status="active") - processor.session.query.return_value.filter.return_value.all.return_value = [ - mock_feed - ] - - result = processor.check_feed_url_exists(test_url) - assert result is True - mock_logging.info.assert_called_with( - f"Found existing feed with URL: {test_url} (status: active)" + processor.session.query.return_value.filter_by.return_value.count.return_value = ( + 1 ) - # Test case 2: Deprecated feed exists - mock_logging.info.reset_mock() - mock_feed.status = "deprecated" - result = processor.check_feed_url_exists(test_url) + result = processor._check_feed_url_exists(test_url) assert result is True - mock_logging.error.assert_called_with( - f"Feed URL {test_url} exists in deprecated feed (id: {mock_feed.id}). " - "Cannot reuse URLs from deprecated feeds." - ) - - # Test case 3: No feed exists - mock_logging.error.reset_mock() - processor.session.query.return_value.filter.return_value.all.return_value = [] - result = processor.check_feed_url_exists(test_url) - assert result is False - mock_logging.debug.assert_called_with( - f"No existing feed found with URL: {test_url}" - ) - - # Test case 4: Multiple feeds with same URL - mock_logging.debug.reset_mock() - mock_feeds = [ - Mock(id="feed1", status="active"), - Mock(id="feed2", status="deprecated"), - ] - processor.session.query.return_value.filter.return_value.all.return_value = ( - mock_feeds - ) - result = processor.check_feed_url_exists(test_url) - assert result is True - mock_logging.warning.assert_called_with( - f"Multiple feeds found with URL: {test_url}" - ) - - def test_log_message_function(self, mock_logging): - """Test the log_message function for different log levels.""" - levels = ["info", "error", "warning", "debug"] - messages = ["Info message", "Error message", "Warning message", "Debug message"] - - for level, message in zip(levels, messages): - log_message(level, message) - - if level == "info": - mock_logging.info.assert_called_with(message) - elif level == "error": - mock_logging.error.assert_called_with(message) - elif level == "warning": - mock_logging.warning.assert_called_with(message) - elif level == "debug": - mock_logging.debug.assert_called_with(message) def test_database_error_handling(self, processor, feed_payload, mock_logging): """Test database error handling in different scenarios.""" # Test case 1: General database error during feed processing processor.session.query.side_effect = SQLAlchemyError("Database error") - with pytest.raises(SQLAlchemyError, match="Database error"): - processor.process_feed(feed_payload) - - processor.session.rollback.assert_called_once() - mock_logging.error.assert_called_with( - "Database transaction rolled back due to error" - ) + processor._rollback_transaction = MagicMock(return_value=None) + processor.process_feed(feed_payload) - # Reset mocks for next test - processor.session.rollback.reset_mock() - mock_logging.error.reset_mock() - - # Test case 2: Connection failure during feed processing - processor.session.query.side_effect = SQLAlchemyError("Connection refused") - - with pytest.raises(SQLAlchemyError, match="Connection refused"): - processor.process_feed(feed_payload) - - processor.session.rollback.assert_called_once() - mock_logging.error.assert_called_with( - "Database transaction rolled back due to error" - ) + processor._rollback_transaction.assert_called_once() def test_publish_to_batch_topic_comprehensive( self, processor, feed_payload, mock_logging @@ -297,45 +212,21 @@ def test_publish_to_batch_topic_comprehensive( mock_future = Mock() processor.publisher.publish.return_value = mock_future - processor.publish_to_batch_topic(feed_payload) - - # Verify publish was called and message format - call_args = processor.publisher.publish.call_args - assert call_args is not None - _, kwargs = call_args - - # Decode and verify message content - message_data = json.loads(base64.b64decode(kwargs["data"]).decode("utf-8")) - assert message_data["execution_id"] == feed_payload.execution_id - assert message_data["producer_url"] == feed_payload.feed_url - assert ( - message_data["feed_stable_id"] - == f"{feed_payload.source}-{feed_payload.external_id}" - ) - - mock_logging.info.assert_called_with( - f"Published feed {feed_payload.feed_id} to dataset batch topic" + processor._publish_to_batch_topic_if_needed( + feed_payload, + Feed( + id="test-id", + authentication_type="0", + producer_url=feed_payload.feed_url, + stable_id=f"{feed_payload.source}-{feed_payload.feed_id}".lower(), + ), ) - # Test case 2: Publish error - processor.publisher.publish.side_effect = Exception("Pub/Sub error") - - with pytest.raises(Exception, match="Pub/Sub error"): - processor.publish_to_batch_topic(feed_payload) - - mock_logging.error.assert_called_with( - "Error publishing to dataset batch topic: Pub/Sub error" - ) - - # Test case 3: Timeout error - processor.publisher.publish.side_effect = DeadlineExceeded("Timeout error") - - with pytest.raises(DeadlineExceeded, match="Timeout error"): - processor.publish_to_batch_topic(feed_payload) - - mock_logging.error.assert_called_with( - "Error publishing to dataset batch topic: 504 Timeout error" - ) + # Verify publish was called and message format + topic_arg, message_arg = processor.publisher.publish.call_args + assert topic_arg == ("test_topic",) + assert "feed_stable_id" in json.loads(message_arg["data"]) + assert "tld-feed1" == json.loads(message_arg["data"])["feed_stable_id"] def test_process_feed_event_validation(self, mock_logging): """Test feed event processing with various invalid payloads.""" @@ -345,267 +236,24 @@ def test_process_feed_event_validation(self, mock_logging): cloud_event = Mock() cloud_event.data = {"message": {"data": empty_payload_data}} - result = process_feed_event(cloud_event) - assert result[1] == 500 - mock_logging.error.assert_called_with( - "Error processing feed event: TransitFeedSyncPayload.__init__() missing 14 " - "required positional arguments: 'external_id', 'feed_id', 'feed_url', " - "'execution_id', 'spec', 'auth_info_url', 'auth_param_name', 'type', " - "'operator_name', 'country', 'state_province', 'city_name', 'source', and " - "'payload_type'" - ) + process_feed_event(cloud_event) # Test case 2: Invalid field - mock_logging.error.reset_mock() invalid_payload_data = base64.b64encode( json.dumps({"invalid": "data"}).encode("utf-8") ).decode() cloud_event.data = {"message": {"data": invalid_payload_data}} - result = process_feed_event(cloud_event) - assert result[1] == 500 - mock_logging.error.assert_called_with( - "Error processing feed event: TransitFeedSyncPayload.__init__() got an " - "unexpected keyword argument 'invalid'" - ) + process_feed_event(cloud_event) # Test case 3: Type error - mock_logging.error.reset_mock() type_error_payload = {"external_id": 12345, "feed_url": True, "feed_id": None} payload_data = base64.b64encode( json.dumps(type_error_payload).encode("utf-8") ).decode() cloud_event.data = {"message": {"data": payload_data}} - result = process_feed_event(cloud_event) - assert result[1] == 500 - mock_logging.error.assert_called_with( - "Error processing feed event: TransitFeedSyncPayload.__init__() missing 11 " - "required positional arguments: 'execution_id', 'spec', 'auth_info_url', " - "'auth_param_name', 'type', 'operator_name', 'country', 'state_province', " - "'city_name', 'source', and 'payload_type'" - ) - - def test_process_new_feed_with_location( - self, processor, feed_payload, mock_logging - ): - """Test creating a new feed with location information.""" - # Mock UUID generation - new_feed_id = str(uuid.uuid4()) - - # Mock database query to return no existing feeds - processor.session.query.return_value.filter.return_value.all.return_value = [] - - with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)): - # Mock Location class - mock_location_cls = Mock(name="Location") - mock_location = mock_location_cls.return_value - mock_location.id = "US-CA-Test City" - mock_location.country_code = "US" - mock_location.country = "United States" - mock_location.subdivision_name = "CA" - mock_location.municipality = "Test City" - mock_location.__eq__ = ( - lambda self, other: isinstance(other, Mock) and self.id == other.id - ) - - # Create a Feed class with a real list for locations - class MockFeed: - def __init__(self): - self.locations = [] - self.externalids = [] - self.id = new_feed_id - self.producer_url = feed_payload.feed_url - self.data_type = feed_payload.spec - self.provider = feed_payload.operator_name - self.status = "active" - self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" - - mock_feed = MockFeed() - - with patch( - "database_gen.sqlacodegen_models.Feed", return_value=mock_feed - ), patch( - "database_gen.sqlacodegen_models.Location", mock_location_cls - ), patch( - "helpers.locations.create_or_get_location", return_value=mock_location - ): - processor.process_new_feed(feed_payload) - - # Verify feed creation - created_feed = processor.session.add.call_args[0][0] - assert created_feed.id == new_feed_id - assert created_feed.producer_url == feed_payload.feed_url - assert created_feed.data_type == feed_payload.spec - assert created_feed.provider == feed_payload.operator_name - - # Verify location was added to feed - assert len(created_feed.locations) == 1 - assert created_feed.locations[0].id == "US-CA-Test City" - assert created_feed.locations[0].country_code == "US" - assert created_feed.locations[0].country == "United States" - assert created_feed.locations[0].subdivision_name == "CA" - assert created_feed.locations[0].municipality == "Test City" - mock_logging.debug.assert_any_call( - f"Added location information for feed: {new_feed_id}" - ) - - def test_process_new_feed_without_location( - self, processor, feed_payload, mock_logging - ): - """Test creating a new feed without location information.""" - # Modify payload to have no location info - feed_payload.country = None - feed_payload.state_province = None - feed_payload.city_name = None - - # Mock database query to return no existing feeds - processor.session.query.return_value.filter.return_value.all.return_value = [] - - # Mock UUID generation - new_feed_id = str(uuid.uuid4()) - - # Create a Feed class with a real list for locations - class MockFeed: - def __init__(self): - self.locations = [] - self.externalids = [] - self.id = new_feed_id - self.producer_url = feed_payload.feed_url - self.data_type = feed_payload.spec - self.provider = feed_payload.operator_name - self.status = "active" - self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" - - mock_feed = MockFeed() - - with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( - "database_gen.sqlacodegen_models.Feed", return_value=mock_feed - ), patch("helpers.locations.create_or_get_location", return_value=None): - processor.process_new_feed(feed_payload) - - # Verify feed creation - created_feed = processor.session.add.call_args[0][0] - assert created_feed.id == new_feed_id - assert not created_feed.locations - - def test_process_feed_update_with_location( - self, processor, feed_payload, mock_logging - ): - """Test updating a feed with location information.""" - old_feed_id = str(uuid.uuid4()) - new_feed_id = str(uuid.uuid4()) - - # Mock database query to return no existing feeds - processor.session.query.return_value.filter.return_value.all.return_value = [] - - # Mock old feed - mock_old_feed = Mock(id=old_feed_id, status="active") - processor.session.get.return_value = mock_old_feed - - # Mock Location class - mock_location_cls = Mock(name="Location") - mock_location = mock_location_cls.return_value - mock_location.id = "US-CA-Test City" - mock_location.country_code = "US" - mock_location.country = "United States" - mock_location.subdivision_name = "CA" - mock_location.municipality = "Test City" - mock_location.__eq__ = ( - lambda self, other: isinstance(other, Mock) and self.id == other.id - ) - - # Create a Feed class with a real list for locations - class MockFeed: - def __init__(self): - self.locations = [] - self.externalids = [] - self.id = new_feed_id - self.producer_url = feed_payload.feed_url - self.data_type = feed_payload.spec - self.provider = feed_payload.operator_name - self.status = "active" - self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" - - mock_new_feed = MockFeed() - - with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( - "database_gen.sqlacodegen_models.Feed", return_value=mock_new_feed - ), patch("database_gen.sqlacodegen_models.Location", mock_location_cls), patch( - "helpers.locations.create_or_get_location", return_value=mock_location - ): - processor.process_feed_update(feed_payload, old_feed_id) - - # Verify feed update - assert mock_old_feed.status == "deprecated" - - # Find the Feed object in the add calls - feed_add_call = None - for call in processor.session.add.call_args_list: - obj = call[0][0] - if hasattr(obj, "locations"): # This is our Feed object - feed_add_call = call - break - - assert ( - feed_add_call is not None - ), "Feed object not found in session.add calls" - created_feed = feed_add_call[0][0] - - # Verify new feed creation with location - assert len(created_feed.locations) == 1 - assert created_feed.locations[0].id == "US-CA-Test City" - assert created_feed.locations[0].country_code == "US" - assert created_feed.locations[0].country == "United States" - assert created_feed.locations[0].subdivision_name == "CA" - assert created_feed.locations[0].municipality == "Test City" - mock_logging.debug.assert_any_call( - f"Added location information for feed: {new_feed_id}" - ) - - def test_process_feed_update_without_location( - self, processor, feed_payload, mock_logging - ): - """Test updating a feed without location information.""" - old_feed_id = str(uuid.uuid4()) - new_feed_id = str(uuid.uuid4()) - - # Mock database query to return no existing feeds - processor.session.query.return_value.filter.return_value.all.return_value = [] - - # Modify payload to have no location info - feed_payload.country = None - feed_payload.state_province = None - feed_payload.city_name = None - - # Mock old feed - mock_old_feed = Mock(id=old_feed_id, status="active") - processor.session.get.return_value = mock_old_feed - - # Create a Feed class with a real list for locations - class MockFeed: - def __init__(self): - self.locations = [] - self.externalids = [] - self.id = new_feed_id - self.producer_url = feed_payload.feed_url - self.data_type = feed_payload.spec - self.provider = feed_payload.operator_name - self.status = "active" - self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" - - mock_new_feed = MockFeed() - - with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( - "database_gen.sqlacodegen_models.Feed", return_value=mock_new_feed - ), patch("helpers.locations.create_or_get_location", return_value=None): - processor.process_feed_update(feed_payload, old_feed_id) - - # Verify feed update - assert mock_old_feed.status == "deprecated" - - # Verify new feed creation without location - assert not mock_new_feed.locations + process_feed_event(cloud_event) def test_process_feed_event_database_connection_error( self, processor, feed_payload, mock_logging @@ -627,11 +275,7 @@ def test_process_feed_event_database_connection_error( "Database connection error" ) - result = process_feed_event(cloud_event) - assert result[1] == 500 - mock_logging.error.assert_called_with( - "Error processing feed event: Database connection error" - ) + process_feed_event(cloud_event) def test_process_feed_event_pubsub_error( self, processor, feed_payload, mock_logging @@ -656,11 +300,7 @@ def test_process_feed_event_pubsub_error( "feed_sync_process_transitland.src.main.start_db_session", return_value=mock_session, ): - result = process_feed_event(cloud_event) - assert result[1] == 500 - mock_logging.error.assert_called_with( - "Error processing feed event: File dummy-credentials.json was not found." - ) + process_feed_event(cloud_event) def test_process_feed_event_malformed_cloud_event(self, mock_logging): """Test feed event processing with malformed cloud event.""" @@ -668,138 +308,12 @@ def test_process_feed_event_malformed_cloud_event(self, mock_logging): cloud_event = Mock() cloud_event.data = {} - result = process_feed_event(cloud_event) - assert result[1] == 500 - mock_logging.error.assert_called_with("Error processing feed event: 'message'") + process_feed_event(cloud_event) # Test case 2: Invalid base64 data - mock_logging.error.reset_mock() cloud_event.data = {"message": {"data": "invalid-base64"}} - result = process_feed_event(cloud_event) - error_msg = ( - "Error processing feed event: Invalid base64-encoded string: " - "number of data characters (13) cannot be 1 more than a multiple of 4" - ) - mock_logging.error.assert_called_with(error_msg) - - def test_publish_to_batch_topic(self, processor, feed_payload, mock_logging): - """Test publishing feed to batch topic.""" - # Mock the topic path - topic_path = "projects/test-project/topics/test-topic" - processor.publisher.topic_path.return_value = topic_path - - # Mock the publish future - mock_future = Mock() - mock_future.result.return_value = "message_id" - processor.publisher.publish.return_value = mock_future - - # Call the method - processor.publish_to_batch_topic(feed_payload) - - # Verify topic path was created correctly - processor.publisher.topic_path.assert_called_once_with( - os.getenv("PROJECT_ID"), os.getenv("DATASET_BATCH_TOPIC") - ) - - # Expected message data - expected_data = { - "execution_id": feed_payload.execution_id, - "producer_url": feed_payload.feed_url, - "feed_stable_id": f"{feed_payload.source}-{feed_payload.external_id}", - "feed_id": feed_payload.feed_id, - "dataset_id": None, - "dataset_hash": None, - "authentication_type": "0", # default value when type is None - "authentication_info_url": feed_payload.auth_info_url, - "api_key_parameter_name": feed_payload.auth_param_name, - } - - # Verify publish was called with correct data - encoded_data = base64.b64encode(json.dumps(expected_data).encode("utf-8")) - processor.publisher.publish.assert_called_once_with( - topic_path, data=encoded_data - ) - - # Verify success was logged - mock_logging.info.assert_called_with( - f"Published feed {feed_payload.feed_id} to dataset batch topic" - ) - - def test_publish_to_batch_topic_error(self, processor, feed_payload, mock_logging): - """Test error handling when publishing to batch topic fails.""" - # Mock the topic path - topic_path = "projects/test-project/topics/test-topic" - processor.publisher.topic_path.return_value = topic_path - - # Mock publish to raise an error - error_msg = "Failed to publish" - processor.publisher.publish.side_effect = Exception(error_msg) - - # Call the method and verify it raises the error - with pytest.raises(Exception) as exc_info: - processor.publish_to_batch_topic(feed_payload) - - assert str(exc_info.value) == error_msg - - # Verify error was logged - mock_logging.error.assert_called_with( - f"Error publishing to dataset batch topic: {error_msg}" - ) - - def test_process_feed_update_with_multiple_references( - self, processor, feed_payload, mock_logging - ): - """Test updating feed with multiple external ID references""" - old_feed_id = "old-feed-uuid" - - # Mock multiple references to the external ID - processor.session.query.return_value.join.return_value.filter.return_value.count.return_value = ( - 3 - ) - - # Mock getting old feed - mock_old_feed = Mock(spec=Feed) - processor.session.get.return_value = mock_old_feed - - # Process the update - processor.process_feed_update(feed_payload, old_feed_id) - - # Verify stable_id includes reference count - expected_stable_id = f"{feed_payload.source}-{feed_payload.external_id}_3" - mock_logging.debug.assert_any_call( - f"Generated new stable_id: {expected_stable_id} (reference count: 3)" - ) - - # Verify old feed was deprecated - assert mock_old_feed.status == "deprecated" - - def test_process_feed_with_auth_info(self, processor, feed_payload, mock_logging): - """Test processing feed with authentication info""" - # Modify payload to include auth info - feed_payload.auth_info_url = "https://auth.example.com" - feed_payload.type = "oauth2" - feed_payload.auth_param_name = "access_token" - - # Mock the methods - with patch.object( - processor, "get_current_feed_info", return_value=(None, None) - ), patch.object( - processor, "check_feed_url_exists", return_value=False - ), patch.object( - processor, "process_new_feed" - ) as mock_process_new_feed: - # Process the feed - processor.process_feed(feed_payload) - - # Verify feed was processed - mock_process_new_feed.assert_called_once_with(feed_payload) - mock_logging.debug.assert_any_call( - "Database transaction committed successfully" - ) - - # Verify not published to batch topic (because auth_info_url is set) - processor.publisher.publish.assert_not_called() + process_feed_event(cloud_event) def test_process_feed_event_invalid_json(self, mock_logging): """Test handling of invalid JSON in cloud event""" @@ -810,30 +324,101 @@ def test_process_feed_event_invalid_json(self, mock_logging): cloud_event.data = {"message": {"data": invalid_json}} # Process the event - result, status_code = process_feed_event(cloud_event) + process_feed_event(cloud_event) # Verify error handling - assert status_code == 500 - assert "Error processing feed event" in result mock_logging.error.assert_called() - def test_process_feed_update_without_old_feed( - self, processor, feed_payload, mock_logging + @patch("feed_sync_process_transitland.src.main.create_new_feed") + def test_process_new_feed_or_skip( + self, create_new_feed_mock, processor, feed_payload, mock_logging ): - """Test feed update when old feed is not found""" - old_feed_id = "non-existent-feed" - - # Mock old feed not found - processor.session.get.return_value = None - - # Process the update - processor.process_feed_update(feed_payload, old_feed_id) + """Test processing new feed or skipping existing feed.""" + processor._check_feed_url_exists = MagicMock() + # Test case 1: New feed + processor._check_feed_url_exists.return_value = False + processor._process_new_feed_or_skip(feed_payload) + create_new_feed_mock.assert_called_once() + + @patch("feed_sync_process_transitland.src.main.create_new_feed") + def test_process_new_feed_skip( + self, create_new_feed_mock, processor, feed_payload, mock_logging + ): + """Test processing new feed or skipping existing feed.""" + processor._check_feed_url_exists = MagicMock() + # Test case 2: Existing feed + processor._check_feed_url_exists.return_value = True + processor._process_new_feed_or_skip(feed_payload) + create_new_feed_mock.assert_not_called() + + @patch("feed_sync_process_transitland.src.main.create_new_feed") + def test_process_existing_feed_refs( + self, create_new_feed_mock, processor, feed_payload, mock_logging + ): + """Test processing existing feed references.""" + # 1. Existing feed with same url + matching_feeds = [ + Gtfsfeed( + id="feed-uuid", + producer_url="https://example.com", + stable_id="TLD-test123", + status="active", + ) + ] + new_feed = processor._process_existing_feed_refs(feed_payload, matching_feeds) + assert new_feed is None - # Verify processing continued without error - mock_logging.debug.assert_any_call( - f"Old feed_id: {old_feed_id}, New URL: {feed_payload.feed_url}" + # 2. Existing feed with same stable_id + matching_feeds = [ + Gtfsfeed( + id="feed-uuid", + producer_url="https://example.com/different", + stable_id="tld-feed1", + status="active", + ) + ] + processor.feed_stable_id = "tld-feed1" + processor._deprecate_old_feed = MagicMock( + return_value=Feed( + id="feed-uuid", + producer_url="https://example.com/different", + stable_id="tld-feed1_2", + status="active", + ) ) + new_feed = processor._process_existing_feed_refs(feed_payload, matching_feeds) + assert new_feed is not None - # Verify no deprecation log since old feed wasn't found - deprecation_log = f"Deprecating old feed ID: {old_feed_id}" - assert mock.call(deprecation_log) not in mock_logging.debug.call_args_list + # 3. No existing feed with same stable_id + matching_feeds = [ + Gtfsfeed( + id="feed-uuid", + producer_url="https://example.com/different", + stable_id="tld-different", + status="active", + ) + ] + processor.feed_stable_id = "tld-feed1" + _ = processor._process_existing_feed_refs(feed_payload, matching_feeds) + create_new_feed_mock.assert_called_once() + + @patch("feed_sync_process_transitland.src.main.create_new_feed") + def test_update_feed(self, create_new_feed_mock, processor, feed_payload): + """Test updating an existing feed.""" + # No matching feed + processor._deprecate_old_feed(feed_payload, None) + create_new_feed_mock.assert_called_once() + # Provided id but no db entity + processor.session.get.return_value = None + processor._deprecate_old_feed(feed_payload, "feed-uuid") + create_new_feed_mock.assert_called() + # Update existing feed + returned_feed = Gtfsfeed( + id="feed-uuid", + producer_url="https://example.com", + stable_id="TLD-test123", + status="active", + ) + processor.session.get.return_value = returned_feed + processor._deprecate_old_feed(feed_payload, "feed-uuid") + assert returned_feed.status == "deprecated" diff --git a/functions-python/helpers/database.py b/functions-python/helpers/database.py index 92a31e7db..2d89a03d3 100644 --- a/functions-python/helpers/database.py +++ b/functions-python/helpers/database.py @@ -14,13 +14,13 @@ # limitations under the License. # +import logging import os import threading from typing import Final from sqlalchemy import create_engine, text, event from sqlalchemy.orm import sessionmaker, mapper, class_mapper -import logging from database_gen.sqlacodegen_models import Feed, Gtfsfeed, Gtfsrealtimefeed, Gbfsfeed @@ -97,7 +97,7 @@ def start_new_db_session(database_url: str = None, echo: bool = True): return sessionmaker(bind=get_db_engine(database_url, echo=echo))() -def start_singleton_db_session(database_url: str = None): +def start_singleton_db_session(database_url: str = None, echo: bool = True): """ :return: Database singleton session """ @@ -106,7 +106,7 @@ def start_singleton_db_session(database_url: str = None): if global_session is not None: logging.info("Database session reused.") return global_session - global_session = start_new_db_session(database_url) + global_session = start_new_db_session(database_url, echo) logging.info("Singleton Database session started.") return global_session except Exception as error: @@ -121,7 +121,7 @@ def start_db_session(database_url: str = None, echo: bool = True): try: lock.acquire() if is_session_reusable(): - return start_singleton_db_session(database_url) + return start_singleton_db_session(database_url, echo=echo) logging.info("Not reusing the previous session, starting new database session.") return start_new_db_session(database_url, echo) except Exception as error: diff --git a/functions-python/helpers/feed_sync/models.py b/functions-python/helpers/feed_sync/models.py index 54f769dec..ce009f001 100644 --- a/functions-python/helpers/feed_sync/models.py +++ b/functions-python/helpers/feed_sync/models.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +import json +from dataclasses import dataclass, asdict from typing import Optional @@ -8,15 +9,23 @@ class TransitFeedSyncPayload: external_id: str feed_id: str - feed_url: str - execution_id: Optional[str] - spec: str - auth_info_url: Optional[str] - auth_param_name: Optional[str] - type: Optional[str] - operator_name: Optional[str] - country: Optional[str] - state_province: Optional[str] - city_name: Optional[str] - source: str - payload_type: str + stable_id: str + entity_types: Optional[str] = None + feed_url: Optional[str] = None + execution_id: Optional[str] = None + spec: Optional[str] = None + auth_info_url: Optional[str] = None + auth_param_name: Optional[str] = None + type: Optional[str] = None + operator_name: Optional[str] = None + country: Optional[str] = None + state_province: Optional[str] = None + city_name: Optional[str] = None + source: Optional[str] = None + payload_type: Optional[str] = None + + def to_dict(self): + return asdict(self) + + def to_json(self): + return json.dumps(self.to_dict()) diff --git a/infra/.terraform.lock.hcl b/infra/.terraform.lock.hcl index 0743df4cd..d973264ea 100644 --- a/infra/.terraform.lock.hcl +++ b/infra/.terraform.lock.hcl @@ -2,58 +2,59 @@ # Manual edits may be lost in future updates. provider "registry.terraform.io/hashicorp/external" { - version = "2.3.2" + version = "2.3.4" hashes = [ - "h1:cy50n4q+Ir4GYppAfuYhQbBJVxMZbJUlIvM6FVK2axs=", - "zh:020bf652739ecd841d696e6c1b85ce7dd803e9177136df8fb03aa08b87365389", - "zh:0c7ea5a1cbf2e01a8627b8a84df69c93683f39fe947b288e958e72b9d12a827f", - "zh:25a68604c7d6aa736d6e99225051279eaac3a7cf4cab33b00ff7eae7096166f6", - "zh:34f46d82ca34604f6522de3b36eda19b7ad3be1e38947afc6ac31656eab58c8a", - "zh:6959f8f2f3de93e61e0abb90dbec41e28a66daec1607c46f43976bd6da50bcfd", + "h1:cCabxnWQ5fX1lS7ZqgUzsvWmKZw9FA7NRxAZ94vcTcc=", + "zh:037fd82cd86227359bc010672cd174235e2d337601d4686f526d0f53c87447cb", + "zh:0ea1db63d6173d01f2fa8eb8989f0809a55135a0d8d424b08ba5dabad73095fa", + "zh:17a4d0a306566f2e45778fbac48744b6fd9c958aaa359e79f144c6358cb93af0", + "zh:298e5408ab17fd2e90d2cd6d406c6d02344fe610de5b7dae943a58b958e76691", + "zh:38ecfd29ee0785fd93164812dcbe0664ebbe5417473f3b2658087ca5a0286ecb", + "zh:59f6a6f31acf66f4ea3667a555a70eba5d406c6e6d93c2c641b81d63261eeace", "zh:78d5eefdd9e494defcb3c68d282b8f96630502cac21d1ea161f53cfe9bb483b3", - "zh:a81e5d65a343da9caa6f1d17ae0aced9faecb36b4f8554bd445dbd4f8be21ab6", - "zh:b1d3f1557214d652c9120862ce27e9a7b61cb5aec5537a28240a5a37bf0b1413", - "zh:b71588d006471ae2d4a7eca2c51d69fd7c5dec9b088315599b794e2ad0cc5e90", - "zh:cfdaae4028b644dff3530c77b49d31f7e6f4c4e2a9e5c8ac6a88e383c80c9e9c", - "zh:dbde15154c2eb38a5f54d0e7646bc67510004179696f3cc2bc1d877cecacf83b", - "zh:fb681b363f83fb5f64dfa6afbf32d100d0facd2a766cf3493b8ddb0398e1b0f7", + "zh:ad0279dfd09d713db0c18469f585e58d04748ca72d9ada83883492e0dd13bd58", + "zh:c69f66fd21f5e2c8ecf7ca68d9091c40f19ad913aef21e3ce23836e91b8cbb5f", + "zh:d4a56f8c48aa86fc8e0c233d56850f5783f322d6336f3bf1916e293246b6b5d4", + "zh:f2b394ebd4af33f343835517e80fc876f79361f4688220833bc3c77655dd2202", + "zh:f31982f29f12834e5d21e010856eddd19d59cd8f449adf470655bfd19354377e", ] } provider "registry.terraform.io/hashicorp/google" { - version = "5.14.0" + version = "5.34.0" + constraints = "5.34.0" hashes = [ - "h1:T6EW5HOI1IrE4zHzQ/5kLyul+U2ByEaIgqMu4Ja7JFI=", - "zh:3927ef7417d9d8a56077e6655d76c99f4175f9746e39226a00ee0555f8c63f8f", - "zh:4b4f521f0779a1797047a8c531afda093aade934b4a49c080fe8d38680b3a52f", - "zh:7e880c5b72684fc8342e03180a1fbbec65c6afeb70511b9c16181d5e168269e6", - "zh:81a7f2efc30e698f476d3e240ee2d82f14eda374852059429fe808ad77b6addd", - "zh:826d4ea55b4afceefb332646f21c6b6dc590b39b16e8d9b5d4a4211beb91dc5e", - "zh:865600ef669fcdd4ae77515c3fd12565fab0f2a263fa2a6dae562f6fe68ed093", - "zh:8e933d1d10fd316e62340175667264f093e4d24457b63d5adf3c424cce22b495", - "zh:bf261924f7350074a355e5b9337f3a8054efb20d316e9085f2b5766dfb5126c4", - "zh:e28e67dcbd4bbd82798561baf86d3dd04f97e08bbf523dfb9f355564ef27d3d6", - "zh:f33cdd3117af8a15f33d375dbe398a5e558730cf6a7a145a479ab68e77572c12", + "h1:t48NNfGkdHByEWWiKx6GtlZPlzEB1Dha3cq44Uidev0=", + "zh:143c88bb74631041c291ebf7b6213467bf376d3775a33815785939dc102fac09", + "zh:1616ac79345f472b33fcc388eaf4a8a3002e4cc3a5e8748d60d6f4786d0d16dc", + "zh:554ce78e73349ac2c893a74b6981f5e55169ca16f4d0f239e6ccdecadbe1c9e1", + "zh:8022f97aa907685b2eb6c39d5411cf2be2448c6f3b7fbeaf9c06618d376ac4bc", + "zh:85f1fe3628954c35379cc05b895091ec8fe8ba0a5610bc9660492d5be65d4902", + "zh:873fb64fca79695aa930cd868b41ac498809eb76bc3292e41460d916c6fa3538", + "zh:8d3c5112a4abf14b42d769f78373e66f2c2f5f03a7e6544d80019a695bd9b198", + "zh:93cbcfa38991965b976d1973bc528d666006b5247c3fda00c714d0f3a2b62d3e", + "zh:b7710246637aee522a4ea4c1b4f0effb83b701546124ae90c8b4afb97ce03aba", + "zh:e4e02fe946ccbe192b6bbc6bed5715cf68084f1faadc134ed99d5e732429d2ca", "zh:f569b65999264a9416862bca5cd2a6177d94ccb0424f3a4ef424428912b9cb3c", - "zh:f913a0e0708391ccd26fc3458158cc1e10d68dc621bef3a1583328c61a77225d", + "zh:fb6b1e4fb2d019d2740aa21b5ecd5f0609f562633a78604a96c14c94aff762b4", ] } provider "registry.terraform.io/hashicorp/google-beta" { - version = "5.14.0" + version = "6.12.0" hashes = [ - "h1:DnzCevNKIci+oXY2/UgGV5Op5T1nMeRKuMuMNjRpKFk=", - "zh:04fc82cc77d944d6e636bd43dcd77c59030920e2b51d404290439bf26e8418cc", - "zh:22a2c9c8b95313c302a4e82bc785d7db2775e199ff0a006c21352d1f4bdb43ea", - "zh:3b65e6135ecdf0e93f431d1ced843caaf569b9c74fcc36ffbc3d100334f0745d", - "zh:56ad10bbc239758ac500afa24cd5948241c4441d6eb6b8450ae3e5e834f73b08", - "zh:612e22f969cd9f6a01a109b81ecbccfec6f9dec812eb5e78810b695a0fce4df0", - "zh:688282d0dade872542dbc1e5eba4e9f3e279a27b6c5a1c2970a86fc526f8d4f0", - "zh:78158d2af1b3e94531bccc91858e002ddd7f6d12f391b7e872c9eeffc9457611", - "zh:7c80aa9c1a4e5009dc8e39aa43abca3dd589005137424377f38848e3b088daff", - "zh:989eedca70db644d94bfbd713e8d746d001fac5002e6ab8a0b10358ce2881276", - "zh:9fd531998aae684bd49b650a40d2d72ef5235002670af90d716bf909564c29ab", - "zh:c51551923f4227c5efc0795528ed21a1b9184e18c05007168b02e0ffb8102187", + "h1:VapaEnmxoM9kGlZ/BH7hZmUy4fvutIOkPcuRRwF/TOo=", + "zh:01ccef122918871d26a00dd7418fdbd62aa5433b31d2baf58ca6b8b512d7567d", + "zh:1d5b72c26dd5143a7d55674912ee4ffab0aaf44f7a998b4878ea0c37256740eb", + "zh:45588f2ad7e5c24ed444ce17041c6d3d02fea116bf0cb1fa416d2d6df78923c0", + "zh:6552cf328df297f9dec9b251a02a9be50f59d4ecc99cc8da48ec580d37b24067", + "zh:982d7adc9be96d47a4425bd1d32ca67a38b72d2ca535f66b3de5a99c7bf5213b", + "zh:99028261de774304d536e25f9d65dee1ce13f3e5111ded8afb691294a6bfbdf4", + "zh:a1c5e3efe2b3403883c3ba98d8b1d2a9599b327cfae4f67ed41c35f9c9971473", + "zh:a8b30370f4cc22af70a9054f773f15b96ee40fcc9f292e1443b872ce8ea369ab", + "zh:ac7101061e9a54c28b6ff634de6fab38c4f71f23c9dbc88828c1a62cbe0871c5", + "zh:f3adfa744e9da50bfe5cf334c8ffcb6ccc57e8fcbd4ddd0d3a548f44a229c849", "zh:f569b65999264a9416862bca5cd2a6177d94ccb0424f3a4ef424428912b9cb3c", + "zh:fefc9adf1719d4e3c0d4018e5c3483668874e4f1d57b8d0f052d3469b43d99e1", ] } diff --git a/infra/functions-python/main.tf b/infra/functions-python/main.tf index c0eb7645c..583a0b53b 100644 --- a/infra/functions-python/main.tf +++ b/infra/functions-python/main.tf @@ -1,3 +1,11 @@ +terraform { + required_providers { + google = { + source = "hashicorp/google" + version = "5.34.0" + } + } +} # # MobilityData 2023 # @@ -37,6 +45,9 @@ locals { function_feed_sync_dispatcher_transitland_config = jsondecode(file("${path.module}/../../functions-python/feed_sync_dispatcher_transitland/function_config.json")) function_feed_sync_dispatcher_transitland_zip = "${path.module}/../../functions-python/feed_sync_dispatcher_transitland/.dist/feed_sync_dispatcher_transitland.zip" + function_feed_sync_process_transitland_config = jsondecode(file("${path.module}/../../functions-python/feed_sync_process_transitland/function_config.json")) + function_feed_sync_process_transitland_zip = "${path.module}/../../functions-python/feed_sync_process_transitland/.dist/feed_sync_process_transitland.zip" + function_operations_api_config = jsondecode(file("${path.module}/../../functions-python/operations_api/function_config.json")) function_operations_api_zip = "${path.module}/../../functions-python/operations_api/.dist/operations_api.zip" } @@ -61,6 +72,9 @@ data "google_vpc_access_connector" "vpc_connector" { project = local.vpc_connector_project } +data "google_pubsub_topic" "datasets_batch_topic" { + name = "datasets-batch-topic-${var.environment}" +} # Service account to execute the cloud functions resource "google_service_account" "functions_service_account" { @@ -119,7 +133,14 @@ resource "google_storage_bucket_object" "feed_sync_dispatcher_transitland_zip" { source = local.function_feed_sync_dispatcher_transitland_zip } -# 7. Operations API +# 7. Feed sync process transitland +resource "google_storage_bucket_object" "feed_sync_process_transitland_zip" { + bucket = google_storage_bucket.functions_bucket.name + name = "feed-sync-process-transitland-${substr(filebase64sha256(local.function_feed_sync_process_transitland_zip), 0, 10)}.zip" + source = local.function_feed_sync_process_transitland_zip +} + +# 8. Operations API resource "google_storage_bucket_object" "operations_api_zip" { bucket = google_storage_bucket.functions_bucket.name name = "operations-api-${substr(filebase64sha256(local.function_operations_api_zip), 0, 10)}.zip" @@ -193,7 +214,7 @@ resource "google_cloudfunctions2_function" "extract_location" { } event_filters { attribute = "resourceName" - value = "projects/_/buckets/mobilitydata-datasets-${var.environment}/objects/mdb-*/mdb-*/mdb-*.zip" + value = "projects/_/buckets/mobilitydata-datasets-${var.environment}/objects/*/*/*.zip" operator = "match-path-pattern" } } @@ -490,6 +511,27 @@ resource "google_cloud_scheduler_job" "gbfs_validator_batch_scheduler" { attempt_deadline = "320s" } +resource "google_cloud_scheduler_job" "transit_land_scraping_scheduler" { + name = "transitland-scraping-scheduler-${var.environment}" + description = "Schedule the transitland scraping function" + time_zone = "Etc/UTC" + schedule = var.transitland_scraping_schedule + region = var.gcp_region + paused = var.environment == "prod" ? false : true + depends_on = [google_cloudfunctions2_function.feed_sync_dispatcher_transitland, google_cloudfunctions2_function_iam_member.transitland_feeds_dispatcher_invoker] + http_target { + http_method = "POST" + uri = google_cloudfunctions2_function.feed_sync_dispatcher_transitland.url + oidc_token { + service_account_email = google_service_account.functions_service_account.email + } + headers = { + "Content-Type" = "application/json" + } + } + attempt_deadline = "320s" +} + # 5.3 Create function that subscribes to the Pub/Sub topic resource "google_cloudfunctions2_function" "gbfs_validator_pubsub" { name = "${local.function_gbfs_validation_report_config.name}-pubsub" @@ -592,6 +634,7 @@ resource "google_cloudfunctions2_function" "feed_sync_dispatcher_transitland" { } } +# 7. functions/operations_api cloud function resource "google_cloudfunctions2_function" "operations_api" { name = "${local.function_operations_api_config.name}" description = local.function_operations_api_config.description @@ -635,6 +678,58 @@ resource "google_cloudfunctions2_function" "operations_api" { } } } +# 8. functions/feed_sync_process_transitland cloud function +resource "google_cloudfunctions2_function" "feed_sync_process_transitland" { + name = "${local.function_feed_sync_process_transitland_config.name}-pubsub" + description = local.function_feed_sync_process_transitland_config.description + location = var.gcp_region + depends_on = [google_project_iam_member.event-receiving, google_secret_manager_secret_iam_member.secret_iam_member] + event_trigger { + trigger_region = var.gcp_region + service_account_email = google_service_account.functions_service_account.email + event_type = "google.cloud.pubsub.topic.v1.messagePublished" + pubsub_topic = google_pubsub_topic.transitland_feeds_dispatch.id + retry_policy = "RETRY_POLICY_RETRY" + } + build_config { + runtime = var.python_runtime + entry_point = local.function_feed_sync_process_transitland_config.entry_point + source { + storage_source { + bucket = google_storage_bucket.functions_bucket.name + object = google_storage_bucket_object.feed_sync_process_transitland_zip.name + } + } + } + service_config { + available_memory = local.function_feed_sync_process_transitland_config.memory + timeout_seconds = local.function_feed_sync_process_transitland_config.timeout + available_cpu = local.function_feed_sync_process_transitland_config.available_cpu + max_instance_request_concurrency = local.function_feed_sync_process_transitland_config.max_instance_request_concurrency + max_instance_count = local.function_feed_sync_process_transitland_config.max_instance_count + min_instance_count = local.function_feed_sync_process_transitland_config.min_instance_count + service_account_email = google_service_account.functions_service_account.email + ingress_settings = var.environment == "dev" ? "ALLOW_ALL" : local.function_feed_sync_process_transitland_config.ingress_settings + vpc_connector = data.google_vpc_access_connector.vpc_connector.id + vpc_connector_egress_settings = "PRIVATE_RANGES_ONLY" + environment_variables = { + PYTHONNODEBUGRANGES = 0 + DB_REUSE_SESSION = "True" + PROJECT_ID = var.project_id + PUBSUB_TOPIC_NAME = google_pubsub_topic.transitland_feeds_dispatch.name + DATASET_BATCH_TOPIC_NAME = data.google_pubsub_topic.datasets_batch_topic.name + } + dynamic "secret_environment_variables" { + for_each = local.function_feed_sync_process_transitland_config.secret_environment_variables + content { + key = secret_environment_variables.value["key"] + project_id = var.project_id + secret = "${upper(var.environment)}_${secret_environment_variables.value["key"]}" + version = "latest" + } + } + } +} # IAM entry for all users to invoke the function resource "google_cloudfunctions2_function_iam_member" "tokens_invoker" { @@ -759,12 +854,21 @@ resource "google_cloudfunctions2_function_iam_member" "gbfs_validator_batch_invo member = "serviceAccount:${google_service_account.functions_service_account.email}" } +resource "google_cloudfunctions2_function_iam_member" "transitland_feeds_dispatcher_invoker" { + project = var.project_id + location = var.gcp_region + cloud_function = google_cloudfunctions2_function.feed_sync_dispatcher_transitland.name + role = "roles/cloudfunctions.invoker" + member = "serviceAccount:${google_service_account.functions_service_account.email}" +} + # Grant permissions to the service account to publish to the pubsub topic resource "google_pubsub_topic_iam_member" "functions_publisher" { for_each = { dataset_updates = google_pubsub_topic.dataset_updates.name validate_gbfs_feed = google_pubsub_topic.validate_gbfs_feed.name feed_sync_dispatcher_transitland = google_pubsub_topic.transitland_feeds_dispatch.name + dataset_batch = data.google_pubsub_topic.datasets_batch_topic.name } project = var.project_id diff --git a/infra/functions-python/vars.tf b/infra/functions-python/vars.tf index 8c68c2a3d..59dc8b0cf 100644 --- a/infra/functions-python/vars.tf +++ b/infra/functions-python/vars.tf @@ -65,6 +65,12 @@ variable "gbfs_scheduler_schedule" { default = "0 0 1 * *" # every month on the first day at 00:00 } +variable "transitland_scraping_schedule" { + type = string + description = "Schedule for the GBFS scheduler job" + default = "0 0 3 * *" # every month on the 3rd day at 00:00 +} + variable "transitland_api_key" { type = string description = "Transitland API key" diff --git a/infra/workflows/main.tf b/infra/workflows/main.tf index f3d93bee5..9c293b6b8 100644 --- a/infra/workflows/main.tf +++ b/infra/workflows/main.tf @@ -95,7 +95,7 @@ resource "google_eventarc_trigger" "gtfs_validator_trigger" { } matching_criteria { attribute = "resourceName" - value = "projects/_/buckets/${var.datasets_bucket_name}-${var.environment}/objects/mdb-*/mdb-*/mdb-*.zip" + value = "projects/_/buckets/${var.datasets_bucket_name}-${var.environment}/objects/*/*/*.zip" operator = "match-path-pattern" } diff --git a/integration-tests/src/endpoints/integration_tests.py b/integration-tests/src/endpoints/integration_tests.py index 40c6519d2..6e0cf9779 100644 --- a/integration-tests/src/endpoints/integration_tests.py +++ b/integration-tests/src/endpoints/integration_tests.py @@ -143,7 +143,7 @@ def _sample_municipalities(df, n): num_samples = min(len(unique_country_codes), n) return pandas.Series(unique_country_codes).sample(n=num_samples, random_state=1) - def get_response(self, url_suffix, params=None, timeout=10): + def get_response(self, url_suffix, params=None, timeout=15): """Helper function to get response from the API.""" url = self.base_url + "/" + url_suffix headers = {