diff --git a/functions-python/batch_process_dataset/requirements.txt b/functions-python/batch_process_dataset/requirements.txt index 5309c4e65..0c90abf30 100644 --- a/functions-python/batch_process_dataset/requirements.txt +++ b/functions-python/batch_process_dataset/requirements.txt @@ -21,4 +21,4 @@ google-api-core google-cloud-firestore google-cloud-datastore google-cloud-bigquery -cloudevents~=1.10.1 \ No newline at end of file +cloudevents~=1.10.1 diff --git a/functions-python/feed_sync_process_transitland/.coveragerc b/functions-python/feed_sync_process_transitland/.coveragerc new file mode 100644 index 000000000..c52988ffd --- /dev/null +++ b/functions-python/feed_sync_process_transitland/.coveragerc @@ -0,0 +1,9 @@ +[run] +omit = + */test*/* + */dataset_service/* + */helpers/* + +[report] +exclude_lines = + if __name__ == .__main__.: \ No newline at end of file diff --git a/functions-python/feed_sync_process_transitland/.env.rename_me b/functions-python/feed_sync_process_transitland/.env.rename_me new file mode 100644 index 000000000..601002cd5 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/.env.rename_me @@ -0,0 +1,5 @@ +# Environment variables for tokens function to run locally. Delete this line after rename the file. +FEEDS_DATABASE_URL=postgresql://postgres:postgres@localhost:54320/MobilityDatabase +PROJECT_ID=mobility-feeds-dev +PUBSUB_TOPIC_NAME=my-topic +DATASET_BATCH_TOPIC_NAME=dataset_batch_topic_{env}_ diff --git a/functions-python/feed_sync_process_transitland/README.md b/functions-python/feed_sync_process_transitland/README.md new file mode 100644 index 000000000..8420508f3 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/README.md @@ -0,0 +1,107 @@ +# TLD Feed Sync Process + +Subscribed to the topic set in the `feed-sync-dispatcher` function, `feed-sync-process` is triggered for each message published. It handles the processing of feed updates, ensuring data consistency and integrity. The function performs the following operations: + +1. **Feed Status Check**: It verifies the current state of the feed in the database using external_id and source. +2. **URL Validation**: Checks if the feed URL already exists in the database. +3. **Feed Processing**: Based on the current state: + - If no existing feed is found, creates a new feed entry + - If feed exists with a different URL, creates a new feed and deprecates the old one + - If feed exists with the same URL, no action is taken +4. **Batch Processing Trigger**: For non-authenticated feeds, publishes events to the dataset batch topic for further processing. + +The function maintains feed history through the `redirectingid` table and ensures proper status tracking with 'active' and 'deprecated' states. + +# Message Format +The function expects a Pub/Sub message with the following format: +```json +{ + "message": { + "data": { + "external_id": "feed-identifier", + "feed_id": "unique-feed-id", + "feed_url": "http://example.com/feed", + "execution_id": "execution-identifier", + "spec": "gtfs", + "auth_info_url": null, + "auth_param_name": null, + "type": null, + "operator_name": "Transit Agency Name", + "country": "Country Name", + "state_province": "State/Province", + "city_name": "City Name", + "source": "TLD", + "payload_type": "new|update" + } + } +} +``` + +# Function Configuration +The function is configured using the following environment variables: +- `PROJECT_ID`: The Google Cloud project ID +- `DATASET_BATCH_TOPIC_NAME`: The name of the topic for batch processing triggers +- `FEEDS_DATABASE_URL`: The URL of the feeds database +- `ENV`: [Optional] Environment identifier (e.g., 'dev', 'prod') + +# Database Schema +The function interacts with the following tables: +1. `feed`: Stores feed information + - Contains fields like id, data_type, feed_name, producer_url, etc. + - Tracks feed status ('active' or 'deprecated') + - Uses CURRENT_TIMESTAMP for created_at + +2. `externalid`: Maps external identifiers to feed IDs + - Links external_id and source to feed entries + - Maintains source tracking + +3. `redirectingid`: Tracks feed updates + - Maps old feed IDs to new ones + - Maintains update history + +# Local development +The local development of this function follows the same steps as the other functions. + +Install Google Pub/Sub emulator, please refer to the [README.md](../README.md) file for more information. + +## Python requirements + +- Install the requirements +```bash + pip install -r ./functions-python/feed_sync_process_transitland/requirements.txt +``` + +## Test locally with Google Cloud Emulators + +- Execute the following commands to start the emulators: +```bash + gcloud beta emulators pubsub start --project=test-project --host-port='localhost:8043' +``` + +- Create a Pub/Sub topic in the emulator: +```bash + curl -X PUT "http://localhost:8043/v1/projects/test-project/topics/feed-sync-transitland" +``` + +- Start function +```bash + export PUBSUB_EMULATOR_HOST=localhost:8043 && ./scripts/function-python-run.sh --function_name feed_sync_process_transitland +``` + +- [Optional]: Create a local subscription to print published messages: +```bash +./scripts/pubsub_message_print.sh feed-sync-process-transitland +``` + +- Execute function +```bash + curl http://localhost:8080 +``` + +- To run/debug from your IDE use the file `main_local_debug.py` + +# Test +- Run the tests +```bash + ./scripts/api-tests.sh --folder functions-python/feed_sync_dispatcher_transitland +``` diff --git a/functions-python/feed_sync_process_transitland/function_config.json b/functions-python/feed_sync_process_transitland/function_config.json new file mode 100644 index 000000000..088c8bd32 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/function_config.json @@ -0,0 +1,19 @@ +{ + "name": "feed-sync-process-transitland", + "description": "Feed Sync process for Transitland feeds", + "entry_point": "process_feed_event", + "timeout": 540, + "memory": "512Mi", + "trigger_http": true, + "include_folders": ["database_gen", "helpers"], + "secret_environment_variables": [ + { + "key": "FEEDS_DATABASE_URL" + } + ], + "ingress_settings": "ALLOW_INTERNAL_AND_GCLB", + "max_instance_request_concurrency": 20, + "max_instance_count": 10, + "min_instance_count": 0, + "available_cpu": 1 +} diff --git a/functions-python/feed_sync_process_transitland/main_local_debug.py b/functions-python/feed_sync_process_transitland/main_local_debug.py new file mode 100644 index 000000000..60a3b1723 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/main_local_debug.py @@ -0,0 +1,173 @@ +""" +Code to be able to debug locally without affecting the runtime cloud function. + +Requirements: +- Google Cloud SDK installed +- Make sure to have the following environment variables set in your .env.local file: + - PROJECT_ID + - DATASET_BATCH_TOPIC_NAME + - FEEDS_DATABASE_URL +- Local database in running state + +Usage: +- python feed_sync_process_transitland/main_local_debug.py +""" + +import base64 +import json +import os +from unittest.mock import MagicMock, patch +import logging +import sys + +import pytest +from dotenv import load_dotenv + +# Configure local logging first +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + stream=sys.stdout, +) + +logger = logging.getLogger("feed_processor") + +# Mock the Google Cloud Logger + + +class MockLogger: + + """Mock logger class""" + + @staticmethod + def init_logger(): + return MagicMock() + + def __init__(self, name): + self.name = name + + def get_logger(self): + return logger + + def addFilter(self, filter): + pass + + +with patch("helpers.logger.Logger", MockLogger): + from feed_sync_process_transitland.src.main import process_feed_event + +# Load environment variables +load_dotenv(dotenv_path=".env.rename_me") + + +class CloudEvent: + """Cloud Event data structure.""" + + def __init__(self, attributes: dict, data: dict): + self.attributes = attributes + self.data = data + + +@pytest.fixture +def mock_pubsub(): + """Fixture to mock PubSub client""" + with patch("google.cloud.pubsub_v1.PublisherClient") as mock_publisher: + publisher_instance = MagicMock() + + def mock_topic_path(project_id, topic_id): + return f"projects/{project_id}/topics/{topic_id}" + + def mock_publish(topic_path, data): + logger.info( + f"[LOCAL DEBUG] Would publish to {topic_path}: {data.decode('utf-8')}" + ) + future = MagicMock() + future.result.return_value = "message_id" + return future + + publisher_instance.topic_path.side_effect = mock_topic_path + publisher_instance.publish.side_effect = mock_publish + mock_publisher.return_value = publisher_instance + + yield mock_publisher + + +def process_event_safely(cloud_event, description=""): + """Process event with error handling.""" + try: + logger.info(f"\nProcessing {description}:") + logger.info("-" * 50) + result = process_feed_event(cloud_event) + logger.info(f"Process result: {result}") + return True + except Exception as e: + logger.error(f"Error processing {description}: {str(e)}") + return False + + +def main(): + """Main function to run local debug tests""" + logger.info("Starting local debug session...") + + # Define test event data + test_payload = { + "external_id": "test-feed-1", + "feed_id": "feed1", + "feed_url": "https://example.com/test-feed-2", + "execution_id": "local-debug-123", + "spec": "gtfs", + "auth_info_url": None, + "auth_param_name": None, + "type": None, + "operator_name": "Test Operator", + "country": "USA", + "state_province": "CA", + "city_name": "Test City", + "source": "TLD", + "payload_type": "new", + } + + # Create cloud event + cloud_event = CloudEvent( + attributes={ + "type": "com.google.cloud.pubsub.topic.publish", + "source": f"//pubsub.googleapis.com/projects/{os.getenv('PROJECT_ID')}/topics/test-topic", + }, + data={ + "message": { + "data": base64.b64encode( + json.dumps(test_payload).encode("utf-8") + ).decode("utf-8") + } + }, + ) + + # Set up mocks + with patch( + "google.cloud.pubsub_v1.PublisherClient", new_callable=MagicMock + ) as mock_publisher, patch("google.cloud.logging.Client", MagicMock()): + publisher_instance = MagicMock() + + def mock_topic_path(project_id, topic_id): + return f"projects/{project_id}/topics/{topic_id}" + + def mock_publish(topic_path, data): + logger.info( + f"[LOCAL DEBUG] Would publish to {topic_path}: {data.decode('utf-8')}" + ) + future = MagicMock() + future.result.return_value = "message_id" + return future + + publisher_instance.topic_path.side_effect = mock_topic_path + publisher_instance.publish.side_effect = mock_publish + mock_publisher.return_value = publisher_instance + + # Process test event + process_event_safely(cloud_event, "test feed event") + + logger.info("Local debug session completed.") + + +if __name__ == "__main__": + main() diff --git a/functions-python/feed_sync_process_transitland/requirements.txt b/functions-python/feed_sync_process_transitland/requirements.txt new file mode 100644 index 000000000..b91a52224 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/requirements.txt @@ -0,0 +1,23 @@ +# Common packages +functions-framework==3.* +google-cloud-logging +psycopg2-binary==2.9.6 +aiohttp~=3.10.5 +asyncio~=3.4.3 +urllib3~=2.2.2 +requests~=2.32.3 +attrs~=23.1.0 +pluggy~=1.3.0 +certifi~=2024.8.30 + +# SQL Alchemy and Geo Alchemy +SQLAlchemy==2.0.23 +geoalchemy2==0.14.7 + +# Google specific packages for this function +google-cloud-pubsub +cloudevents~=1.10.1 + +# Additional packages for this function +pandas +pycountry diff --git a/functions-python/feed_sync_process_transitland/requirements_dev.txt b/functions-python/feed_sync_process_transitland/requirements_dev.txt new file mode 100644 index 000000000..9ee50adce --- /dev/null +++ b/functions-python/feed_sync_process_transitland/requirements_dev.txt @@ -0,0 +1,2 @@ +Faker +pytest~=7.4.3 \ No newline at end of file diff --git a/functions-python/feed_sync_process_transitland/src/__init__.py b/functions-python/feed_sync_process_transitland/src/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/functions-python/feed_sync_process_transitland/src/main.py b/functions-python/feed_sync_process_transitland/src/main.py new file mode 100644 index 000000000..1a6a3b6c0 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/src/main.py @@ -0,0 +1,476 @@ +# +# MobilityData 2024 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import json +import logging +import os +import uuid +from typing import Optional, Tuple + +import functions_framework +from google.cloud import pubsub_v1 +from sqlalchemy.orm import Session +from database_gen.sqlacodegen_models import Feed, Externalid, Redirectingid +from sqlalchemy.exc import SQLAlchemyError + +from helpers.database import start_db_session, close_db_session +from helpers.logger import Logger, StableIdFilter +from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload +from helpers.locations import create_or_get_location + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +logger = logging.getLogger("feed_processor") +handler = logging.StreamHandler() +handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) +logger.addHandler(handler) +logger.setLevel(logging.INFO) + +# Initialize GCP logger for cloud environment +Logger.init_logger() +gcp_logger = Logger("feed_processor").get_logger() + + +def log_message(level, message): + """Log messages to both local and GCP loggers""" + if level == "info": + logger.info(message) + gcp_logger.info(message) + elif level == "error": + logger.error(message) + gcp_logger.error(message) + elif level == "warning": + logger.warning(message) + gcp_logger.warning(message) + elif level == "debug": + logger.debug(message) + gcp_logger.debug(message) + + +# Environment variables +PROJECT_ID = os.getenv("PROJECT_ID") +DATASET_BATCH_TOPIC = os.getenv("DATASET_BATCH_TOPIC_NAME") +FEEDS_DATABASE_URL = os.getenv("FEEDS_DATABASE_URL") + + +class FeedProcessor: + """Handles feed processing operations""" + + def __init__(self, db_session: Session): + self.session = db_session + self.publisher = pubsub_v1.PublisherClient() + + def process_feed(self, payload: FeedPayload) -> None: + """ + Processes feed idempotently based on database state + + Args: + payload (FeedPayload): The feed payload to process + """ + gcp_logger.addFilter(StableIdFilter(payload.external_id)) + try: + log_message( + "info", + f"Starting feed processing for external_id: {payload.external_id}", + ) + + # Check current state of feed in database + current_feed_id, current_url = self.get_current_feed_info( + payload.external_id, payload.source + ) + + if current_feed_id is None: + log_message("info", "Processing new feed") + # If no existing feed_id found - check if URL exists in any feed + if self.check_feed_url_exists(payload.feed_url): + log_message("error", f"Feed URL already exists: {payload.feed_url}") + return + self.process_new_feed(payload) + else: + # If Feed exists - check if URL has changed + if current_url != payload.feed_url: + log_message("info", "Processing feed update") + log_message( + "debug", + f"Found existing feed: {current_feed_id} with different URL", + ) + self.process_feed_update(payload, current_feed_id) + else: + log_message( + "error", + f"Feed already exists with same URL: {payload.external_id}", + ) + return + + self.session.commit() + log_message("debug", "Database transaction committed successfully") + + # Publish to dataset_batch_topic if not authenticated + if not payload.auth_info_url: + self.publish_to_batch_topic(payload) + + except SQLAlchemyError as e: + error_msg = ( + f"Database error processing feed {payload.external_id}: {str(e)}" + ) + log_message("error", error_msg) + self.session.rollback() + log_message("error", "Database transaction rolled back due to error") + raise + except Exception as e: + error_msg = f"Error processing feed {payload.external_id}: {str(e)}" + log_message("error", error_msg) + self.session.rollback() + log_message("error", "Database transaction rolled back due to error") + raise + + def process_new_feed(self, payload: FeedPayload) -> None: + """ + Process creation of a new feed + + Args: + payload (FeedPayload): The feed payload for new feed + """ + try: + log_message( + "info", + f"Starting new feed creation for external_id: {payload.external_id}", + ) + + # Check if feed with same URL exists + if self.check_feed_url_exists(payload.feed_url): + log_message("error", f"Feed URL already exists: {payload.feed_url}") + return + + # Generate new feed ID and stable ID + feed_id = str(uuid.uuid4()) + stable_id = f"{payload.source}-{payload.external_id}" + + log_message( + "debug", f"Generated new feed_id: {feed_id} and stable_id: {stable_id}" + ) + + try: + # Create new feed + new_feed = Feed( + id=feed_id, + data_type=payload.spec, + producer_url=payload.feed_url, + authentication_type=payload.type if payload.type else "0", + authentication_info_url=payload.auth_info_url, + api_key_parameter_name=payload.auth_param_name, + stable_id=stable_id, + status="active", + provider=payload.operator_name, + operational_status="wip", + ) + + # external ID mapping + external_id = Externalid( + feed_id=feed_id, + associated_id=payload.external_id, + source=payload.source, + ) + + # Add relationships + new_feed.externalids.append(external_id) + + # Create or get location + location = create_or_get_location( + self.session, + payload.country, + payload.state_province, + payload.city_name, + ) + + if location is not None: + new_feed.locations.append(location) + log_message( + "debug", f"Added location information for feed: {feed_id}" + ) + else: + log_message( + "debug", f"No location information to add for feed: {feed_id}" + ) + + self.session.add(new_feed) + self.session.flush() + + log_message("debug", f"Successfully created feed with ID: {feed_id}") + log_message( + "info", + f"Created new feed with ID: {feed_id} for external_id: {payload.external_id}", + ) + + except SQLAlchemyError as e: + self.session.rollback() + error_msg = f"Database error creating feed for external_id {payload.external_id}: {str(e)}" + log_message("error", error_msg) + raise + + except Exception as e: + error_msg = f"Database error creating feed for external_id {payload.external_id}: {str(e)}" + log_message("error", error_msg) + raise + + def process_feed_update(self, payload: FeedPayload, old_feed_id: str) -> None: + """ + Process feed update when URL has changed + + Args: + payload (FeedPayload): The feed payload for update + old_feed_id (str): The ID of the existing feed to be updated + """ + log_message( + "info", + f"Starting feed update process for external_id: {payload.external_id}", + ) + log_message("debug", f"Old feed_id: {old_feed_id}, New URL: {payload.feed_url}") + + try: + # Get count of existing references to this external ID + reference_count = ( + self.session.query(Feed) + .join(Externalid) + .filter( + Externalid.associated_id == payload.external_id, + Externalid.source == payload.source, + ) + .count() + ) + + # Create new feed with updated URL + new_feed_id = str(uuid.uuid4()) + # Added counter to stable_id + stable_id = ( + f"{payload.source}-{payload.external_id}" + if reference_count == 1 + else f"{payload.source}-{payload.external_id}_{reference_count}" + ) + + log_message( + "debug", + f"Generated new stable_id: {stable_id} (reference count: {reference_count})", + ) + + # Create new feed entry + new_feed = Feed( + id=new_feed_id, + data_type=payload.spec, + producer_url=payload.feed_url, + authentication_type=payload.type if payload.type else "0", + authentication_info_url=payload.auth_info_url, + api_key_parameter_name=payload.auth_param_name, + stable_id=stable_id, + status="active", + provider=payload.operator_name, + operational_status="wip", + ) + + # Add new feed to session + self.session.add(new_feed) + + # Update old feed status to deprecated + old_feed = self.session.get(Feed, old_feed_id) + if old_feed: + old_feed.status = "deprecated" + log_message("debug", f"Deprecating old feed ID: {old_feed_id}") + + # Create new external ID mapping for updated feed + new_external_id = Externalid( + feed_id=new_feed_id, + associated_id=payload.external_id, + source=payload.source, + ) + self.session.add(new_external_id) + log_message( + "debug", f"Created new external ID mapping for feed_id: {new_feed_id}" + ) + + # Create redirect + redirect = Redirectingid(source_id=old_feed_id, target_id=new_feed_id) + self.session.add(redirect) + log_message( + "debug", f"Created redirect from {old_feed_id} to {new_feed_id}" + ) + + # Create or get location and add to new feed + location = create_or_get_location( + self.session, payload.country, payload.state_province, payload.city_name + ) + + if location: + new_feed.locations.append(location) + log_message( + "debug", f"Added location information for feed: {new_feed_id}" + ) + + self.session.flush() + + log_message( + "info", + f"Updated feed for external_id: {payload.external_id}, new feed_id: {new_feed_id}", + ) + + except Exception as e: + log_message( + "error", + f"Error updating feed for external_id {payload.external_id}: {str(e)}", + ) + raise + + def check_feed_url_exists(self, feed_url: str) -> bool: + """ + Check if a feed with the given URL exists in any state (active or deprecated). + This check is used to prevent creating new feeds with URLs that are already in use. + + Args: + feed_url (str): The URL to check + + Returns: + bool: True if any feed with this URL exists (either active or deprecated), + preventing creation of new feeds with duplicate URLs + """ + results = self.session.query(Feed).filter(Feed.producer_url == feed_url).all() + + if results: + if len(results) > 1: + log_message("warning", f"Multiple feeds found with URL: {feed_url}") + return True + + result = results[0] + if result.status == "active": + log_message( + "info", f"Found existing feed with URL: {feed_url} (status: active)" + ) + return True + elif result.status == "deprecated": + log_message( + "error", + f"Feed URL {feed_url} exists in deprecated feed (id: {result.id}). " + "Cannot reuse URLs from deprecated feeds.", + ) + return True + + log_message("debug", f"No existing feed found with URL: {feed_url}") + return False + + def get_current_feed_info( + self, external_id: str, source: str + ) -> Tuple[Optional[str], Optional[str]]: + """ + Get current feed ID and URL for given external ID + + Args: + external_id (str): The external ID to look up + source (str): The source of the feed + + Returns: + Tuple[Optional[str], Optional[str]]: Tuple of (feed_id, feed_url) + """ + result = ( + self.session.query(Feed) + .filter(Feed.externalids.any(associated_id=external_id, source=source)) + .first() + ) + if result is not None: + log_message( + "info", + f"Retrieved feed {result.stable_id} " + f"info for external_id: {external_id} (status: {result.status})", + ) + return result.id, result.producer_url + log_message("info", f"No existing feed found for external_id: {external_id}") + return None, None + + def publish_to_batch_topic(self, payload: FeedPayload) -> None: + """ + Publish feed to dataset batch topic + + Args: + payload (FeedPayload): The feed payload to publish + """ + topic_path = self.publisher.topic_path(PROJECT_ID, DATASET_BATCH_TOPIC) + log_message("debug", f"Publishing to topic: {topic_path}") + + # Prepare message data in the expected format + message_data = { + "execution_id": payload.execution_id, + "producer_url": payload.feed_url, + "feed_stable_id": f"{payload.source}-{payload.external_id}", + "feed_id": payload.feed_id, + "dataset_id": None, + "dataset_hash": None, + "authentication_type": payload.type if payload.type else "0", + "authentication_info_url": payload.auth_info_url, + "api_key_parameter_name": payload.auth_param_name, + } + + try: + log_message("debug", f"Preparing to publish feed_id: {payload.feed_id}") + # Convert to JSON string and encode as base64 + json_str = json.dumps(message_data) + encoded_data = base64.b64encode(json_str.encode("utf-8")) + + future = self.publisher.publish(topic_path, data=encoded_data) + future.result() + log_message( + "info", f"Published feed {payload.feed_id} to dataset batch topic" + ) + except Exception as e: + error_msg = f"Error publishing to dataset batch topic: {str(e)}" + log_message("error", error_msg) + raise + + +@functions_framework.cloud_event +def process_feed_event(cloud_event): + """ + Cloud Function to process feed events from Pub/Sub + + Args: + cloud_event (CloudEvent): The cloud event + containing the Pub/Sub message + """ + try: + # Decode payload from Pub/Sub message + pubsub_message = base64.b64decode(cloud_event.data["message"]["data"]).decode() + message_data = json.loads(pubsub_message) + + payload = FeedPayload(**message_data) + + db_session = start_db_session(FEEDS_DATABASE_URL) + + try: + processor = FeedProcessor(db_session) + processor.process_feed(payload) + + log_message("info", f"Successfully processed feed: {payload.external_id}") + return "Success", 200 + + finally: + close_db_session(db_session) + + except Exception as e: + error_msg = f"Error processing feed event: {str(e)}" + log_message("error", error_msg) + return error_msg, 500 diff --git a/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py b/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py new file mode 100644 index 000000000..b4848ce56 --- /dev/null +++ b/functions-python/feed_sync_process_transitland/tests/test_feed_sync_process.py @@ -0,0 +1,839 @@ +import base64 +import json +import logging +import uuid +from unittest import mock +from unittest.mock import patch, Mock, MagicMock +import os + +import pytest +from google.api_core.exceptions import DeadlineExceeded +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session as DBSession + +from database_gen.sqlacodegen_models import Feed +from helpers.feed_sync.models import TransitFeedSyncPayload as FeedPayload + +with mock.patch("helpers.logger.Logger.init_logger") as mock_init_logger: + from feed_sync_process_transitland.src.main import ( + FeedProcessor, + process_feed_event, + log_message, + ) + +# Environment variables for tests +TEST_DB_URL = "postgresql://test:test@localhost:54320/test" + + +@pytest.fixture +def mock_feed(): + """Fixture for a Feed model instance""" + return Mock() + + +@pytest.fixture +def mock_external_id(): + """Fixture for an ExternalId model instance""" + return Mock() + + +@pytest.fixture +def mock_location(): + """Fixture for a Location model instance""" + return Mock() + + +class MockLogger: + """Mock logger for testing""" + + @staticmethod + def init_logger(): + return MagicMock() + + def __init__(self, name): + self.name = name + self._logger = logging.getLogger(name) + + def get_logger(self): + mock_logger = MagicMock() + # Add all required logging methods + mock_logger.info = MagicMock() + mock_logger.error = MagicMock() + mock_logger.warning = MagicMock() + mock_logger.debug = MagicMock() + mock_logger.addFilter = MagicMock() + return mock_logger + + +@pytest.fixture(autouse=True) +def mock_logging(): + """Mock both local and GCP logging.""" + with patch("feed_sync_process_transitland.src.main.logger") as mock_log, patch( + "feed_sync_process_transitland.src.main.gcp_logger" + ) as mock_gcp_log, patch("helpers.logger.Logger", MockLogger): + for logger in [mock_log, mock_gcp_log]: + logger.info = MagicMock() + logger.error = MagicMock() + logger.warning = MagicMock() + logger.debug = MagicMock() + logger.addFilter = MagicMock() + + yield mock_log + + +@pytest.fixture +def feed_payload(): + """Fixture for feed payload.""" + return FeedPayload( + external_id="test123", + feed_id="feed1", + feed_url="https://example.com/feed1", + execution_id="exec123", + spec="gtfs", + auth_info_url=None, + auth_param_name=None, + type=None, + operator_name="Test Operator", + country="United States", + state_province="CA", + city_name="Test City", + source="TLD", + payload_type="new", + ) + + +@mock.patch.dict( + "os.environ", + { + "FEEDS_DATABASE_URL": TEST_DB_URL, + "GOOGLE_APPLICATION_CREDENTIALS": "dummy-credentials.json", + }, +) +class TestFeedProcessor: + """Test suite for FeedProcessor.""" + + @pytest.fixture + def processor(self): + """Fixture for FeedProcessor with mocked dependencies.""" + # mock for the database session + mock_session = Mock(spec=DBSession) + + # Mock the PublisherClient + with patch("google.cloud.pubsub_v1.PublisherClient") as MockPublisherClient: + mock_publisher = MockPublisherClient.return_value + processor = FeedProcessor(mock_session) + processor.publisher = mock_publisher + mock_publisher.topic_path = Mock() + mock_publisher.publish = Mock() + + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.first.return_value = None + mock_session.query.return_value = mock_query + + return processor + + @staticmethod + def _create_payload_dict(feed_payload: FeedPayload) -> dict: + """Helper method to create a payload dictionary from a FeedPayload object.""" + return { + "external_id": feed_payload.external_id, + "feed_id": feed_payload.feed_id, + "feed_url": feed_payload.feed_url, + "execution_id": feed_payload.execution_id, + "spec": feed_payload.spec, + "auth_info_url": feed_payload.auth_info_url, + "auth_param_name": feed_payload.auth_param_name, + "type": feed_payload.type, + "operator_name": feed_payload.operator_name, + "country": feed_payload.country, + "state_province": feed_payload.state_province, + "city_name": feed_payload.city_name, + "source": feed_payload.source, + "payload_type": feed_payload.payload_type, + } + + def test_get_current_feed_info(self, processor, feed_payload, mock_logging): + """Test retrieving current feed information.""" + # Mock database query + processor.session.query.return_value.filter.return_value.first.return_value = ( + Mock( + id="feed-uuid", + producer_url="https://example.com/feed", + stable_id="TLD-test123", + status="active", + ) + ) + + feed_id, url = processor.get_current_feed_info( + feed_payload.external_id, feed_payload.source + ) + + # Assertions + assert feed_id == "feed-uuid" + assert url == "https://example.com/feed" + mock_logging.info.assert_called_with( + "Retrieved feed TLD-test123 " + f"info for external_id: {feed_payload.external_id} (status: active)" + ) + + # Test case when feed does not exist + processor.session.query.return_value.filter.return_value.first.return_value = ( + None + ) + feed_id, url = processor.get_current_feed_info( + feed_payload.external_id, feed_payload.source + ) + + assert feed_id is None + assert url is None + mock_logging.info.assert_called_with( + f"No existing feed found for external_id: {feed_payload.external_id}" + ) + + def test_check_feed_url_exists_comprehensive(self, processor, mock_logging): + """Test comprehensive feed URL existence checks.""" + test_url = "https://example.com/feed" + + # Test case 1: Active feed exists + mock_feed = Mock(id="test-id", status="active") + processor.session.query.return_value.filter.return_value.all.return_value = [ + mock_feed + ] + + result = processor.check_feed_url_exists(test_url) + assert result is True + mock_logging.info.assert_called_with( + f"Found existing feed with URL: {test_url} (status: active)" + ) + + # Test case 2: Deprecated feed exists + mock_logging.info.reset_mock() + mock_feed.status = "deprecated" + result = processor.check_feed_url_exists(test_url) + assert result is True + mock_logging.error.assert_called_with( + f"Feed URL {test_url} exists in deprecated feed (id: {mock_feed.id}). " + "Cannot reuse URLs from deprecated feeds." + ) + + # Test case 3: No feed exists + mock_logging.error.reset_mock() + processor.session.query.return_value.filter.return_value.all.return_value = [] + result = processor.check_feed_url_exists(test_url) + assert result is False + mock_logging.debug.assert_called_with( + f"No existing feed found with URL: {test_url}" + ) + + # Test case 4: Multiple feeds with same URL + mock_logging.debug.reset_mock() + mock_feeds = [ + Mock(id="feed1", status="active"), + Mock(id="feed2", status="deprecated"), + ] + processor.session.query.return_value.filter.return_value.all.return_value = ( + mock_feeds + ) + result = processor.check_feed_url_exists(test_url) + assert result is True + mock_logging.warning.assert_called_with( + f"Multiple feeds found with URL: {test_url}" + ) + + def test_log_message_function(self, mock_logging): + """Test the log_message function for different log levels.""" + levels = ["info", "error", "warning", "debug"] + messages = ["Info message", "Error message", "Warning message", "Debug message"] + + for level, message in zip(levels, messages): + log_message(level, message) + + if level == "info": + mock_logging.info.assert_called_with(message) + elif level == "error": + mock_logging.error.assert_called_with(message) + elif level == "warning": + mock_logging.warning.assert_called_with(message) + elif level == "debug": + mock_logging.debug.assert_called_with(message) + + def test_database_error_handling(self, processor, feed_payload, mock_logging): + """Test database error handling in different scenarios.""" + + # Test case 1: General database error during feed processing + processor.session.query.side_effect = SQLAlchemyError("Database error") + with pytest.raises(SQLAlchemyError, match="Database error"): + processor.process_feed(feed_payload) + + processor.session.rollback.assert_called_once() + mock_logging.error.assert_called_with( + "Database transaction rolled back due to error" + ) + + # Reset mocks for next test + processor.session.rollback.reset_mock() + mock_logging.error.reset_mock() + + # Test case 2: Connection failure during feed processing + processor.session.query.side_effect = SQLAlchemyError("Connection refused") + + with pytest.raises(SQLAlchemyError, match="Connection refused"): + processor.process_feed(feed_payload) + + processor.session.rollback.assert_called_once() + mock_logging.error.assert_called_with( + "Database transaction rolled back due to error" + ) + + def test_publish_to_batch_topic_comprehensive( + self, processor, feed_payload, mock_logging + ): + """Test publishing to batch topic including success, error, and message format validation.""" + + # Test case 1: Successful publish with message format validation + processor.publisher.topic_path.return_value = "test_topic" + mock_future = Mock() + processor.publisher.publish.return_value = mock_future + + processor.publish_to_batch_topic(feed_payload) + + # Verify publish was called and message format + call_args = processor.publisher.publish.call_args + assert call_args is not None + _, kwargs = call_args + + # Decode and verify message content + message_data = json.loads(base64.b64decode(kwargs["data"]).decode("utf-8")) + assert message_data["execution_id"] == feed_payload.execution_id + assert message_data["producer_url"] == feed_payload.feed_url + assert ( + message_data["feed_stable_id"] + == f"{feed_payload.source}-{feed_payload.external_id}" + ) + + mock_logging.info.assert_called_with( + f"Published feed {feed_payload.feed_id} to dataset batch topic" + ) + + # Test case 2: Publish error + processor.publisher.publish.side_effect = Exception("Pub/Sub error") + + with pytest.raises(Exception, match="Pub/Sub error"): + processor.publish_to_batch_topic(feed_payload) + + mock_logging.error.assert_called_with( + "Error publishing to dataset batch topic: Pub/Sub error" + ) + + # Test case 3: Timeout error + processor.publisher.publish.side_effect = DeadlineExceeded("Timeout error") + + with pytest.raises(DeadlineExceeded, match="Timeout error"): + processor.publish_to_batch_topic(feed_payload) + + mock_logging.error.assert_called_with( + "Error publishing to dataset batch topic: 504 Timeout error" + ) + + def test_process_feed_event_validation(self, mock_logging): + """Test feed event processing with various invalid payloads.""" + + # Test case 1: Empty payload + empty_payload_data = base64.b64encode(json.dumps({}).encode("utf-8")).decode() + cloud_event = Mock() + cloud_event.data = {"message": {"data": empty_payload_data}} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: TransitFeedSyncPayload.__init__() missing 14 " + "required positional arguments: 'external_id', 'feed_id', 'feed_url', " + "'execution_id', 'spec', 'auth_info_url', 'auth_param_name', 'type', " + "'operator_name', 'country', 'state_province', 'city_name', 'source', and " + "'payload_type'" + ) + + # Test case 2: Invalid field + mock_logging.error.reset_mock() + invalid_payload_data = base64.b64encode( + json.dumps({"invalid": "data"}).encode("utf-8") + ).decode() + cloud_event.data = {"message": {"data": invalid_payload_data}} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: TransitFeedSyncPayload.__init__() got an " + "unexpected keyword argument 'invalid'" + ) + + # Test case 3: Type error + mock_logging.error.reset_mock() + type_error_payload = {"external_id": 12345, "feed_url": True, "feed_id": None} + payload_data = base64.b64encode( + json.dumps(type_error_payload).encode("utf-8") + ).decode() + cloud_event.data = {"message": {"data": payload_data}} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: TransitFeedSyncPayload.__init__() missing 11 " + "required positional arguments: 'execution_id', 'spec', 'auth_info_url', " + "'auth_param_name', 'type', 'operator_name', 'country', 'state_province', " + "'city_name', 'source', and 'payload_type'" + ) + + def test_process_new_feed_with_location( + self, processor, feed_payload, mock_logging + ): + """Test creating a new feed with location information.""" + # Mock UUID generation + new_feed_id = str(uuid.uuid4()) + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)): + # Mock Location class + mock_location_cls = Mock(name="Location") + mock_location = mock_location_cls.return_value + mock_location.id = "US-CA-Test City" + mock_location.country_code = "US" + mock_location.country = "United States" + mock_location.subdivision_name = "CA" + mock_location.municipality = "Test City" + mock_location.__eq__ = ( + lambda self, other: isinstance(other, Mock) and self.id == other.id + ) + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_feed = MockFeed() + + with patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_feed + ), patch( + "database_gen.sqlacodegen_models.Location", mock_location_cls + ), patch( + "helpers.locations.create_or_get_location", return_value=mock_location + ): + processor.process_new_feed(feed_payload) + + # Verify feed creation + created_feed = processor.session.add.call_args[0][0] + assert created_feed.id == new_feed_id + assert created_feed.producer_url == feed_payload.feed_url + assert created_feed.data_type == feed_payload.spec + assert created_feed.provider == feed_payload.operator_name + + # Verify location was added to feed + assert len(created_feed.locations) == 1 + assert created_feed.locations[0].id == "US-CA-Test City" + assert created_feed.locations[0].country_code == "US" + assert created_feed.locations[0].country == "United States" + assert created_feed.locations[0].subdivision_name == "CA" + assert created_feed.locations[0].municipality == "Test City" + mock_logging.debug.assert_any_call( + f"Added location information for feed: {new_feed_id}" + ) + + def test_process_new_feed_without_location( + self, processor, feed_payload, mock_logging + ): + """Test creating a new feed without location information.""" + # Modify payload to have no location info + feed_payload.country = None + feed_payload.state_province = None + feed_payload.city_name = None + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + # Mock UUID generation + new_feed_id = str(uuid.uuid4()) + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_feed = MockFeed() + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_feed + ), patch("helpers.locations.create_or_get_location", return_value=None): + processor.process_new_feed(feed_payload) + + # Verify feed creation + created_feed = processor.session.add.call_args[0][0] + assert created_feed.id == new_feed_id + assert not created_feed.locations + + def test_process_feed_update_with_location( + self, processor, feed_payload, mock_logging + ): + """Test updating a feed with location information.""" + old_feed_id = str(uuid.uuid4()) + new_feed_id = str(uuid.uuid4()) + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + # Mock old feed + mock_old_feed = Mock(id=old_feed_id, status="active") + processor.session.get.return_value = mock_old_feed + + # Mock Location class + mock_location_cls = Mock(name="Location") + mock_location = mock_location_cls.return_value + mock_location.id = "US-CA-Test City" + mock_location.country_code = "US" + mock_location.country = "United States" + mock_location.subdivision_name = "CA" + mock_location.municipality = "Test City" + mock_location.__eq__ = ( + lambda self, other: isinstance(other, Mock) and self.id == other.id + ) + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_new_feed = MockFeed() + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_new_feed + ), patch("database_gen.sqlacodegen_models.Location", mock_location_cls), patch( + "helpers.locations.create_or_get_location", return_value=mock_location + ): + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify feed update + assert mock_old_feed.status == "deprecated" + + # Find the Feed object in the add calls + feed_add_call = None + for call in processor.session.add.call_args_list: + obj = call[0][0] + if hasattr(obj, "locations"): # This is our Feed object + feed_add_call = call + break + + assert ( + feed_add_call is not None + ), "Feed object not found in session.add calls" + created_feed = feed_add_call[0][0] + + # Verify new feed creation with location + assert len(created_feed.locations) == 1 + assert created_feed.locations[0].id == "US-CA-Test City" + assert created_feed.locations[0].country_code == "US" + assert created_feed.locations[0].country == "United States" + assert created_feed.locations[0].subdivision_name == "CA" + assert created_feed.locations[0].municipality == "Test City" + mock_logging.debug.assert_any_call( + f"Added location information for feed: {new_feed_id}" + ) + + def test_process_feed_update_without_location( + self, processor, feed_payload, mock_logging + ): + """Test updating a feed without location information.""" + old_feed_id = str(uuid.uuid4()) + new_feed_id = str(uuid.uuid4()) + + # Mock database query to return no existing feeds + processor.session.query.return_value.filter.return_value.all.return_value = [] + + # Modify payload to have no location info + feed_payload.country = None + feed_payload.state_province = None + feed_payload.city_name = None + + # Mock old feed + mock_old_feed = Mock(id=old_feed_id, status="active") + processor.session.get.return_value = mock_old_feed + + # Create a Feed class with a real list for locations + class MockFeed: + def __init__(self): + self.locations = [] + self.externalids = [] + self.id = new_feed_id + self.producer_url = feed_payload.feed_url + self.data_type = feed_payload.spec + self.provider = feed_payload.operator_name + self.status = "active" + self.stable_id = f"{feed_payload.source}-{feed_payload.external_id}" + + mock_new_feed = MockFeed() + + with patch("uuid.uuid4", return_value=uuid.UUID(new_feed_id)), patch( + "database_gen.sqlacodegen_models.Feed", return_value=mock_new_feed + ), patch("helpers.locations.create_or_get_location", return_value=None): + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify feed update + assert mock_old_feed.status == "deprecated" + + # Verify new feed creation without location + assert not mock_new_feed.locations + + def test_process_feed_event_database_connection_error( + self, processor, feed_payload, mock_logging + ): + """Test feed event processing with database connection error.""" + # Create cloud event with valid payload + payload_dict = self._create_payload_dict(feed_payload) + payload_data = base64.b64encode( + json.dumps(payload_dict).encode("utf-8") + ).decode() + cloud_event = Mock() + cloud_event.data = {"message": {"data": payload_data}} + + # Mock database session to raise error + with patch( + "feed_sync_process_transitland.src.main.start_db_session" + ) as mock_start_session: + mock_start_session.side_effect = SQLAlchemyError( + "Database connection error" + ) + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: Database connection error" + ) + + def test_process_feed_event_pubsub_error( + self, processor, feed_payload, mock_logging + ): + """Test feed event processing handles missing credentials error.""" + # Create cloud event with valid payload + payload_dict = self._create_payload_dict(feed_payload) + payload_data = base64.b64encode( + json.dumps(payload_dict).encode("utf-8") + ).decode() + + # Create cloud event mock with minimal required structure + cloud_event = Mock() + cloud_event.data = {"message": {"data": payload_data}} + + # Mock database session with minimal setup + mock_session = Mock() + mock_session.query.return_value.filter.return_value.all.return_value = [] + + # Process event and verify error handling + with patch( + "feed_sync_process_transitland.src.main.start_db_session", + return_value=mock_session, + ): + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with( + "Error processing feed event: File dummy-credentials.json was not found." + ) + + def test_process_feed_event_malformed_cloud_event(self, mock_logging): + """Test feed event processing with malformed cloud event.""" + # Test case 1: Missing message data + cloud_event = Mock() + cloud_event.data = {} + + result = process_feed_event(cloud_event) + assert result[1] == 500 + mock_logging.error.assert_called_with("Error processing feed event: 'message'") + + # Test case 2: Invalid base64 data + mock_logging.error.reset_mock() + cloud_event.data = {"message": {"data": "invalid-base64"}} + + result = process_feed_event(cloud_event) + error_msg = ( + "Error processing feed event: Invalid base64-encoded string: " + "number of data characters (13) cannot be 1 more than a multiple of 4" + ) + mock_logging.error.assert_called_with(error_msg) + + def test_publish_to_batch_topic(self, processor, feed_payload, mock_logging): + """Test publishing feed to batch topic.""" + # Mock the topic path + topic_path = "projects/test-project/topics/test-topic" + processor.publisher.topic_path.return_value = topic_path + + # Mock the publish future + mock_future = Mock() + mock_future.result.return_value = "message_id" + processor.publisher.publish.return_value = mock_future + + # Call the method + processor.publish_to_batch_topic(feed_payload) + + # Verify topic path was created correctly + processor.publisher.topic_path.assert_called_once_with( + os.getenv("PROJECT_ID"), os.getenv("DATASET_BATCH_TOPIC") + ) + + # Expected message data + expected_data = { + "execution_id": feed_payload.execution_id, + "producer_url": feed_payload.feed_url, + "feed_stable_id": f"{feed_payload.source}-{feed_payload.external_id}", + "feed_id": feed_payload.feed_id, + "dataset_id": None, + "dataset_hash": None, + "authentication_type": "0", # default value when type is None + "authentication_info_url": feed_payload.auth_info_url, + "api_key_parameter_name": feed_payload.auth_param_name, + } + + # Verify publish was called with correct data + encoded_data = base64.b64encode(json.dumps(expected_data).encode("utf-8")) + processor.publisher.publish.assert_called_once_with( + topic_path, data=encoded_data + ) + + # Verify success was logged + mock_logging.info.assert_called_with( + f"Published feed {feed_payload.feed_id} to dataset batch topic" + ) + + def test_publish_to_batch_topic_error(self, processor, feed_payload, mock_logging): + """Test error handling when publishing to batch topic fails.""" + # Mock the topic path + topic_path = "projects/test-project/topics/test-topic" + processor.publisher.topic_path.return_value = topic_path + + # Mock publish to raise an error + error_msg = "Failed to publish" + processor.publisher.publish.side_effect = Exception(error_msg) + + # Call the method and verify it raises the error + with pytest.raises(Exception) as exc_info: + processor.publish_to_batch_topic(feed_payload) + + assert str(exc_info.value) == error_msg + + # Verify error was logged + mock_logging.error.assert_called_with( + f"Error publishing to dataset batch topic: {error_msg}" + ) + + def test_process_feed_update_with_multiple_references( + self, processor, feed_payload, mock_logging + ): + """Test updating feed with multiple external ID references""" + old_feed_id = "old-feed-uuid" + + # Mock multiple references to the external ID + processor.session.query.return_value.join.return_value.filter.return_value.count.return_value = ( + 3 + ) + + # Mock getting old feed + mock_old_feed = Mock(spec=Feed) + processor.session.get.return_value = mock_old_feed + + # Process the update + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify stable_id includes reference count + expected_stable_id = f"{feed_payload.source}-{feed_payload.external_id}_3" + mock_logging.debug.assert_any_call( + f"Generated new stable_id: {expected_stable_id} (reference count: 3)" + ) + + # Verify old feed was deprecated + assert mock_old_feed.status == "deprecated" + + def test_process_feed_with_auth_info(self, processor, feed_payload, mock_logging): + """Test processing feed with authentication info""" + # Modify payload to include auth info + feed_payload.auth_info_url = "https://auth.example.com" + feed_payload.type = "oauth2" + feed_payload.auth_param_name = "access_token" + + # Mock the methods + with patch.object( + processor, "get_current_feed_info", return_value=(None, None) + ), patch.object( + processor, "check_feed_url_exists", return_value=False + ), patch.object( + processor, "process_new_feed" + ) as mock_process_new_feed: + # Process the feed + processor.process_feed(feed_payload) + + # Verify feed was processed + mock_process_new_feed.assert_called_once_with(feed_payload) + mock_logging.debug.assert_any_call( + "Database transaction committed successfully" + ) + + # Verify not published to batch topic (because auth_info_url is set) + processor.publisher.publish.assert_not_called() + + def test_process_feed_event_invalid_json(self, mock_logging): + """Test handling of invalid JSON in cloud event""" + # Create invalid base64 encoded JSON + invalid_json = base64.b64encode(b'{"invalid": "json"').decode() + + cloud_event = Mock() + cloud_event.data = {"message": {"data": invalid_json}} + + # Process the event + result, status_code = process_feed_event(cloud_event) + + # Verify error handling + assert status_code == 500 + assert "Error processing feed event" in result + mock_logging.error.assert_called() + + def test_process_feed_update_without_old_feed( + self, processor, feed_payload, mock_logging + ): + """Test feed update when old feed is not found""" + old_feed_id = "non-existent-feed" + + # Mock old feed not found + processor.session.get.return_value = None + + # Process the update + processor.process_feed_update(feed_payload, old_feed_id) + + # Verify processing continued without error + mock_logging.debug.assert_any_call( + f"Old feed_id: {old_feed_id}, New URL: {feed_payload.feed_url}" + ) + + # Verify no deprecation log since old feed wasn't found + deprecation_log = f"Deprecating old feed ID: {old_feed_id}" + assert mock.call(deprecation_log) not in mock_logging.debug.call_args_list diff --git a/functions-python/helpers/feed_sync/models.py b/functions-python/helpers/feed_sync/models.py new file mode 100644 index 000000000..54f769dec --- /dev/null +++ b/functions-python/helpers/feed_sync/models.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class TransitFeedSyncPayload: + """Data class for transit feed processing payload""" + + external_id: str + feed_id: str + feed_url: str + execution_id: Optional[str] + spec: str + auth_info_url: Optional[str] + auth_param_name: Optional[str] + type: Optional[str] + operator_name: Optional[str] + country: Optional[str] + state_province: Optional[str] + city_name: Optional[str] + source: str + payload_type: str diff --git a/functions-python/helpers/locations.py b/functions-python/helpers/locations.py index 9042b67b5..e73cebcc1 100644 --- a/functions-python/helpers/locations.py +++ b/functions-python/helpers/locations.py @@ -1,13 +1,109 @@ -from typing import Dict +from typing import Dict, Optional +from sqlalchemy.orm import Session +import pycountry +from database_gen.sqlacodegen_models import Feed, Location +import logging -from database_gen.sqlacodegen_models import Feed + +def get_country_code(country_name: str) -> Optional[str]: + """ + Get ISO 3166 country code from country name + + Args: + country_name (str): Full country name + + Returns: + Optional[str]: Two-letter ISO country code or None if not found + """ + # Return None for empty or whitespace-only strings + if not country_name or not country_name.strip(): + logging.error("Could not find country code for: empty string") + return None + + try: + # Try exact match first + country = pycountry.countries.get(name=country_name) + if country: + return country.alpha_2 + + # Try searching by name + countries = pycountry.countries.search_fuzzy(country_name) + if countries: + return countries[0].alpha_2 + + except LookupError: + logging.error(f"Could not find country code for: {country_name}") + return None + + +def create_or_get_location( + session: Session, + country: Optional[str], + state_province: Optional[str], + city_name: Optional[str], +) -> Optional[Location]: + """ + Create a new location or get existing one + + Args: + session: Database session + country: Country name + state_province: State/province name + city_name: City name + + Returns: + Optional[Location]: Location object or None if creation failed + """ + if not any([country, state_province, city_name]): + return None + + # Generate location_id using the specified pattern + location_components = [] + if country: + country_code = get_country_code(country) + if country_code: + location_components.append(country_code) + else: + logging.error(f"Could not determine country code for {country}") + return None + + if state_province: + location_components.append(state_province) + if city_name: + location_components.append(city_name) + + location_id = "-".join(location_components) + + # First check if location already exists + existing_location = ( + session.query(Location).filter(Location.id == location_id).first() + ) + + if existing_location: + logging.debug(f"Using existing location: {location_id}") + return existing_location + + # Create new location + location = Location( + id=location_id, + country_code=country_code, + country=country, + subdivision_name=state_province, + municipality=city_name, + ) + session.add(location) + logging.debug(f"Created new location: {location_id}") + + return location def translate_feed_locations(feed: Feed, location_translations: Dict): """ Translate the locations of a feed. - :param feed: The feed object - :param location_translations: The location translations + + Args: + feed: The feed object + location_translations: The location translations """ for location in feed.locations: location_translation = location_translations.get(location.id) diff --git a/functions-python/helpers/requirements.txt b/functions-python/helpers/requirements.txt index ae500c0b2..59b67dd1a 100644 --- a/functions-python/helpers/requirements.txt +++ b/functions-python/helpers/requirements.txt @@ -22,4 +22,7 @@ cloudevents~=1.10.1 google-cloud-bigquery google-api-core google-cloud-firestore -google-cloud-bigquery \ No newline at end of file +google-cloud-bigquery + +#Additional package +pycountry diff --git a/functions-python/helpers/tests/test_locations.py b/functions-python/helpers/tests/test_locations.py index 38180cdc2..b3ad676f0 100644 --- a/functions-python/helpers/tests/test_locations.py +++ b/functions-python/helpers/tests/test_locations.py @@ -1,23 +1,107 @@ +"""Unit tests for locations helper module.""" + import unittest from unittest.mock import MagicMock from database_gen.sqlacodegen_models import Feed, Location -from helpers.locations import translate_feed_locations +from helpers.locations import ( + translate_feed_locations, + get_country_code, + create_or_get_location, +) +from unittest.mock import patch + + +class TestLocations(unittest.TestCase): + """Test cases for location-related functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.session = MagicMock() + + def test_get_country_code_exact_match(self): + """Test getting country code with exact name match.""" + self.assertEqual(get_country_code("France"), "FR") + self.assertEqual(get_country_code("United States"), "US") + + def test_get_country_code_fuzzy_match(self): + """Test getting country code with fuzzy matching.""" + self.assertEqual(get_country_code("USA"), "US") + self.assertEqual(get_country_code("United Kingdom of Great Britain"), "GB") + + def test_get_country_code_invalid(self): + """Test getting country code with invalid country name.""" + self.assertIsNone(get_country_code("Invalid Country Name")) + + def test_create_or_get_location_existing(self): + """Test retrieving existing location.""" + mock_location = Location( + id="US-California-San Francisco", + country_code="US", + country="United States", + subdivision_name="California", + municipality="San Francisco", + ) + self.session.query.return_value.filter.return_value.first.return_value = ( + mock_location + ) + + result = create_or_get_location( + self.session, + country="United States", + state_province="California", + city_name="San Francisco", + ) + + self.assertEqual(result, mock_location) + self.session.add.assert_not_called() + + def test_create_or_get_location_new(self): + """Test creating new location.""" + self.session.query.return_value.filter.return_value.first.return_value = None + + result = create_or_get_location( + self.session, + country="United States", + state_province="California", + city_name="San Francisco", + ) + + self.assertIsNotNone(result) + self.assertEqual(result.id, "US-California-San Francisco") + self.assertEqual(result.country_code, "US") + self.assertEqual(result.country, "United States") + self.assertEqual(result.subdivision_name, "California") + self.assertEqual(result.municipality, "San Francisco") + self.session.add.assert_called_once() + def test_create_or_get_location_no_inputs(self): + """Test with no location information provided.""" + result = create_or_get_location( + self.session, country=None, state_province=None, city_name=None + ) + self.assertIsNone(result) + + def test_create_or_get_location_invalid_country(self): + """Test with invalid country name.""" + result = create_or_get_location( + self.session, + country="Invalid Country", + state_province="State", + city_name="City", + ) + self.assertIsNone(result) -class TestTranslateFeedLocations(unittest.TestCase): def test_translate_feed_locations(self): - # Mock a location object with specific attributes + """Test translating feed locations with all translations available.""" mock_location = MagicMock(spec=Location) mock_location.id = 1 mock_location.subdivision_name = "Original Subdivision" mock_location.municipality = "Original Municipality" mock_location.country = "Original Country" - # Mock a feed object with locations mock_feed = MagicMock(spec=Feed) mock_feed.locations = [mock_location] - # Define a translation dictionary location_translations = { 1: { "subdivision_name_translation": "Translated Subdivision", @@ -26,27 +110,23 @@ def test_translate_feed_locations(self): } } - # Call the translate_feed_locations function translate_feed_locations(mock_feed, location_translations) - # Assert that the location's attributes were updated with translations self.assertEqual(mock_location.subdivision_name, "Translated Subdivision") self.assertEqual(mock_location.municipality, "Translated Municipality") self.assertEqual(mock_location.country, "Translated Country") def test_translate_feed_locations_with_missing_translations(self): - # Mock a location object with specific attributes + """Test translating feed locations with some missing translations.""" mock_location = MagicMock(spec=Location) mock_location.id = 1 mock_location.subdivision_name = "Original Subdivision" mock_location.municipality = "Original Municipality" mock_location.country = "Original Country" - # Mock a feed object with locations mock_feed = MagicMock(spec=Feed) mock_feed.locations = [mock_location] - # Define a translation dictionary with missing translations location_translations = { 1: { "subdivision_name_translation": None, @@ -55,37 +135,145 @@ def test_translate_feed_locations_with_missing_translations(self): } } - # Call the translate_feed_locations function translate_feed_locations(mock_feed, location_translations) - # Assert that the location's attributes were updated correctly - self.assertEqual( - mock_location.subdivision_name, "Original Subdivision" - ) # No translation - self.assertEqual( - mock_location.municipality, "Original Municipality" - ) # No translation - self.assertEqual(mock_location.country, "Translated Country") # Translated + self.assertEqual(mock_location.subdivision_name, "Original Subdivision") + self.assertEqual(mock_location.municipality, "Original Municipality") + self.assertEqual(mock_location.country, "Translated Country") def test_translate_feed_locations_with_no_translation(self): - # Mock a location object with specific attributes + """Test translating feed locations with no translations available.""" mock_location = MagicMock(spec=Location) mock_location.id = 1 mock_location.subdivision_name = "Original Subdivision" mock_location.municipality = "Original Municipality" mock_location.country = "Original Country" - # Mock a feed object with locations mock_feed = MagicMock(spec=Feed) mock_feed.locations = [mock_location] - # Define an empty translation dictionary location_translations = {} - # Call the translate_feed_locations function translate_feed_locations(mock_feed, location_translations) - # Assert that the location's attributes remain unchanged self.assertEqual(mock_location.subdivision_name, "Original Subdivision") self.assertEqual(mock_location.municipality, "Original Municipality") self.assertEqual(mock_location.country, "Original Country") + + def test_get_country_code_fuzzy_match_partial(self): + """Test getting country code with partial name matches""" + # Test partial name matches + self.assertEqual(get_country_code("United"), "US") # Should match United States + self.assertEqual(get_country_code("South Korea"), "KR") # Republic of Korea + self.assertEqual( + get_country_code("North Korea"), "KP" + ) # Democratic People's Republic of Korea + self.assertEqual( + get_country_code("Great Britain"), "GB" + ) # Should match United Kingdom + + @patch("helpers.locations.logging.error") + def test_get_country_code_empty_string(self, mock_logging): + """Test getting country code with empty string""" + self.assertIsNone(get_country_code("")) + mock_logging.assert_called_with("Could not find country code for: empty string") + + def test_create_or_get_location_partial_info(self): + """Test creating location with partial information""" + self.session.query.return_value.filter.return_value.first.return_value = None + + # Test with only country + result = create_or_get_location( + self.session, country="United States", state_province=None, city_name=None + ) + self.assertEqual(result.id, "US") + self.assertEqual(result.country_code, "US") + self.assertEqual(result.country, "United States") + self.assertIsNone(result.subdivision_name) + self.assertIsNone(result.municipality) + + # Test with country and state + result = create_or_get_location( + self.session, + country="United States", + state_province="California", + city_name=None, + ) + self.assertEqual(result.id, "US-California") + self.assertEqual(result.country_code, "US") + self.assertEqual(result.country, "United States") + self.assertEqual(result.subdivision_name, "California") + self.assertIsNone(result.municipality) + + def test_translate_feed_locations_partial_translations(self): + """Test translating feed locations with partial translations""" + mock_location = MagicMock(spec=Location) + mock_location.id = "loc1" + mock_location.subdivision_name = "Original State" + mock_location.municipality = "Original City" + mock_location.country = "Original Country" + + mock_feed = MagicMock(spec=Feed) + mock_feed.locations = [mock_location] + + # Test with only some fields translated + translations = { + "loc1": { + "subdivision_name_translation": "Translated State", + "municipality_translation": None, # No translation + "country_translation": "Translated Country", + } + } + + translate_feed_locations(mock_feed, translations) + + # Verify partial translations + self.assertEqual(mock_location.subdivision_name, "Translated State") + self.assertEqual( + mock_location.municipality, "Original City" + ) # Should remain unchanged + self.assertEqual(mock_location.country, "Translated Country") + + def test_translate_feed_locations_multiple_locations(self): + """Test translating multiple locations in a feed""" + # Create multiple mock locations + mock_location1 = MagicMock(spec=Location) + mock_location1.id = "loc1" + mock_location1.subdivision_name = "Original State 1" + mock_location1.municipality = "Original City 1" + mock_location1.country = "Original Country 1" + + mock_location2 = MagicMock(spec=Location) + mock_location2.id = "loc2" + mock_location2.subdivision_name = "Original State 2" + mock_location2.municipality = "Original City 2" + mock_location2.country = "Original Country 2" + + mock_feed = MagicMock(spec=Feed) + mock_feed.locations = [mock_location1, mock_location2] + + # Translations for both locations + translations = { + "loc1": { + "subdivision_name_translation": "Translated State 1", + "municipality_translation": "Translated City 1", + "country_translation": "Translated Country 1", + }, + "loc2": { + "subdivision_name_translation": "Translated State 2", + "municipality_translation": "Translated City 2", + "country_translation": "Translated Country 2", + }, + } + + translate_feed_locations(mock_feed, translations) + + # Verify translations for first location + self.assertEqual(mock_location1.subdivision_name, "Translated State 1") + self.assertEqual(mock_location1.municipality, "Translated City 1") + self.assertEqual(mock_location1.country, "Translated Country 1") + + # Verify translations for second location + self.assertEqual(mock_location2.subdivision_name, "Translated State 2") + self.assertEqual(mock_location2.municipality, "Translated City 2") + self.assertEqual(mock_location2.country, "Translated Country 2") diff --git a/functions-python/preprocessed_analytics/requirements.txt b/functions-python/preprocessed_analytics/requirements.txt index a07655518..9ec5ce9fb 100644 --- a/functions-python/preprocessed_analytics/requirements.txt +++ b/functions-python/preprocessed_analytics/requirements.txt @@ -19,4 +19,5 @@ google-cloud-bigquery google-cloud-storage # Additional packages for this function -pandas \ No newline at end of file +pandas +pycountry \ No newline at end of file